From 9efd029e26f1764c134fdf0b380f6cb6b34c3e94 Mon Sep 17 00:00:00 2001 From: Yeachan-Heo Date: Wed, 1 Apr 2026 04:40:18 +0000 Subject: [PATCH] wip: hook-pipeline progress --- rust/crates/runtime/src/conversation.rs | 105 ++++++++++++++++----- rust/crates/runtime/src/hooks.rs | 54 ++++------- rust/crates/rusty-claude-cli/src/main.rs | 33 ++++--- rust/crates/rusty-claude-cli/src/render.rs | 4 +- 4 files changed, 124 insertions(+), 72 deletions(-) diff --git a/rust/crates/runtime/src/conversation.rs b/rust/crates/runtime/src/conversation.rs index 2c5f6ea..358e1cc 100644 --- a/rust/crates/runtime/src/conversation.rs +++ b/rust/crates/runtime/src/conversation.rs @@ -5,7 +5,7 @@ use crate::compact::{ compact_session, estimate_session_tokens, CompactionConfig, CompactionResult, }; use crate::config::RuntimeFeatureConfig; -use crate::hooks::{HookAbortSignal, HookProgressReporter, HookRunResult, HookRunner}; +use crate::hooks::{HookAbortSignal, HookRunResult, HookRunner}; use crate::permissions::{ PermissionContext, PermissionOutcome, PermissionPolicy, PermissionPrompter, }; @@ -100,7 +100,6 @@ pub struct ConversationRuntime { usage_tracker: UsageTracker, hook_runner: HookRunner, hook_abort_signal: HookAbortSignal, - hook_progress_reporter: Option>, } impl ConversationRuntime @@ -122,18 +121,19 @@ where tool_executor, permission_policy, system_prompt, - &RuntimeFeatureConfig::default(), + RuntimeFeatureConfig::default(), ) } #[must_use] + #[allow(clippy::needless_pass_by_value)] pub fn new_with_features( session: Session, api_client: C, tool_executor: T, permission_policy: PermissionPolicy, system_prompt: Vec, - feature_config: &RuntimeFeatureConfig, + feature_config: RuntimeFeatureConfig, ) -> Self { let usage_tracker = UsageTracker::from_session(&session); Self { @@ -144,8 +144,9 @@ where system_prompt, max_iterations: usize::MAX, usage_tracker, - hook_runner: HookRunner::from_feature_config(feature_config), + hook_runner: HookRunner::from_feature_config(&feature_config), hook_abort_signal: HookAbortSignal::default(), + hook_progress_reporter: None, } } @@ -220,17 +221,12 @@ where } for (tool_use_id, tool_name, input) in pending_tool_uses { - let pre_hook_result = self.hook_runner.run_pre_tool_use_with_context( - &tool_name, - &input, - Some(&self.hook_abort_signal), - self.hook_progress_reporter.as_deref_mut(), - ); + let pre_hook_result = self.run_pre_tool_use_hook(&tool_name, &input); let effective_input = pre_hook_result - .updated_input_json() + .updated_input() .map_or_else(|| input.clone(), ToOwned::to_owned); let permission_context = PermissionContext::new( - pre_hook_result.permission_decision(), + pre_hook_result.permission_override(), pre_hook_result.permission_reason().map(ToOwned::to_owned), ); @@ -274,21 +270,17 @@ where output = merge_hook_feedback(pre_hook_result.messages(), output, false); let post_hook_result = if is_error { - self.hook_runner.run_post_tool_use_failure_with_context( + self.run_post_tool_use_failure_hook( &tool_name, &effective_input, &output, - Some(&self.hook_abort_signal), - self.hook_progress_reporter.as_deref_mut(), ) } else { - self.hook_runner.run_post_tool_use_with_context( + self.run_post_tool_use_hook( &tool_name, &effective_input, &output, false, - Some(&self.hook_abort_signal), - self.hook_progress_reporter.as_deref_mut(), ) }; if post_hook_result.is_denied() || post_hook_result.is_cancelled() { @@ -322,6 +314,77 @@ where }) } + fn run_pre_tool_use_hook(&mut self, tool_name: &str, input: &str) -> HookRunResult { + if let Some(reporter) = self.hook_progress_reporter.as_mut() { + self.hook_runner.run_pre_tool_use_with_context( + tool_name, + input, + Some(&self.hook_abort_signal), + Some(reporter.as_mut()), + ) + } else { + self.hook_runner.run_pre_tool_use_with_context( + tool_name, + input, + Some(&self.hook_abort_signal), + None, + ) + } + } + + fn run_post_tool_use_hook( + &mut self, + tool_name: &str, + input: &str, + output: &str, + is_error: bool, + ) -> HookRunResult { + if let Some(reporter) = self.hook_progress_reporter.as_mut() { + self.hook_runner.run_post_tool_use_with_context( + tool_name, + input, + output, + is_error, + Some(&self.hook_abort_signal), + Some(reporter.as_mut()), + ) + } else { + self.hook_runner.run_post_tool_use_with_context( + tool_name, + input, + output, + is_error, + Some(&self.hook_abort_signal), + None, + ) + } + } + + fn run_post_tool_use_failure_hook( + &mut self, + tool_name: &str, + input: &str, + output: &str, + ) -> HookRunResult { + if let Some(reporter) = self.hook_progress_reporter.as_mut() { + self.hook_runner.run_post_tool_use_failure_with_context( + tool_name, + input, + output, + Some(&self.hook_abort_signal), + Some(reporter.as_mut()), + ) + } else { + self.hook_runner.run_post_tool_use_failure_with_context( + tool_name, + input, + output, + Some(&self.hook_abort_signal), + None, + ) + } + } + #[must_use] pub fn compact(&self, config: CompactionConfig) -> CompactionResult { compact_session(&self.session, config) @@ -669,7 +732,7 @@ mod tests { }), PermissionPolicy::new(PermissionMode::DangerFullAccess), vec!["system".to_string()], - &RuntimeFeatureConfig::default().with_hooks(RuntimeHookConfig::new( + RuntimeFeatureConfig::default().with_hooks(RuntimeHookConfig::new( vec![shell_snippet("printf 'blocked by hook'; exit 2")], Vec::new(), Vec::new(), @@ -736,7 +799,7 @@ mod tests { StaticToolExecutor::new().register("add", |_input| Ok("4".to_string())), PermissionPolicy::new(PermissionMode::DangerFullAccess), vec!["system".to_string()], - &RuntimeFeatureConfig::default().with_hooks(RuntimeHookConfig::new( + RuntimeFeatureConfig::default().with_hooks(RuntimeHookConfig::new( vec![shell_snippet("printf 'pre hook ran'")], vec![shell_snippet("printf 'post hook ran'")], Vec::new(), diff --git a/rust/crates/runtime/src/hooks.rs b/rust/crates/runtime/src/hooks.rs index 3af6cbd..4ef5a6c 100644 --- a/rust/crates/runtime/src/hooks.rs +++ b/rust/crates/runtime/src/hooks.rs @@ -1,16 +1,14 @@ use std::ffi::OsStr; +use std::io::Write; use std::process::{Command, Stdio}; use std::sync::{ atomic::{AtomicBool, Ordering}, Arc, }; +use std::thread; use std::time::Duration; use serde_json::{json, Value}; -use tokio::io::AsyncWriteExt; -use tokio::process::Command as TokioCommand; -use tokio::runtime::Builder; -use tokio::time::sleep; use crate::config::{RuntimeFeatureConfig, RuntimeHookConfig}; use crate::permissions::PermissionOverride; @@ -172,7 +170,7 @@ impl HookRunner { abort_signal: Option<&HookAbortSignal>, reporter: Option<&mut dyn HookProgressReporter>, ) -> HookRunResult { - self.run_commands( + Self::run_commands( HookEvent::PreToolUse, self.config.pre_tool_use(), tool_name, @@ -222,7 +220,7 @@ impl HookRunner { abort_signal: Option<&HookAbortSignal>, reporter: Option<&mut dyn HookProgressReporter>, ) -> HookRunResult { - self.run_commands( + Self::run_commands( HookEvent::PostToolUse, self.config.post_tool_use(), tool_name, @@ -272,7 +270,7 @@ impl HookRunner { abort_signal: Option<&HookAbortSignal>, reporter: Option<&mut dyn HookProgressReporter>, ) -> HookRunResult { - self.run_commands( + Self::run_commands( HookEvent::PostToolUseFailure, self.config.post_tool_use_failure(), tool_name, @@ -303,7 +301,6 @@ impl HookRunner { #[allow(clippy::too_many_arguments)] fn run_commands( - &self, event: HookEvent, commands: &[String], tool_name: &str, @@ -675,36 +672,23 @@ impl CommandWithStdin { stdin: &[u8], abort_signal: Option<&HookAbortSignal>, ) -> std::io::Result { - let runtime = Builder::new_current_thread().enable_all().build()?; - let mut command = - TokioCommand::from(std::mem::replace(&mut self.command, Command::new("true"))); - let stdin = stdin.to_vec(); - let abort_signal = abort_signal.cloned(); - runtime.block_on(async move { - let mut child = command.spawn()?; - if let Some(mut child_stdin) = child.stdin.take() { - child_stdin.write_all(&stdin).await?; + let mut child = self.command.spawn()?; + if let Some(mut child_stdin) = child.stdin.take() { + child_stdin.write_all(stdin)?; + } + + loop { + if abort_signal.is_some_and(HookAbortSignal::is_aborted) { + let _ = child.kill(); + let _ = child.wait_with_output(); + return Ok(CommandExecution::Cancelled); } - loop { - if abort_signal - .as_ref() - .is_some_and(HookAbortSignal::is_aborted) - { - let _ = child.start_kill(); - let _ = child.wait().await; - return Ok(CommandExecution::Cancelled); - } - - if let Some(status) = child.try_wait()? { - let output = child.wait_with_output().await?; - debug_assert_eq!(output.status.code(), status.code()); - return Ok(CommandExecution::Finished(output)); - } - - sleep(Duration::from_millis(20)).await; + match child.try_wait()? { + Some(_) => return child.wait_with_output().map(CommandExecution::Finished), + None => thread::sleep(Duration::from_millis(20)), } - }) + } } } diff --git a/rust/crates/rusty-claude-cli/src/main.rs b/rust/crates/rusty-claude-cli/src/main.rs index ff241eb..935bade 100644 --- a/rust/crates/rusty-claude-cli/src/main.rs +++ b/rust/crates/rusty-claude-cli/src/main.rs @@ -1923,14 +1923,15 @@ fn build_runtime( ) -> Result, Box> { let feature_config = build_runtime_feature_config()?; - Ok(ConversationRuntime::new_with_features( + let runtime = ConversationRuntime::new_with_features( session, AnthropicRuntimeClient::new(model, enable_tools, emit_output, allowed_tools.clone())?, CliToolExecutor::new(allowed_tools, emit_output), permission_policy(permission_mode, &feature_config), system_prompt, - &feature_config, - )) + feature_config, + ); + Ok(runtime) } struct CliPermissionPrompter { @@ -1953,6 +1954,9 @@ impl runtime::PermissionPrompter for CliPermissionPrompter { println!(" Tool {}", request.tool_name); println!(" Current mode {}", self.current_mode.as_str()); println!(" Required mode {}", request.required_mode.as_str()); + if let Some(reason) = &request.reason { + println!(" Reason {reason}"); + } println!(" Input {}", request.input); print!("Approve this tool call? [y/N]: "); let _ = io::stdout().flush(); @@ -2365,13 +2369,15 @@ fn format_bash_result(icon: &str, parsed: &serde_json::Value) -> String { .get("backgroundTaskId") .and_then(|value| value.as_str()) { - lines[0].push_str(&format!(" backgrounded ({task_id})")); + use std::fmt::Write as _; + let _ = write!(lines[0], " backgrounded ({task_id})"); } else if let Some(status) = parsed .get("returnCodeInterpretation") .and_then(|value| value.as_str()) .filter(|status| !status.is_empty()) { - lines[0].push_str(&format!(" {status}")); + use std::fmt::Write as _; + let _ = write!(lines[0], " {status}"); } if let Some(stdout) = parsed.get("stdout").and_then(|value| value.as_str()) { @@ -2393,15 +2399,15 @@ fn format_read_result(icon: &str, parsed: &serde_json::Value) -> String { let path = extract_tool_path(file); let start_line = file .get("startLine") - .and_then(|value| value.as_u64()) + .and_then(serde_json::Value::as_u64) .unwrap_or(1); let num_lines = file .get("numLines") - .and_then(|value| value.as_u64()) + .and_then(serde_json::Value::as_u64) .unwrap_or(0); let total_lines = file .get("totalLines") - .and_then(|value| value.as_u64()) + .and_then(serde_json::Value::as_u64) .unwrap_or(num_lines); let content = file .get("content") @@ -2427,8 +2433,7 @@ fn format_write_result(icon: &str, parsed: &serde_json::Value) -> String { let line_count = parsed .get("content") .and_then(|value| value.as_str()) - .map(|content| content.lines().count()) - .unwrap_or(0); + .map_or(0, |content| content.lines().count()); format!( "{icon} \x1b[1;32m✏️ {} {path}\x1b[0m \x1b[2m({line_count} lines)\x1b[0m", if kind == "create" { "Wrote" } else { "Updated" }, @@ -2459,7 +2464,7 @@ fn format_edit_result(icon: &str, parsed: &serde_json::Value) -> String { let path = extract_tool_path(parsed); let suffix = if parsed .get("replaceAll") - .and_then(|value| value.as_bool()) + .and_then(serde_json::Value::as_bool) .unwrap_or(false) { " (replace all)" @@ -2487,7 +2492,7 @@ fn format_edit_result(icon: &str, parsed: &serde_json::Value) -> String { fn format_glob_result(icon: &str, parsed: &serde_json::Value) -> String { let num_files = parsed .get("numFiles") - .and_then(|value| value.as_u64()) + .and_then(serde_json::Value::as_u64) .unwrap_or(0); let filenames = parsed .get("filenames") @@ -2511,11 +2516,11 @@ fn format_glob_result(icon: &str, parsed: &serde_json::Value) -> String { fn format_grep_result(icon: &str, parsed: &serde_json::Value) -> String { let num_matches = parsed .get("numMatches") - .and_then(|value| value.as_u64()) + .and_then(serde_json::Value::as_u64) .unwrap_or(0); let num_files = parsed .get("numFiles") - .and_then(|value| value.as_u64()) + .and_then(serde_json::Value::as_u64) .unwrap_or(0); let content = parsed .get("content") diff --git a/rust/crates/rusty-claude-cli/src/render.rs b/rust/crates/rusty-claude-cli/src/render.rs index 465c5a4..d8d8796 100644 --- a/rust/crates/rusty-claude-cli/src/render.rs +++ b/rust/crates/rusty-claude-cli/src/render.rs @@ -286,7 +286,7 @@ impl TerminalRenderer { ) { match event { Event::Start(Tag::Heading { level, .. }) => { - self.start_heading(state, level as u8, output) + Self::start_heading(state, level as u8, output); } Event::End(TagEnd::Paragraph) => output.push_str("\n\n"), Event::Start(Tag::BlockQuote(..)) => self.start_quote(state, output), @@ -426,7 +426,7 @@ impl TerminalRenderer { } } - fn start_heading(&self, state: &mut RenderState, level: u8, output: &mut String) { + fn start_heading(state: &mut RenderState, level: u8, output: &mut String) { state.heading_level = Some(level); if !output.is_empty() { output.push('\n');