diff --git a/rust/crates/api/src/client.rs b/rust/crates/api/src/client.rs index f90eaf8..c7aca3f 100644 --- a/rust/crates/api/src/client.rs +++ b/rust/crates/api/src/client.rs @@ -1,4 +1,5 @@ use std::collections::VecDeque; +use std::sync::{Arc, Mutex}; use std::time::{Duration, SystemTime, UNIX_EPOCH}; use runtime::{ @@ -8,7 +9,7 @@ use runtime::{ use serde::Deserialize; use crate::error::ApiError; -use crate::prompt_cache::{PromptCache, PromptCacheStats}; +use crate::prompt_cache::{PromptCache, PromptCacheRecord, PromptCacheStats}; use crate::sse::SseParser; use crate::types::{MessageRequest, MessageResponse, StreamEvent, Usage}; @@ -110,6 +111,7 @@ pub struct AnthropicClient { initial_backoff: Duration, max_backoff: Duration, prompt_cache: Option, + last_prompt_cache_record: Arc>>, } impl AnthropicClient { @@ -123,6 +125,7 @@ impl AnthropicClient { initial_backoff: DEFAULT_INITIAL_BACKOFF, max_backoff: DEFAULT_MAX_BACKOFF, prompt_cache: None, + last_prompt_cache_record: Arc::new(Mutex::new(None)), } } @@ -136,6 +139,7 @@ impl AnthropicClient { initial_backoff: DEFAULT_INITIAL_BACKOFF, max_backoff: DEFAULT_MAX_BACKOFF, prompt_cache: None, + last_prompt_cache_record: Arc::new(Mutex::new(None)), } } @@ -209,6 +213,14 @@ impl AnthropicClient { self.prompt_cache.as_ref().map(PromptCache::stats) } + #[must_use] + pub fn take_last_prompt_cache_record(&self) -> Option { + self.last_prompt_cache_record() + .lock() + .unwrap_or_else(std::sync::PoisonError::into_inner) + .take() + } + #[must_use] pub fn auth_source(&self) -> &AuthSource { &self.auth @@ -218,12 +230,16 @@ impl AnthropicClient { &self, request: &MessageRequest, ) -> Result { + self.store_last_prompt_cache_record(None); let request = MessageRequest { stream: false, ..request.clone() }; if let Some(prompt_cache) = &self.prompt_cache { if let Some(response) = prompt_cache.lookup_completion(&request) { + self.store_last_prompt_cache_record(Some(prompt_cache_record_from_stats( + prompt_cache.stats(), + ))); return Ok(response); } } @@ -237,7 +253,8 @@ impl AnthropicClient { response.request_id = request_id; } if let Some(prompt_cache) = &self.prompt_cache { - let _ = prompt_cache.record_response(&request, &response); + let record = prompt_cache.record_response(&request, &response); + self.store_last_prompt_cache_record(Some(record)); } Ok(response) } @@ -246,6 +263,7 @@ impl AnthropicClient { &self, request: &MessageRequest, ) -> Result { + self.store_last_prompt_cache_record(None); let response = self .send_with_retry(&request.clone().with_streaming()) .await?; @@ -263,10 +281,22 @@ impl AnthropicClient { request: request.clone().with_streaming(), last_usage: None, finalized: false, + last_record: self.last_prompt_cache_record.clone(), }), }) } + fn store_last_prompt_cache_record(&self, record: Option) { + *self + .last_prompt_cache_record() + .lock() + .unwrap_or_else(std::sync::PoisonError::into_inner) = record; + } + + fn last_prompt_cache_record(&self) -> &Arc>> { + &self.last_prompt_cache_record + } + pub async fn exchange_oauth_code( &self, config: &OAuthConfig, @@ -615,6 +645,7 @@ struct StreamCacheTracking { request: MessageRequest, last_usage: Option, finalized: bool, + last_record: Arc>>, } impl StreamCacheTracking { @@ -638,12 +669,23 @@ impl StreamCacheTracking { return; } if let Some(usage) = &self.last_usage { - let _ = self.prompt_cache.record_usage(&self.request, usage); + let record = self.prompt_cache.record_usage(&self.request, usage); + *self + .last_record + .lock() + .unwrap_or_else(std::sync::PoisonError::into_inner) = Some(record); } self.finalized = true; } } +fn prompt_cache_record_from_stats(stats: PromptCacheStats) -> PromptCacheRecord { + PromptCacheRecord { + cache_break: None, + stats, + } +} + async fn expect_success(response: reqwest::Response) -> Result { let status = response.status(); if status.is_success() { diff --git a/rust/crates/runtime/src/conversation.rs b/rust/crates/runtime/src/conversation.rs index 4ffbabc..00dbf54 100644 --- a/rust/crates/runtime/src/conversation.rs +++ b/rust/crates/runtime/src/conversation.rs @@ -25,9 +25,19 @@ pub enum AssistantEvent { input: String, }, Usage(TokenUsage), + PromptCache(PromptCacheEvent), MessageStop, } +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct PromptCacheEvent { + pub unexpected: bool, + pub reason: String, + pub previous_cache_read_input_tokens: u32, + pub current_cache_read_input_tokens: u32, + pub token_drop: u32, +} + pub trait ApiClient { fn stream(&mut self, request: ApiRequest) -> Result, RuntimeError>; } @@ -84,6 +94,7 @@ impl std::error::Error for RuntimeError {} pub struct TurnSummary { pub assistant_messages: Vec, pub tool_results: Vec, + pub prompt_cache_events: Vec, pub iterations: usize, pub usage: TokenUsage, } @@ -118,7 +129,7 @@ where tool_executor, permission_policy, system_prompt, - RuntimeFeatureConfig::default(), + &RuntimeFeatureConfig::default(), ) } @@ -129,7 +140,7 @@ where tool_executor: T, permission_policy: PermissionPolicy, system_prompt: Vec, - feature_config: RuntimeFeatureConfig, + feature_config: &RuntimeFeatureConfig, ) -> Self { let usage_tracker = UsageTracker::from_session(&session); Self { @@ -140,7 +151,7 @@ where system_prompt, max_iterations: usize::MAX, usage_tracker, - hook_runner: HookRunner::from_feature_config(&feature_config), + hook_runner: HookRunner::from_feature_config(feature_config), } } @@ -161,6 +172,7 @@ where let mut assistant_messages = Vec::new(); let mut tool_results = Vec::new(); + let mut prompt_cache_events = Vec::new(); let mut iterations = 0; loop { @@ -176,10 +188,12 @@ where messages: self.session.messages.clone(), }; let events = self.api_client.stream(request)?; - let (assistant_message, usage) = build_assistant_message(events)?; + let (assistant_message, usage, turn_prompt_cache_events) = + build_assistant_message(events)?; if let Some(usage) = usage { self.usage_tracker.record(usage); } + prompt_cache_events.extend(turn_prompt_cache_events); let pending_tool_uses = assistant_message .blocks .iter() @@ -257,6 +271,7 @@ where Ok(TurnSummary { assistant_messages, tool_results, + prompt_cache_events, iterations, usage: self.usage_tracker.cumulative_usage(), }) @@ -290,9 +305,17 @@ where fn build_assistant_message( events: Vec, -) -> Result<(ConversationMessage, Option), RuntimeError> { +) -> Result< + ( + ConversationMessage, + Option, + Vec, + ), + RuntimeError, +> { let mut text = String::new(); let mut blocks = Vec::new(); + let mut prompt_cache_events = Vec::new(); let mut finished = false; let mut usage = None; @@ -304,6 +327,7 @@ fn build_assistant_message( blocks.push(ContentBlock::ToolUse { id, name, input }); } AssistantEvent::Usage(value) => usage = Some(value), + AssistantEvent::PromptCache(event) => prompt_cache_events.push(event), AssistantEvent::MessageStop => { finished = true; } @@ -324,6 +348,7 @@ fn build_assistant_message( Ok(( ConversationMessage::assistant_with_usage(blocks, usage), usage, + prompt_cache_events, )) } @@ -396,7 +421,7 @@ impl ToolExecutor for StaticToolExecutor { #[cfg(test)] mod tests { use super::{ - ApiClient, ApiRequest, AssistantEvent, ConversationRuntime, RuntimeError, + ApiClient, ApiRequest, AssistantEvent, ConversationRuntime, PromptCacheEvent, RuntimeError, StaticToolExecutor, }; use crate::compact::CompactionConfig; @@ -453,6 +478,15 @@ mod tests { cache_creation_input_tokens: 1, cache_read_input_tokens: 3, }), + AssistantEvent::PromptCache(PromptCacheEvent { + unexpected: true, + reason: + "cache read tokens dropped while prompt fingerprint remained stable" + .to_string(), + previous_cache_read_input_tokens: 6_000, + current_cache_read_input_tokens: 1_000, + token_drop: 5_000, + }), AssistantEvent::MessageStop, ]) } @@ -506,8 +540,10 @@ mod tests { assert_eq!(summary.iterations, 2); assert_eq!(summary.assistant_messages.len(), 2); assert_eq!(summary.tool_results.len(), 1); + assert_eq!(summary.prompt_cache_events.len(), 1); assert_eq!(runtime.session().messages.len(), 4); assert_eq!(summary.usage.output_tokens, 10); + assert!(summary.prompt_cache_events[0].unexpected); assert!(matches!( runtime.session().messages[1].blocks[1], ContentBlock::ToolUse { .. } @@ -609,7 +645,7 @@ mod tests { }), PermissionPolicy::new(PermissionMode::DangerFullAccess), vec!["system".to_string()], - RuntimeFeatureConfig::default().with_hooks(RuntimeHookConfig::new( + &RuntimeFeatureConfig::default().with_hooks(RuntimeHookConfig::new( vec![shell_snippet("printf 'blocked by hook'; exit 2")], Vec::new(), )), @@ -675,7 +711,7 @@ mod tests { StaticToolExecutor::new().register("add", |_input| Ok("4".to_string())), PermissionPolicy::new(PermissionMode::DangerFullAccess), vec!["system".to_string()], - RuntimeFeatureConfig::default().with_hooks(RuntimeHookConfig::new( + &RuntimeFeatureConfig::default().with_hooks(RuntimeHookConfig::new( vec![shell_snippet("printf 'pre hook ran'")], vec![shell_snippet("printf 'post hook ran'")], )), @@ -697,7 +733,7 @@ mod tests { "post hook should preserve non-error result: {output:?}" ); assert!( - output.contains("4"), + output.contains('4'), "tool output missing value: {output:?}" ); assert!( diff --git a/rust/crates/runtime/src/hooks.rs b/rust/crates/runtime/src/hooks.rs index 36756a0..80770ba 100644 --- a/rust/crates/runtime/src/hooks.rs +++ b/rust/crates/runtime/src/hooks.rs @@ -64,7 +64,7 @@ impl HookRunner { #[must_use] pub fn run_pre_tool_use(&self, tool_name: &str, tool_input: &str) -> HookRunResult { - self.run_commands( + Self::run_commands( HookEvent::PreToolUse, self.config.pre_tool_use(), tool_name, @@ -82,7 +82,7 @@ impl HookRunner { tool_output: &str, is_error: bool, ) -> HookRunResult { - self.run_commands( + Self::run_commands( HookEvent::PostToolUse, self.config.post_tool_use(), tool_name, @@ -93,7 +93,6 @@ impl HookRunner { } fn run_commands( - &self, event: HookEvent, commands: &[String], tool_name: &str, @@ -118,7 +117,7 @@ impl HookRunner { let mut messages = Vec::new(); for command in commands { - match self.run_command( + match Self::run_command( command, event, tool_name, @@ -150,7 +149,6 @@ impl HookRunner { } fn run_command( - &self, command: &str, event: HookEvent, tool_name: &str, diff --git a/rust/crates/runtime/src/lib.rs b/rust/crates/runtime/src/lib.rs index da745e5..856f9f5 100644 --- a/rust/crates/runtime/src/lib.rs +++ b/rust/crates/runtime/src/lib.rs @@ -31,8 +31,8 @@ pub use config::{ ScopedMcpServerConfig, CLAUDE_CODE_SETTINGS_SCHEMA_NAME, }; pub use conversation::{ - ApiClient, ApiRequest, AssistantEvent, ConversationRuntime, RuntimeError, StaticToolExecutor, - ToolError, ToolExecutor, TurnSummary, + ApiClient, ApiRequest, AssistantEvent, ConversationRuntime, PromptCacheEvent, RuntimeError, + StaticToolExecutor, ToolError, ToolExecutor, TurnSummary, }; pub use file_ops::{ edit_file, glob_search, grep_search, read_file, write_file, EditFileOutput, GlobSearchOutput, diff --git a/rust/crates/rusty-claude-cli/src/main.rs b/rust/crates/rusty-claude-cli/src/main.rs index 5f8a7a6..dcce2b0 100644 --- a/rust/crates/rusty-claude-cli/src/main.rs +++ b/rust/crates/rusty-claude-cli/src/main.rs @@ -13,8 +13,9 @@ use std::time::{SystemTime, UNIX_EPOCH}; use api::{ resolve_startup_auth_source, AnthropicClient, AuthSource, ContentBlockDelta, InputContentBlock, - InputMessage, MessageRequest, MessageResponse, OutputContentBlock, - StreamEvent as ApiStreamEvent, ToolChoice, ToolDefinition, ToolResultContentBlock, + InputMessage, MessageRequest, MessageResponse, OutputContentBlock, PromptCache, + PromptCacheRecord, StreamEvent as ApiStreamEvent, ToolChoice, ToolDefinition, + ToolResultContentBlock, }; use commands::{ @@ -28,8 +29,8 @@ use runtime::{ parse_oauth_callback_request_target, save_oauth_credentials, ApiClient, ApiRequest, AssistantEvent, CompactionConfig, ConfigLoader, ConfigSource, ContentBlock, ConversationMessage, ConversationRuntime, MessageRole, OAuthAuthorizationRequest, OAuthConfig, - OAuthTokenExchangeRequest, PermissionMode, PermissionPolicy, ProjectContext, RuntimeError, - Session, TokenUsage, ToolError, ToolExecutor, UsageTracker, + OAuthTokenExchangeRequest, PermissionMode, PermissionPolicy, ProjectContext, PromptCacheEvent, + RuntimeError, Session, TokenUsage, ToolError, ToolExecutor, UsageTracker, }; use serde_json::json; use tools::{execute_tool, mvp_tool_specs, ToolSpec}; @@ -995,6 +996,7 @@ impl LiveCli { let session = create_managed_session_handle()?; let runtime = build_runtime( Session::new(), + session.id.clone(), model.clone(), system_prompt.clone(), enable_tools, @@ -1050,13 +1052,14 @@ impl LiveCli { let mut permission_prompter = CliPermissionPrompter::new(self.permission_mode); let result = self.runtime.run_turn(input, Some(&mut permission_prompter)); match result { - Ok(_) => { + Ok(summary) => { spinner.finish( "✨ Done", TerminalRenderer::new().color_theme(), &mut stdout, )?; println!(); + print_prompt_cache_events(&summary); self.persist_session()?; Ok(()) } @@ -1086,6 +1089,7 @@ impl LiveCli { let session = self.runtime.session().clone(); let mut runtime = build_runtime( session, + self.session.id.clone(), self.model.clone(), self.system_prompt.clone(), true, @@ -1105,6 +1109,7 @@ impl LiveCli { "iterations": summary.iterations, "tool_uses": collect_tool_uses(&summary), "tool_results": collect_tool_results(&summary), + "prompt_cache_events": collect_prompt_cache_events(&summary), "usage": { "input_tokens": summary.usage.input_tokens, "output_tokens": summary.usage.output_tokens, @@ -1232,6 +1237,7 @@ impl LiveCli { let message_count = session.messages.len(); self.runtime = build_runtime( session, + self.session.id.clone(), model.clone(), self.system_prompt.clone(), true, @@ -1275,6 +1281,7 @@ impl LiveCli { self.permission_mode = permission_mode_from_label(normalized); self.runtime = build_runtime( session, + self.session.id.clone(), self.model.clone(), self.system_prompt.clone(), true, @@ -1300,6 +1307,7 @@ impl LiveCli { self.session = create_managed_session_handle()?; self.runtime = build_runtime( Session::new(), + self.session.id.clone(), self.model.clone(), self.system_prompt.clone(), true, @@ -1335,6 +1343,7 @@ impl LiveCli { let message_count = session.messages.len(); self.runtime = build_runtime( session, + handle.id.clone(), self.model.clone(), self.system_prompt.clone(), true, @@ -1407,6 +1416,7 @@ impl LiveCli { let message_count = session.messages.len(); self.runtime = build_runtime( session, + handle.id.clone(), self.model.clone(), self.system_prompt.clone(), true, @@ -1437,6 +1447,7 @@ impl LiveCli { let skipped = removed == 0; self.runtime = build_runtime( result.compacted_session, + self.session.id.clone(), self.model.clone(), self.system_prompt.clone(), true, @@ -1912,8 +1923,10 @@ fn build_runtime_feature_config( .clone()) } +#[allow(clippy::too_many_arguments)] fn build_runtime( session: Session, + session_id: String, model: String, system_prompt: Vec, enable_tools: bool, @@ -1924,11 +1937,17 @@ fn build_runtime( { Ok(ConversationRuntime::new_with_features( session, - AnthropicRuntimeClient::new(model, enable_tools, emit_output, allowed_tools.clone())?, + AnthropicRuntimeClient::new( + model, + enable_tools, + emit_output, + allowed_tools.clone(), + session_id, + )?, CliToolExecutor::new(allowed_tools, emit_output), permission_policy(permission_mode), system_prompt, - build_runtime_feature_config()?, + &build_runtime_feature_config()?, )) } @@ -1993,11 +2012,13 @@ impl AnthropicRuntimeClient { enable_tools: bool, emit_output: bool, allowed_tools: Option, + session_id: impl Into, ) -> Result> { Ok(Self { runtime: tokio::runtime::Runtime::new()?, client: AnthropicClient::from_auth(resolve_cli_auth_source()?) - .with_base_url(api::read_base_url()), + .with_base_url(api::read_base_url()) + .with_prompt_cache(PromptCache::new(session_id)), model, enable_tools, emit_output, @@ -2112,8 +2133,8 @@ impl ApiClient for AnthropicRuntimeClient { events.push(AssistantEvent::Usage(TokenUsage { input_tokens: delta.usage.input_tokens, output_tokens: delta.usage.output_tokens, - cache_creation_input_tokens: 0, - cache_read_input_tokens: 0, + cache_creation_input_tokens: delta.usage.cache_creation_input_tokens, + cache_read_input_tokens: delta.usage.cache_read_input_tokens, })); } ApiStreamEvent::MessageStop(_) => { @@ -2128,6 +2149,8 @@ impl ApiClient for AnthropicRuntimeClient { } } + push_prompt_cache_record(&self.client, &mut events); + if !saw_stop && events.iter().any(|event| { matches!(event, AssistantEvent::TextDelta(text) if !text.is_empty()) @@ -2152,7 +2175,9 @@ impl ApiClient for AnthropicRuntimeClient { }) .await .map_err(|error| RuntimeError::new(error.to_string()))?; - response_to_events(response, out) + let mut events = response_to_events(response, out)?; + push_prompt_cache_record(&self.client, &mut events); + Ok(events) }) } } @@ -2213,6 +2238,39 @@ fn collect_tool_results(summary: &runtime::TurnSummary) -> Vec Vec { + summary + .prompt_cache_events + .iter() + .map(|event| { + json!({ + "unexpected": event.unexpected, + "reason": event.reason, + "previous_cache_read_input_tokens": event.previous_cache_read_input_tokens, + "current_cache_read_input_tokens": event.current_cache_read_input_tokens, + "token_drop": event.token_drop, + }) + }) + .collect() +} + +fn print_prompt_cache_events(summary: &runtime::TurnSummary) { + for event in &summary.prompt_cache_events { + let label = if event.unexpected { + "Prompt cache break" + } else { + "Prompt cache invalidation" + }; + println!( + "{label}: {} (cache read {} -> {}, drop {})", + event.reason, + event.previous_cache_read_input_tokens, + event.current_cache_read_input_tokens, + event.token_drop, + ); + } +} + fn slash_command_completion_candidates() -> Vec { slash_command_specs() .iter() @@ -2359,18 +2417,20 @@ fn first_visible_line(text: &str) -> &str { } fn format_bash_result(icon: &str, parsed: &serde_json::Value) -> String { + use std::fmt::Write as _; + let mut lines = vec![format!("{icon} \x1b[38;5;245mbash\x1b[0m")]; if let Some(task_id) = parsed .get("backgroundTaskId") .and_then(|value| value.as_str()) { - lines[0].push_str(&format!(" backgrounded ({task_id})")); + let _ = write!(lines[0], " backgrounded ({task_id})"); } else if let Some(status) = parsed .get("returnCodeInterpretation") .and_then(|value| value.as_str()) .filter(|status| !status.is_empty()) { - lines[0].push_str(&format!(" {status}")); + let _ = write!(lines[0], " {status}"); } if let Some(stdout) = parsed.get("stdout").and_then(|value| value.as_str()) { @@ -2392,15 +2452,15 @@ fn format_read_result(icon: &str, parsed: &serde_json::Value) -> String { let path = extract_tool_path(file); let start_line = file .get("startLine") - .and_then(|value| value.as_u64()) + .and_then(serde_json::Value::as_u64) .unwrap_or(1); let num_lines = file .get("numLines") - .and_then(|value| value.as_u64()) + .and_then(serde_json::Value::as_u64) .unwrap_or(0); let total_lines = file .get("totalLines") - .and_then(|value| value.as_u64()) + .and_then(serde_json::Value::as_u64) .unwrap_or(num_lines); let content = file .get("content") @@ -2426,8 +2486,7 @@ fn format_write_result(icon: &str, parsed: &serde_json::Value) -> String { let line_count = parsed .get("content") .and_then(|value| value.as_str()) - .map(|content| content.lines().count()) - .unwrap_or(0); + .map_or(0, |content| content.lines().count()); format!( "{icon} \x1b[1;32m✏️ {} {path}\x1b[0m \x1b[2m({line_count} lines)\x1b[0m", if kind == "create" { "Wrote" } else { "Updated" }, @@ -2458,7 +2517,7 @@ fn format_edit_result(icon: &str, parsed: &serde_json::Value) -> String { let path = extract_tool_path(parsed); let suffix = if parsed .get("replaceAll") - .and_then(|value| value.as_bool()) + .and_then(serde_json::Value::as_bool) .unwrap_or(false) { " (replace all)" @@ -2486,7 +2545,7 @@ fn format_edit_result(icon: &str, parsed: &serde_json::Value) -> String { fn format_glob_result(icon: &str, parsed: &serde_json::Value) -> String { let num_files = parsed .get("numFiles") - .and_then(|value| value.as_u64()) + .and_then(serde_json::Value::as_u64) .unwrap_or(0); let filenames = parsed .get("filenames") @@ -2510,11 +2569,11 @@ fn format_glob_result(icon: &str, parsed: &serde_json::Value) -> String { fn format_grep_result(icon: &str, parsed: &serde_json::Value) -> String { let num_matches = parsed .get("numMatches") - .and_then(|value| value.as_u64()) + .and_then(serde_json::Value::as_u64) .unwrap_or(0); let num_files = parsed .get("numFiles") - .and_then(|value| value.as_u64()) + .and_then(serde_json::Value::as_u64) .unwrap_or(0); let content = parsed .get("content") @@ -2621,6 +2680,26 @@ fn response_to_events( Ok(events) } +fn push_prompt_cache_record(client: &AnthropicClient, events: &mut Vec) { + if let Some(event) = client + .take_last_prompt_cache_record() + .and_then(prompt_cache_record_to_runtime_event) + { + events.push(AssistantEvent::PromptCache(event)); + } +} + +fn prompt_cache_record_to_runtime_event(record: PromptCacheRecord) -> Option { + let cache_break = record.cache_break?; + Some(PromptCacheEvent { + unexpected: cache_break.unexpected, + reason: cache_break.reason, + previous_cache_read_input_tokens: cache_break.previous_cache_read_input_tokens, + current_cache_read_input_tokens: cache_break.current_cache_read_input_tokens, + token_drop: cache_break.token_drop, + }) +} + struct CliToolExecutor { renderer: TerminalRenderer, emit_output: bool, diff --git a/rust/crates/rusty-claude-cli/src/render.rs b/rust/crates/rusty-claude-cli/src/render.rs index 465c5a4..d8d8796 100644 --- a/rust/crates/rusty-claude-cli/src/render.rs +++ b/rust/crates/rusty-claude-cli/src/render.rs @@ -286,7 +286,7 @@ impl TerminalRenderer { ) { match event { Event::Start(Tag::Heading { level, .. }) => { - self.start_heading(state, level as u8, output) + Self::start_heading(state, level as u8, output); } Event::End(TagEnd::Paragraph) => output.push_str("\n\n"), Event::Start(Tag::BlockQuote(..)) => self.start_quote(state, output), @@ -426,7 +426,7 @@ impl TerminalRenderer { } } - fn start_heading(&self, state: &mut RenderState, level: u8, output: &mut String) { + fn start_heading(state: &mut RenderState, level: u8, output: &mut String) { state.heading_level = Some(level); if !output.is_empty() { output.push('\n'); diff --git a/rust/crates/tools/src/lib.rs b/rust/crates/tools/src/lib.rs index 8dcd33d..be11e6b 100644 --- a/rust/crates/tools/src/lib.rs +++ b/rust/crates/tools/src/lib.rs @@ -5,15 +5,15 @@ use std::time::{Duration, Instant}; use api::{ read_base_url, AnthropicClient, ContentBlockDelta, InputContentBlock, InputMessage, - MessageRequest, MessageResponse, OutputContentBlock, StreamEvent as ApiStreamEvent, ToolChoice, - ToolDefinition, ToolResultContentBlock, + MessageRequest, MessageResponse, OutputContentBlock, PromptCache, PromptCacheRecord, + StreamEvent as ApiStreamEvent, ToolChoice, ToolDefinition, ToolResultContentBlock, }; use reqwest::blocking::Client; use runtime::{ edit_file, execute_bash, glob_search, grep_search, load_system_prompt, read_file, write_file, ApiClient, ApiRequest, AssistantEvent, BashCommandInput, ContentBlock, ConversationMessage, ConversationRuntime, GrepSearchInput, MessageRole, PermissionMode, PermissionPolicy, - RuntimeError, Session, TokenUsage, ToolError, ToolExecutor, + PromptCacheEvent, RuntimeError, Session, TokenUsage, ToolError, ToolExecutor, }; use serde::{Deserialize, Serialize}; use serde_json::{json, Value}; @@ -1466,7 +1466,8 @@ fn build_agent_runtime( .clone() .unwrap_or_else(|| DEFAULT_AGENT_MODEL.to_string()); let allowed_tools = job.allowed_tools.clone(); - let api_client = AnthropicRuntimeClient::new(model, allowed_tools.clone())?; + let api_client = + AnthropicRuntimeClient::new(model, allowed_tools.clone(), job.manifest.agent_id.clone())?; let tool_executor = SubagentToolExecutor::new(allowed_tools); Ok(ConversationRuntime::new( Session::new(), @@ -1643,10 +1644,15 @@ struct AnthropicRuntimeClient { } impl AnthropicRuntimeClient { - fn new(model: String, allowed_tools: BTreeSet) -> Result { + fn new( + model: String, + allowed_tools: BTreeSet, + session_id: impl Into, + ) -> Result { let client = AnthropicClient::from_env() .map_err(|error| error.to_string())? - .with_base_url(read_base_url()); + .with_base_url(read_base_url()) + .with_prompt_cache(PromptCache::new(session_id)); Ok(Self { runtime: tokio::runtime::Runtime::new().map_err(|error| error.to_string())?, client, @@ -1657,6 +1663,7 @@ impl AnthropicRuntimeClient { } impl ApiClient for AnthropicRuntimeClient { + #[allow(clippy::too_many_lines)] fn stream(&mut self, request: ApiRequest) -> Result, RuntimeError> { let tools = tool_specs_for_allowed_tools(Some(&self.allowed_tools)) .into_iter() @@ -1726,8 +1733,8 @@ impl ApiClient for AnthropicRuntimeClient { events.push(AssistantEvent::Usage(TokenUsage { input_tokens: delta.usage.input_tokens, output_tokens: delta.usage.output_tokens, - cache_creation_input_tokens: 0, - cache_read_input_tokens: 0, + cache_creation_input_tokens: delta.usage.cache_creation_input_tokens, + cache_read_input_tokens: delta.usage.cache_read_input_tokens, })); } ApiStreamEvent::MessageStop(_) => { @@ -1737,6 +1744,8 @@ impl ApiClient for AnthropicRuntimeClient { } } + push_prompt_cache_record(&self.client, &mut events); + if !saw_stop && events.iter().any(|event| { matches!(event, AssistantEvent::TextDelta(text) if !text.is_empty()) @@ -1761,7 +1770,9 @@ impl ApiClient for AnthropicRuntimeClient { }) .await .map_err(|error| RuntimeError::new(error.to_string()))?; - Ok(response_to_events(response)) + let mut events = response_to_events(response); + push_prompt_cache_record(&self.client, &mut events); + Ok(events) }) } } @@ -1884,6 +1895,26 @@ fn response_to_events(response: MessageResponse) -> Vec { events } +fn push_prompt_cache_record(client: &AnthropicClient, events: &mut Vec) { + if let Some(event) = client + .take_last_prompt_cache_record() + .and_then(prompt_cache_record_to_runtime_event) + { + events.push(AssistantEvent::PromptCache(event)); + } +} + +fn prompt_cache_record_to_runtime_event(record: PromptCacheRecord) -> Option { + let cache_break = record.cache_break?; + Some(PromptCacheEvent { + unexpected: cache_break.unexpected, + reason: cache_break.reason, + previous_cache_read_input_tokens: cache_break.previous_cache_read_input_tokens, + current_cache_read_input_tokens: cache_break.current_cache_read_input_tokens, + token_drop: cache_break.token_drop, + }) +} + fn final_assistant_text(summary: &runtime::TurnSummary) -> String { summary .assistant_messages