diff --git a/rust/crates/runtime/src/lib.rs b/rust/crates/runtime/src/lib.rs index b4f52cb..573f858 100644 --- a/rust/crates/runtime/src/lib.rs +++ b/rust/crates/runtime/src/lib.rs @@ -38,4 +38,6 @@ pub use prompt::{ SystemPromptBuilder, FRONTIER_MODEL_NAME, SYSTEM_PROMPT_DYNAMIC_BOUNDARY, }; pub use session::{ContentBlock, ConversationMessage, MessageRole, Session, SessionError}; -pub use usage::{format_usd, TokenUsage, UsageCostEstimate, UsageTracker}; +pub use usage::{ + format_usd, pricing_for_model, ModelPricing, TokenUsage, UsageCostEstimate, UsageTracker, +}; diff --git a/rust/crates/runtime/src/usage.rs b/rust/crates/runtime/src/usage.rs index 08f2d9a..04e28df 100644 --- a/rust/crates/runtime/src/usage.rs +++ b/rust/crates/runtime/src/usage.rs @@ -5,6 +5,26 @@ const DEFAULT_OUTPUT_COST_PER_MILLION: f64 = 75.0; const DEFAULT_CACHE_CREATION_COST_PER_MILLION: f64 = 18.75; const DEFAULT_CACHE_READ_COST_PER_MILLION: f64 = 1.5; +#[derive(Debug, Clone, Copy, PartialEq)] +pub struct ModelPricing { + pub input_cost_per_million: f64, + pub output_cost_per_million: f64, + pub cache_creation_cost_per_million: f64, + pub cache_read_cost_per_million: f64, +} + +impl ModelPricing { + #[must_use] + pub const fn default_sonnet_tier() -> Self { + Self { + input_cost_per_million: DEFAULT_INPUT_COST_PER_MILLION, + output_cost_per_million: DEFAULT_OUTPUT_COST_PER_MILLION, + cache_creation_cost_per_million: DEFAULT_CACHE_CREATION_COST_PER_MILLION, + cache_read_cost_per_million: DEFAULT_CACHE_READ_COST_PER_MILLION, + } + } +} + #[derive(Debug, Clone, Copy, Default, PartialEq, Eq)] pub struct TokenUsage { pub input_tokens: u32, @@ -31,6 +51,31 @@ impl UsageCostEstimate { } } +#[must_use] +pub fn pricing_for_model(model: &str) -> Option { + let normalized = model.to_ascii_lowercase(); + if normalized.contains("haiku") { + return Some(ModelPricing { + input_cost_per_million: 1.0, + output_cost_per_million: 5.0, + cache_creation_cost_per_million: 1.25, + cache_read_cost_per_million: 0.1, + }); + } + if normalized.contains("opus") { + return Some(ModelPricing { + input_cost_per_million: 15.0, + output_cost_per_million: 75.0, + cache_creation_cost_per_million: 18.75, + cache_read_cost_per_million: 1.5, + }); + } + if normalized.contains("sonnet") { + return Some(ModelPricing::default_sonnet_tier()); + } + None +} + impl TokenUsage { #[must_use] pub fn total_tokens(self) -> u32 { @@ -42,32 +87,57 @@ impl TokenUsage { #[must_use] pub fn estimate_cost_usd(self) -> UsageCostEstimate { + self.estimate_cost_usd_with_pricing(ModelPricing::default_sonnet_tier()) + } + + #[must_use] + pub fn estimate_cost_usd_with_pricing(self, pricing: ModelPricing) -> UsageCostEstimate { UsageCostEstimate { - input_cost_usd: cost_for_tokens(self.input_tokens, DEFAULT_INPUT_COST_PER_MILLION), - output_cost_usd: cost_for_tokens(self.output_tokens, DEFAULT_OUTPUT_COST_PER_MILLION), + input_cost_usd: cost_for_tokens(self.input_tokens, pricing.input_cost_per_million), + output_cost_usd: cost_for_tokens(self.output_tokens, pricing.output_cost_per_million), cache_creation_cost_usd: cost_for_tokens( self.cache_creation_input_tokens, - DEFAULT_CACHE_CREATION_COST_PER_MILLION, + pricing.cache_creation_cost_per_million, ), cache_read_cost_usd: cost_for_tokens( self.cache_read_input_tokens, - DEFAULT_CACHE_READ_COST_PER_MILLION, + pricing.cache_read_cost_per_million, ), } } #[must_use] pub fn summary_lines(self, label: &str) -> Vec { - let cost = self.estimate_cost_usd(); + self.summary_lines_for_model(label, None) + } + + #[must_use] + pub fn summary_lines_for_model(self, label: &str, model: Option<&str>) -> Vec { + let pricing = model.and_then(pricing_for_model); + let cost = pricing.map_or_else( + || self.estimate_cost_usd(), + |pricing| self.estimate_cost_usd_with_pricing(pricing), + ); + let model_suffix = + model.map_or_else(String::new, |model_name| format!(" model={model_name}")); + let pricing_suffix = if pricing.is_some() { + "" + } else if model.is_some() { + " pricing=estimated-default" + } else { + "" + }; vec![ format!( - "{label}: total_tokens={} input={} output={} cache_write={} cache_read={} estimated_cost={}", + "{label}: total_tokens={} input={} output={} cache_write={} cache_read={} estimated_cost={}{}{}", self.total_tokens(), self.input_tokens, self.output_tokens, self.cache_creation_input_tokens, self.cache_read_input_tokens, format_usd(cost.total_cost_usd()), + model_suffix, + pricing_suffix, ), format!( " cost breakdown: input={} output={} cache_write={} cache_read={}", @@ -140,7 +210,7 @@ impl UsageTracker { #[cfg(test)] mod tests { - use super::{format_usd, TokenUsage, UsageTracker}; + use super::{format_usd, pricing_for_model, TokenUsage, UsageTracker}; use crate::session::{ContentBlock, ConversationMessage, MessageRole, Session}; #[test] @@ -179,11 +249,41 @@ mod tests { let cost = usage.estimate_cost_usd(); assert_eq!(format_usd(cost.input_cost_usd), "$15.0000"); assert_eq!(format_usd(cost.output_cost_usd), "$37.5000"); - let lines = usage.summary_lines("usage"); + let lines = usage.summary_lines_for_model("usage", Some("claude-sonnet-4-20250514")); assert!(lines[0].contains("estimated_cost=$54.6750")); + assert!(lines[0].contains("model=claude-sonnet-4-20250514")); assert!(lines[1].contains("cache_read=$0.3000")); } + #[test] + fn supports_model_specific_pricing() { + let usage = TokenUsage { + input_tokens: 1_000_000, + output_tokens: 500_000, + cache_creation_input_tokens: 0, + cache_read_input_tokens: 0, + }; + + let haiku = pricing_for_model("claude-haiku-4-5-20251001").expect("haiku pricing"); + let opus = pricing_for_model("claude-opus-4-6").expect("opus pricing"); + let haiku_cost = usage.estimate_cost_usd_with_pricing(haiku); + let opus_cost = usage.estimate_cost_usd_with_pricing(opus); + assert_eq!(format_usd(haiku_cost.total_cost_usd()), "$3.5000"); + assert_eq!(format_usd(opus_cost.total_cost_usd()), "$52.5000"); + } + + #[test] + fn marks_unknown_model_pricing_as_fallback() { + let usage = TokenUsage { + input_tokens: 100, + output_tokens: 100, + cache_creation_input_tokens: 0, + cache_read_input_tokens: 0, + }; + let lines = usage.summary_lines_for_model("usage", Some("custom-model")); + assert!(lines[0].contains("pricing=estimated-default")); + } + #[test] fn reconstructs_usage_from_session_messages() { let session = Session { diff --git a/rust/crates/rusty-claude-cli/src/main.rs b/rust/crates/rusty-claude-cli/src/main.rs index 0816ec3..9db600f 100644 --- a/rust/crates/rusty-claude-cli/src/main.rs +++ b/rust/crates/rusty-claude-cli/src/main.rs @@ -386,6 +386,9 @@ fn inspect_session(target: &str) { println!("- size_bytes: {bytes}"); println!("- messages: {}", session.messages.len()); println!("- total_tokens: {}", usage.total_tokens()); + for line in usage.summary_lines_for_model("- usage", None) { + println!("{line}"); + } println!("- preview: {}", session_preview(&session)); if let Some(user_text) = latest_text_for_role(&session, MessageRole::User) { @@ -499,7 +502,7 @@ impl LiveCli { self.runtime.usage().turns(), self.runtime.estimated_tokens() ); - for line in usage.summary_lines("usage") { + for line in usage.summary_lines_for_model("usage", Some(&self.model)) { println!("{line}"); } } @@ -507,11 +510,11 @@ impl LiveCli { fn print_turn_usage(&self, cumulative_usage: TokenUsage) { let latest = self.runtime.usage().current_turn_usage(); println!("\nTurn usage:"); - for line in latest.summary_lines(" latest") { + for line in latest.summary_lines_for_model(" latest", Some(&self.model)) { println!("{line}"); } println!("Cumulative usage:"); - for line in cumulative_usage.summary_lines(" total") { + for line in cumulative_usage.summary_lines_for_model(" total", Some(&self.model)) { println!("{line}"); } }