Make Rust cost reporting aware of the active model

This replaces the single default pricing assumption with a small model-aware pricing table for Sonnet, Opus, and Haiku so CLI usage output better matches the selected model. Unknown models still fall back cleanly with explicit labeling.

The change keeps pricing lightweight and local while improving the usefulness of usage/cost reporting for resumed sessions and live turns.

Constraint: Keep pricing local and dependency-free

Constraint: Preserve graceful fallback behavior for unknown model IDs

Rejected: Add a remote pricing source now | unnecessary coupling and risk for this slice

Confidence: high

Scope-risk: narrow

Reversibility: clean

Directive: If pricing tables expand later, prefer explicit model-family matching and keep fallback labeling visible

Tested: cargo fmt; cargo clippy --all-targets --all-features -- -D warnings; cargo test -q

Not-tested: Validation against live provider billing exports
This commit is contained in:
Yeachan-Heo
2026-03-31 19:42:31 +00:00
parent add5513ac5
commit 6fe404329d
3 changed files with 117 additions and 12 deletions

View File

@@ -38,4 +38,6 @@ pub use prompt::{
SystemPromptBuilder, FRONTIER_MODEL_NAME, SYSTEM_PROMPT_DYNAMIC_BOUNDARY, SystemPromptBuilder, FRONTIER_MODEL_NAME, SYSTEM_PROMPT_DYNAMIC_BOUNDARY,
}; };
pub use session::{ContentBlock, ConversationMessage, MessageRole, Session, SessionError}; pub use session::{ContentBlock, ConversationMessage, MessageRole, Session, SessionError};
pub use usage::{format_usd, TokenUsage, UsageCostEstimate, UsageTracker}; pub use usage::{
format_usd, pricing_for_model, ModelPricing, TokenUsage, UsageCostEstimate, UsageTracker,
};

View File

@@ -5,6 +5,26 @@ const DEFAULT_OUTPUT_COST_PER_MILLION: f64 = 75.0;
const DEFAULT_CACHE_CREATION_COST_PER_MILLION: f64 = 18.75; const DEFAULT_CACHE_CREATION_COST_PER_MILLION: f64 = 18.75;
const DEFAULT_CACHE_READ_COST_PER_MILLION: f64 = 1.5; const DEFAULT_CACHE_READ_COST_PER_MILLION: f64 = 1.5;
#[derive(Debug, Clone, Copy, PartialEq)]
pub struct ModelPricing {
pub input_cost_per_million: f64,
pub output_cost_per_million: f64,
pub cache_creation_cost_per_million: f64,
pub cache_read_cost_per_million: f64,
}
impl ModelPricing {
#[must_use]
pub const fn default_sonnet_tier() -> Self {
Self {
input_cost_per_million: DEFAULT_INPUT_COST_PER_MILLION,
output_cost_per_million: DEFAULT_OUTPUT_COST_PER_MILLION,
cache_creation_cost_per_million: DEFAULT_CACHE_CREATION_COST_PER_MILLION,
cache_read_cost_per_million: DEFAULT_CACHE_READ_COST_PER_MILLION,
}
}
}
#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)] #[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
pub struct TokenUsage { pub struct TokenUsage {
pub input_tokens: u32, pub input_tokens: u32,
@@ -31,6 +51,31 @@ impl UsageCostEstimate {
} }
} }
#[must_use]
pub fn pricing_for_model(model: &str) -> Option<ModelPricing> {
let normalized = model.to_ascii_lowercase();
if normalized.contains("haiku") {
return Some(ModelPricing {
input_cost_per_million: 1.0,
output_cost_per_million: 5.0,
cache_creation_cost_per_million: 1.25,
cache_read_cost_per_million: 0.1,
});
}
if normalized.contains("opus") {
return Some(ModelPricing {
input_cost_per_million: 15.0,
output_cost_per_million: 75.0,
cache_creation_cost_per_million: 18.75,
cache_read_cost_per_million: 1.5,
});
}
if normalized.contains("sonnet") {
return Some(ModelPricing::default_sonnet_tier());
}
None
}
impl TokenUsage { impl TokenUsage {
#[must_use] #[must_use]
pub fn total_tokens(self) -> u32 { pub fn total_tokens(self) -> u32 {
@@ -42,32 +87,57 @@ impl TokenUsage {
#[must_use] #[must_use]
pub fn estimate_cost_usd(self) -> UsageCostEstimate { pub fn estimate_cost_usd(self) -> UsageCostEstimate {
self.estimate_cost_usd_with_pricing(ModelPricing::default_sonnet_tier())
}
#[must_use]
pub fn estimate_cost_usd_with_pricing(self, pricing: ModelPricing) -> UsageCostEstimate {
UsageCostEstimate { UsageCostEstimate {
input_cost_usd: cost_for_tokens(self.input_tokens, DEFAULT_INPUT_COST_PER_MILLION), input_cost_usd: cost_for_tokens(self.input_tokens, pricing.input_cost_per_million),
output_cost_usd: cost_for_tokens(self.output_tokens, DEFAULT_OUTPUT_COST_PER_MILLION), output_cost_usd: cost_for_tokens(self.output_tokens, pricing.output_cost_per_million),
cache_creation_cost_usd: cost_for_tokens( cache_creation_cost_usd: cost_for_tokens(
self.cache_creation_input_tokens, self.cache_creation_input_tokens,
DEFAULT_CACHE_CREATION_COST_PER_MILLION, pricing.cache_creation_cost_per_million,
), ),
cache_read_cost_usd: cost_for_tokens( cache_read_cost_usd: cost_for_tokens(
self.cache_read_input_tokens, self.cache_read_input_tokens,
DEFAULT_CACHE_READ_COST_PER_MILLION, pricing.cache_read_cost_per_million,
), ),
} }
} }
#[must_use] #[must_use]
pub fn summary_lines(self, label: &str) -> Vec<String> { pub fn summary_lines(self, label: &str) -> Vec<String> {
let cost = self.estimate_cost_usd(); self.summary_lines_for_model(label, None)
}
#[must_use]
pub fn summary_lines_for_model(self, label: &str, model: Option<&str>) -> Vec<String> {
let pricing = model.and_then(pricing_for_model);
let cost = pricing.map_or_else(
|| self.estimate_cost_usd(),
|pricing| self.estimate_cost_usd_with_pricing(pricing),
);
let model_suffix =
model.map_or_else(String::new, |model_name| format!(" model={model_name}"));
let pricing_suffix = if pricing.is_some() {
""
} else if model.is_some() {
" pricing=estimated-default"
} else {
""
};
vec![ vec![
format!( format!(
"{label}: total_tokens={} input={} output={} cache_write={} cache_read={} estimated_cost={}", "{label}: total_tokens={} input={} output={} cache_write={} cache_read={} estimated_cost={}{}{}",
self.total_tokens(), self.total_tokens(),
self.input_tokens, self.input_tokens,
self.output_tokens, self.output_tokens,
self.cache_creation_input_tokens, self.cache_creation_input_tokens,
self.cache_read_input_tokens, self.cache_read_input_tokens,
format_usd(cost.total_cost_usd()), format_usd(cost.total_cost_usd()),
model_suffix,
pricing_suffix,
), ),
format!( format!(
" cost breakdown: input={} output={} cache_write={} cache_read={}", " cost breakdown: input={} output={} cache_write={} cache_read={}",
@@ -140,7 +210,7 @@ impl UsageTracker {
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::{format_usd, TokenUsage, UsageTracker}; use super::{format_usd, pricing_for_model, TokenUsage, UsageTracker};
use crate::session::{ContentBlock, ConversationMessage, MessageRole, Session}; use crate::session::{ContentBlock, ConversationMessage, MessageRole, Session};
#[test] #[test]
@@ -179,11 +249,41 @@ mod tests {
let cost = usage.estimate_cost_usd(); let cost = usage.estimate_cost_usd();
assert_eq!(format_usd(cost.input_cost_usd), "$15.0000"); assert_eq!(format_usd(cost.input_cost_usd), "$15.0000");
assert_eq!(format_usd(cost.output_cost_usd), "$37.5000"); assert_eq!(format_usd(cost.output_cost_usd), "$37.5000");
let lines = usage.summary_lines("usage"); let lines = usage.summary_lines_for_model("usage", Some("claude-sonnet-4-20250514"));
assert!(lines[0].contains("estimated_cost=$54.6750")); assert!(lines[0].contains("estimated_cost=$54.6750"));
assert!(lines[0].contains("model=claude-sonnet-4-20250514"));
assert!(lines[1].contains("cache_read=$0.3000")); assert!(lines[1].contains("cache_read=$0.3000"));
} }
#[test]
fn supports_model_specific_pricing() {
let usage = TokenUsage {
input_tokens: 1_000_000,
output_tokens: 500_000,
cache_creation_input_tokens: 0,
cache_read_input_tokens: 0,
};
let haiku = pricing_for_model("claude-haiku-4-5-20251001").expect("haiku pricing");
let opus = pricing_for_model("claude-opus-4-6").expect("opus pricing");
let haiku_cost = usage.estimate_cost_usd_with_pricing(haiku);
let opus_cost = usage.estimate_cost_usd_with_pricing(opus);
assert_eq!(format_usd(haiku_cost.total_cost_usd()), "$3.5000");
assert_eq!(format_usd(opus_cost.total_cost_usd()), "$52.5000");
}
#[test]
fn marks_unknown_model_pricing_as_fallback() {
let usage = TokenUsage {
input_tokens: 100,
output_tokens: 100,
cache_creation_input_tokens: 0,
cache_read_input_tokens: 0,
};
let lines = usage.summary_lines_for_model("usage", Some("custom-model"));
assert!(lines[0].contains("pricing=estimated-default"));
}
#[test] #[test]
fn reconstructs_usage_from_session_messages() { fn reconstructs_usage_from_session_messages() {
let session = Session { let session = Session {

View File

@@ -386,6 +386,9 @@ fn inspect_session(target: &str) {
println!("- size_bytes: {bytes}"); println!("- size_bytes: {bytes}");
println!("- messages: {}", session.messages.len()); println!("- messages: {}", session.messages.len());
println!("- total_tokens: {}", usage.total_tokens()); println!("- total_tokens: {}", usage.total_tokens());
for line in usage.summary_lines_for_model("- usage", None) {
println!("{line}");
}
println!("- preview: {}", session_preview(&session)); println!("- preview: {}", session_preview(&session));
if let Some(user_text) = latest_text_for_role(&session, MessageRole::User) { if let Some(user_text) = latest_text_for_role(&session, MessageRole::User) {
@@ -499,7 +502,7 @@ impl LiveCli {
self.runtime.usage().turns(), self.runtime.usage().turns(),
self.runtime.estimated_tokens() self.runtime.estimated_tokens()
); );
for line in usage.summary_lines("usage") { for line in usage.summary_lines_for_model("usage", Some(&self.model)) {
println!("{line}"); println!("{line}");
} }
} }
@@ -507,11 +510,11 @@ impl LiveCli {
fn print_turn_usage(&self, cumulative_usage: TokenUsage) { fn print_turn_usage(&self, cumulative_usage: TokenUsage) {
let latest = self.runtime.usage().current_turn_usage(); let latest = self.runtime.usage().current_turn_usage();
println!("\nTurn usage:"); println!("\nTurn usage:");
for line in latest.summary_lines(" latest") { for line in latest.summary_lines_for_model(" latest", Some(&self.model)) {
println!("{line}"); println!("{line}");
} }
println!("Cumulative usage:"); println!("Cumulative usage:");
for line in cumulative_usage.summary_lines(" total") { for line in cumulative_usage.summary_lines_for_model(" total", Some(&self.model)) {
println!("{line}"); println!("{line}");
} }
} }