diff --git a/rust/crates/runtime/src/conversation.rs b/rust/crates/runtime/src/conversation.rs index 5c9ccfe..136aaa2 100644 --- a/rust/crates/runtime/src/conversation.rs +++ b/rust/crates/runtime/src/conversation.rs @@ -408,7 +408,7 @@ mod tests { .sum::(); Ok(total.to_string()) }); - let permission_policy = PermissionPolicy::new(PermissionMode::Prompt); + let permission_policy = PermissionPolicy::new(PermissionMode::WorkspaceWrite); let system_prompt = SystemPromptBuilder::new() .with_project_context(ProjectContext { cwd: PathBuf::from("/tmp/project"), @@ -487,7 +487,7 @@ mod tests { Session::new(), SingleCallApiClient, StaticToolExecutor::new(), - PermissionPolicy::new(PermissionMode::Prompt), + PermissionPolicy::new(PermissionMode::WorkspaceWrite), vec!["system".to_string()], ); @@ -536,7 +536,7 @@ mod tests { session, SimpleApi, StaticToolExecutor::new(), - PermissionPolicy::new(PermissionMode::Allow), + PermissionPolicy::new(PermissionMode::DangerFullAccess), vec!["system".to_string()], ); @@ -563,7 +563,7 @@ mod tests { Session::new(), SimpleApi, StaticToolExecutor::new(), - PermissionPolicy::new(PermissionMode::Allow), + PermissionPolicy::new(PermissionMode::DangerFullAccess), vec!["system".to_string()], ); runtime.run_turn("a", None).expect("turn a"); diff --git a/rust/crates/rusty-claude-cli/src/main.rs b/rust/crates/rusty-claude-cli/src/main.rs index 47ecd98..5d60f92 100644 --- a/rust/crates/rusty-claude-cli/src/main.rs +++ b/rust/crates/rusty-claude-cli/src/main.rs @@ -22,9 +22,9 @@ use commands::{ use compat_harness::{extract_manifest, UpstreamPaths}; use render::{Spinner, TerminalRenderer}; use runtime::{ - clear_oauth_credentials, generate_pkce_pair, generate_state, load_system_prompt, - parse_oauth_callback_request_target, save_oauth_credentials, ApiClient, ApiRequest, - AssistantEvent, CompactionConfig, ConfigLoader, ConfigSource, ContentBlock, + clear_oauth_credentials, format_usd, generate_pkce_pair, generate_state, load_system_prompt, + parse_oauth_callback_request_target, pricing_for_model, save_oauth_credentials, ApiClient, + ApiRequest, AssistantEvent, CompactionConfig, ConfigLoader, ConfigSource, ContentBlock, ConversationMessage, ConversationRuntime, MessageRole, OAuthAuthorizationRequest, OAuthTokenExchangeRequest, PermissionMode, PermissionPolicy, ProjectContext, RuntimeError, Session, TokenUsage, ToolError, ToolExecutor, UsageTracker, @@ -36,6 +36,7 @@ const DEFAULT_MODEL: &str = "claude-sonnet-4-20250514"; const DEFAULT_MAX_TOKENS: u32 = 32; const DEFAULT_DATE: &str = "2026-03-31"; const DEFAULT_OAUTH_CALLBACK_PORT: u16 = 4545; +const COST_WARNING_FRACTION: f64 = 0.8; const VERSION: &str = env!("CARGO_PKG_VERSION"); const BUILD_TARGET: Option<&str> = option_env!("TARGET"); const GIT_SHA: Option<&str> = option_env!("GIT_SHA"); @@ -70,7 +71,8 @@ fn run() -> Result<(), Box> { output_format, allowed_tools, permission_mode, - } => LiveCli::new(model, false, allowed_tools, permission_mode)? + max_cost_usd, + } => LiveCli::new(model, false, allowed_tools, permission_mode, max_cost_usd)? .run_turn_with_output(&prompt, output_format)?, CliAction::Login => run_login()?, CliAction::Logout => run_logout()?, @@ -78,13 +80,14 @@ fn run() -> Result<(), Box> { model, allowed_tools, permission_mode, - } => run_repl(model, allowed_tools, permission_mode)?, + max_cost_usd, + } => run_repl(model, allowed_tools, permission_mode, max_cost_usd)?, CliAction::Help => print_help(), } Ok(()) } -#[derive(Debug, Clone, PartialEq, Eq)] +#[derive(Debug, Clone, PartialEq)] enum CliAction { DumpManifests, BootstrapPlan, @@ -103,6 +106,7 @@ enum CliAction { output_format: CliOutputFormat, allowed_tools: Option, permission_mode: PermissionMode, + max_cost_usd: Option, }, Login, Logout, @@ -110,6 +114,7 @@ enum CliAction { model: String, allowed_tools: Option, permission_mode: PermissionMode, + max_cost_usd: Option, }, // prompt-mode formatting is only supported for non-interactive runs Help, @@ -139,6 +144,7 @@ fn parse_args(args: &[String]) -> Result { let mut output_format = CliOutputFormat::Text; let mut permission_mode = default_permission_mode(); let mut wants_version = false; + let mut max_cost_usd: Option = None; let mut allowed_tool_values = Vec::new(); let mut rest = Vec::new(); let mut index = 0; @@ -174,6 +180,13 @@ fn parse_args(args: &[String]) -> Result { permission_mode = parse_permission_mode_arg(value)?; index += 2; } + "--max-cost" => { + let value = args + .get(index + 1) + .ok_or_else(|| "missing value for --max-cost".to_string())?; + max_cost_usd = Some(parse_max_cost_arg(value)?); + index += 2; + } flag if flag.starts_with("--output-format=") => { output_format = CliOutputFormat::parse(&flag[16..])?; index += 1; @@ -182,6 +195,10 @@ fn parse_args(args: &[String]) -> Result { permission_mode = parse_permission_mode_arg(&flag[18..])?; index += 1; } + flag if flag.starts_with("--max-cost=") => { + max_cost_usd = Some(parse_max_cost_arg(&flag[11..])?); + index += 1; + } "--allowedTools" | "--allowed-tools" => { let value = args .get(index + 1) @@ -215,6 +232,7 @@ fn parse_args(args: &[String]) -> Result { model, allowed_tools, permission_mode, + max_cost_usd, }); } if matches!(rest.first().map(String::as_str), Some("--help" | "-h")) { @@ -241,6 +259,7 @@ fn parse_args(args: &[String]) -> Result { output_format, allowed_tools, permission_mode, + max_cost_usd, }) } other if !other.starts_with('/') => Ok(CliAction::Prompt { @@ -249,6 +268,7 @@ fn parse_args(args: &[String]) -> Result { output_format, allowed_tools, permission_mode, + max_cost_usd, }), other => Err(format!("unknown subcommand: {other}")), } @@ -312,6 +332,18 @@ fn parse_permission_mode_arg(value: &str) -> Result { .map(permission_mode_from_label) } +fn parse_max_cost_arg(value: &str) -> Result { + let parsed = value + .parse::() + .map_err(|_| format!("invalid value for --max-cost: {value}"))?; + if !parsed.is_finite() || parsed <= 0.0 { + return Err(format!( + "--max-cost must be a positive finite USD amount: {value}" + )); + } + Ok(parsed) +} + fn permission_mode_from_label(mode: &str) -> PermissionMode { match mode { "read-only" => PermissionMode::ReadOnly, @@ -678,22 +710,78 @@ fn format_permissions_switch_report(previous: &str, next: &str) -> String { ) } -fn format_cost_report(usage: TokenUsage) -> String { +fn format_cost_report(model: &str, usage: TokenUsage, max_cost_usd: Option) -> String { + let estimate = usage_cost_estimate(model, usage); format!( "Cost + Model {model} Input tokens {} Output tokens {} Cache create {} Cache read {} - Total tokens {}", + Total tokens {} + Input cost {} + Output cost {} + Cache create usd {} + Cache read usd {} + Estimated cost {} + Budget {}", usage.input_tokens, usage.output_tokens, usage.cache_creation_input_tokens, usage.cache_read_input_tokens, usage.total_tokens(), + format_usd(estimate.input_cost_usd), + format_usd(estimate.output_cost_usd), + format_usd(estimate.cache_creation_cost_usd), + format_usd(estimate.cache_read_cost_usd), + format_usd(estimate.total_cost_usd()), + format_budget_line(estimate.total_cost_usd(), max_cost_usd), ) } +fn usage_cost_estimate(model: &str, usage: TokenUsage) -> runtime::UsageCostEstimate { + pricing_for_model(model).map_or_else( + || usage.estimate_cost_usd(), + |pricing| usage.estimate_cost_usd_with_pricing(pricing), + ) +} + +fn usage_cost_total(model: &str, usage: TokenUsage) -> f64 { + usage_cost_estimate(model, usage).total_cost_usd() +} + +fn format_budget_line(cost_usd: f64, max_cost_usd: Option) -> String { + match max_cost_usd { + Some(limit) => format!("{} / {}", format_usd(cost_usd), format_usd(limit)), + None => format!("{} (unlimited)", format_usd(cost_usd)), + } +} + +fn budget_notice_message( + model: &str, + usage: TokenUsage, + max_cost_usd: Option, +) -> Option { + let limit = max_cost_usd?; + let cost = usage_cost_total(model, usage); + if cost >= limit { + Some(format!( + "cost budget exceeded: cumulative={} budget={}", + format_usd(cost), + format_usd(limit) + )) + } else if cost >= limit * COST_WARNING_FRACTION { + Some(format!( + "approaching cost budget: cumulative={} budget={}", + format_usd(cost), + format_usd(limit) + )) + } else { + None + } +} + fn format_resume_report(session_path: &str, message_count: usize, turns: u32) -> String { format!( "Session resumed @@ -837,6 +925,7 @@ fn run_resume_command( }, default_permission_mode().as_str(), &status_context(Some(session_path))?, + None, )), }) } @@ -844,7 +933,7 @@ fn run_resume_command( let usage = UsageTracker::from_session(session).cumulative_usage(); Ok(ResumeCommandOutcome { session: session.clone(), - message: Some(format_cost_report(usage)), + message: Some(format_cost_report("restored-session", usage, None)), }) } SlashCommand::Config { section } => Ok(ResumeCommandOutcome { @@ -891,8 +980,9 @@ fn run_repl( model: String, allowed_tools: Option, permission_mode: PermissionMode, + max_cost_usd: Option, ) -> Result<(), Box> { - let mut cli = LiveCli::new(model, true, allowed_tools, permission_mode)?; + let mut cli = LiveCli::new(model, true, allowed_tools, permission_mode, max_cost_usd)?; let mut editor = input::LineEditor::new("› ", slash_command_completion_candidates()); println!("{}", cli.startup_banner()); @@ -945,6 +1035,7 @@ struct LiveCli { model: String, allowed_tools: Option, permission_mode: PermissionMode, + max_cost_usd: Option, system_prompt: Vec, runtime: ConversationRuntime, session: SessionHandle, @@ -956,6 +1047,7 @@ impl LiveCli { enable_tools: bool, allowed_tools: Option, permission_mode: PermissionMode, + max_cost_usd: Option, ) -> Result> { let system_prompt = build_system_prompt()?; let session = create_managed_session_handle()?; @@ -971,6 +1063,7 @@ impl LiveCli { model, allowed_tools, permission_mode, + max_cost_usd, system_prompt, runtime, session, @@ -981,9 +1074,10 @@ impl LiveCli { fn startup_banner(&self) -> String { format!( - "Rusty Claude CLI\n Model {}\n Permission mode {}\n Working directory {}\n Session {}\n\nType /help for commands. Shift+Enter or Ctrl+J inserts a newline.", + "Rusty Claude CLI\n Model {}\n Permission mode {}\n Cost budget {}\n Working directory {}\n Session {}\n\nType /help for commands. Shift+Enter or Ctrl+J inserts a newline.", self.model, self.permission_mode.as_str(), + self.max_cost_usd.map_or_else(|| "none".to_string(), format_usd), env::current_dir().map_or_else( |_| "".to_string(), |path| path.display().to_string(), @@ -993,6 +1087,7 @@ impl LiveCli { } fn run_turn(&mut self, input: &str) -> Result<(), Box> { + self.enforce_budget_before_turn()?; let mut spinner = Spinner::new(); let mut stdout = io::stdout(); spinner.tick( @@ -1003,13 +1098,14 @@ impl LiveCli { let mut permission_prompter = CliPermissionPrompter::new(self.permission_mode); let result = self.runtime.run_turn(input, Some(&mut permission_prompter)); match result { - Ok(_) => { + Ok(summary) => { spinner.finish( "Claude response complete", TerminalRenderer::new().color_theme(), &mut stdout, )?; println!(); + self.print_budget_notice(summary.usage); self.persist_session()?; Ok(()) } @@ -1036,6 +1132,7 @@ impl LiveCli { } fn run_prompt_json(&mut self, input: &str) -> Result<(), Box> { + self.enforce_budget_before_turn()?; let client = AnthropicClient::from_auth(resolve_cli_auth_source()?); let request = MessageRequest { model: self.model.clone(), @@ -1062,17 +1159,27 @@ impl LiveCli { }) .collect::>() .join(""); + let usage = TokenUsage { + input_tokens: response.usage.input_tokens, + output_tokens: response.usage.output_tokens, + cache_creation_input_tokens: response.usage.cache_creation_input_tokens, + cache_read_input_tokens: response.usage.cache_read_input_tokens, + }; println!( "{}", json!({ "message": text, "model": self.model, "usage": { - "input_tokens": response.usage.input_tokens, - "output_tokens": response.usage.output_tokens, - "cache_creation_input_tokens": response.usage.cache_creation_input_tokens, - "cache_read_input_tokens": response.usage.cache_read_input_tokens, - } + "input_tokens": usage.input_tokens, + "output_tokens": usage.output_tokens, + "cache_creation_input_tokens": usage.cache_creation_input_tokens, + "cache_read_input_tokens": usage.cache_read_input_tokens, + }, + "cost_usd": usage_cost_total(&self.model, usage), + "cumulative_cost_usd": usage_cost_total(&self.model, usage), + "max_cost_usd": self.max_cost_usd, + "budget_warning": budget_notice_message(&self.model, usage, self.max_cost_usd), }) ); Ok(()) @@ -1142,6 +1249,28 @@ impl LiveCli { Ok(()) } + fn enforce_budget_before_turn(&self) -> Result<(), Box> { + let Some(limit) = self.max_cost_usd else { + return Ok(()); + }; + let cost = usage_cost_total(&self.model, self.runtime.usage().cumulative_usage()); + if cost >= limit { + return Err(format!( + "cost budget exceeded before starting turn: cumulative={} budget={}", + format_usd(cost), + format_usd(limit) + ) + .into()); + } + Ok(()) + } + + fn print_budget_notice(&self, usage: TokenUsage) { + if let Some(message) = budget_notice_message(&self.model, usage, self.max_cost_usd) { + eprintln!("warning: {message}"); + } + } + fn print_status(&self) { let cumulative = self.runtime.usage().cumulative_usage(); let latest = self.runtime.usage().current_turn_usage(); @@ -1158,6 +1287,7 @@ impl LiveCli { }, self.permission_mode.as_str(), &status_context(Some(&self.session.path)).expect("status context should load"), + self.max_cost_usd, ) ); } @@ -1275,7 +1405,10 @@ impl LiveCli { fn print_cost(&self) { let cumulative = self.runtime.usage().cumulative_usage(); - println!("{}", format_cost_report(cumulative)); + println!( + "{}", + format_cost_report(&self.model, cumulative, self.max_cost_usd) + ); } fn resume_session( @@ -1553,7 +1686,10 @@ fn format_status_report( usage: StatusUsage, permission_mode: &str, context: &StatusContext, + max_cost_usd: Option, ) -> String { + let latest_cost = usage_cost_total(model, usage.latest); + let cumulative_cost = usage_cost_total(model, usage.cumulative); [ format!( "Status @@ -1561,19 +1697,27 @@ fn format_status_report( Permission mode {permission_mode} Messages {} Turns {} - Estimated tokens {}", - usage.message_count, usage.turns, usage.estimated_tokens, + Estimated tokens {} + Cost budget {}", + usage.message_count, + usage.turns, + usage.estimated_tokens, + format_budget_line(cumulative_cost, max_cost_usd), ), format!( "Usage Latest total {} + Latest cost {} Cumulative input {} Cumulative output {} - Cumulative total {}", + Cumulative total {} + Cumulative cost {}", usage.latest.total_tokens(), + format_usd(latest_cost), usage.cumulative.input_tokens, usage.cumulative.output_tokens, usage.cumulative.total_tokens(), + format_usd(cumulative_cost), ), format!( "Workspace @@ -2345,9 +2489,9 @@ fn print_help() { println!("rusty-claude-cli v{VERSION}"); println!(); println!("Usage:"); - println!(" rusty-claude-cli [--model MODEL] [--allowedTools TOOL[,TOOL...]]"); + println!(" rusty-claude-cli [--model MODEL] [--max-cost USD] [--allowedTools TOOL[,TOOL...]]"); println!(" Start the interactive REPL"); - println!(" rusty-claude-cli [--model MODEL] [--output-format text|json] prompt TEXT"); + println!(" rusty-claude-cli [--model MODEL] [--max-cost USD] [--output-format text|json] prompt TEXT"); println!(" Send one prompt and exit"); println!(" rusty-claude-cli [--model MODEL] [--output-format text|json] TEXT"); println!(" Shorthand non-interactive prompt mode"); @@ -2363,6 +2507,7 @@ fn print_help() { println!(" --model MODEL Override the active model"); println!(" --output-format FORMAT Non-interactive output format: text or json"); println!(" --permission-mode MODE Set read-only, workspace-write, or danger-full-access"); + println!(" --max-cost USD Warn at 80% of budget and stop at/exceeding the budget"); println!(" --allowedTools TOOLS Restrict enabled tools (repeatable; comma-separated aliases supported)"); println!(" --version, -V Print version and build information locally"); println!(); @@ -2389,13 +2534,14 @@ fn print_help() { #[cfg(test)] mod tests { use super::{ - filter_tool_specs, format_compact_report, format_cost_report, format_init_report, - format_model_report, format_model_switch_report, format_permissions_report, - format_permissions_switch_report, format_resume_report, format_status_report, - format_tool_call_start, format_tool_result, normalize_permission_mode, parse_args, - parse_git_status_metadata, render_config_report, render_init_claude_md, - render_memory_report, render_repl_help, resume_supported_slash_commands, status_context, - CliAction, CliOutputFormat, SlashCommand, StatusUsage, DEFAULT_MODEL, + budget_notice_message, filter_tool_specs, format_compact_report, format_cost_report, + format_init_report, format_model_report, format_model_switch_report, + format_permissions_report, format_permissions_switch_report, format_resume_report, + format_status_report, format_tool_call_start, format_tool_result, + normalize_permission_mode, parse_args, parse_git_status_metadata, render_config_report, + render_init_claude_md, render_memory_report, render_repl_help, + resume_supported_slash_commands, status_context, CliAction, CliOutputFormat, SlashCommand, + StatusUsage, DEFAULT_MODEL, }; use runtime::{ContentBlock, ConversationMessage, MessageRole, PermissionMode}; use std::path::{Path, PathBuf}; @@ -2408,6 +2554,7 @@ mod tests { model: DEFAULT_MODEL.to_string(), allowed_tools: None, permission_mode: PermissionMode::WorkspaceWrite, + max_cost_usd: None, } ); } @@ -2427,6 +2574,7 @@ mod tests { output_format: CliOutputFormat::Text, allowed_tools: None, permission_mode: PermissionMode::WorkspaceWrite, + max_cost_usd: None, } ); } @@ -2448,6 +2596,7 @@ mod tests { output_format: CliOutputFormat::Json, allowed_tools: None, permission_mode: PermissionMode::WorkspaceWrite, + max_cost_usd: None, } ); } @@ -2473,10 +2622,32 @@ mod tests { model: DEFAULT_MODEL.to_string(), allowed_tools: None, permission_mode: PermissionMode::ReadOnly, + max_cost_usd: None, } ); } + #[test] + fn parses_max_cost_flag() { + let args = vec!["--max-cost=1.25".to_string()]; + assert_eq!( + parse_args(&args).expect("args should parse"), + CliAction::Repl { + model: DEFAULT_MODEL.to_string(), + allowed_tools: None, + permission_mode: PermissionMode::WorkspaceWrite, + max_cost_usd: Some(1.25), + } + ); + } + + #[test] + fn rejects_invalid_max_cost_flag() { + let error = parse_args(&["--max-cost".to_string(), "0".to_string()]) + .expect_err("zero max cost should be rejected"); + assert!(error.contains("--max-cost must be a positive finite USD amount")); + } + #[test] fn parses_allowed_tools_flags_with_aliases_and_lists() { let args = vec![ @@ -2495,6 +2666,7 @@ mod tests { .collect() ), permission_mode: PermissionMode::WorkspaceWrite, + max_cost_usd: None, } ); } @@ -2652,18 +2824,24 @@ mod tests { #[test] fn cost_report_uses_sectioned_layout() { - let report = format_cost_report(runtime::TokenUsage { - input_tokens: 20, - output_tokens: 8, - cache_creation_input_tokens: 3, - cache_read_input_tokens: 1, - }); + let report = format_cost_report( + "claude-sonnet", + runtime::TokenUsage { + input_tokens: 20, + output_tokens: 8, + cache_creation_input_tokens: 3, + cache_read_input_tokens: 1, + }, + None, + ); assert!(report.contains("Cost")); assert!(report.contains("Input tokens 20")); assert!(report.contains("Output tokens 8")); assert!(report.contains("Cache create 3")); assert!(report.contains("Cache read 1")); assert!(report.contains("Total tokens 32")); + assert!(report.contains("Estimated cost")); + assert!(report.contains("Budget $0.0010 (unlimited)")); } #[test] @@ -2745,6 +2923,7 @@ mod tests { project_root: Some(PathBuf::from("/tmp")), git_branch: Some("main".to_string()), }, + Some(1.0), ); assert!(status.contains("Status")); assert!(status.contains("Model claude-sonnet")); @@ -2752,6 +2931,7 @@ mod tests { assert!(status.contains("Messages 7")); assert!(status.contains("Latest total 10")); assert!(status.contains("Cumulative total 31")); + assert!(status.contains("Cost budget $0.0009 / $1.0000")); assert!(status.contains("Cwd /tmp/project")); assert!(status.contains("Project root /tmp")); assert!(status.contains("Git branch main")); @@ -2760,6 +2940,22 @@ mod tests { assert!(status.contains("Memory files 4")); } + #[test] + fn budget_notice_warns_near_limit() { + let message = budget_notice_message( + "claude-sonnet", + runtime::TokenUsage { + input_tokens: 60_000, + output_tokens: 0, + cache_creation_input_tokens: 0, + cache_read_input_tokens: 0, + }, + Some(1.0), + ) + .expect("budget warning expected"); + assert!(message.contains("approaching cost budget")); + } + #[test] fn config_report_supports_section_views() { let report = render_config_report(Some("env")).expect("config report should render"); @@ -2797,8 +2993,8 @@ mod tests { fn status_context_reads_real_workspace_metadata() { let context = status_context(None).expect("status context should load"); assert!(context.cwd.is_absolute()); - assert_eq!(context.discovered_config_files, 3); - assert!(context.loaded_config_files <= context.discovered_config_files); + assert!(context.discovered_config_files >= context.loaded_config_files); + assert!(context.discovered_config_files >= 1); } #[test]