wip: hook-pipeline progress

This commit is contained in:
Yeachan-Heo
2026-04-01 04:40:18 +00:00
parent eb89fc95e7
commit 9efd029e26
4 changed files with 124 additions and 72 deletions

View File

@@ -5,7 +5,7 @@ use crate::compact::{
compact_session, estimate_session_tokens, CompactionConfig, CompactionResult, compact_session, estimate_session_tokens, CompactionConfig, CompactionResult,
}; };
use crate::config::RuntimeFeatureConfig; use crate::config::RuntimeFeatureConfig;
use crate::hooks::{HookAbortSignal, HookProgressReporter, HookRunResult, HookRunner}; use crate::hooks::{HookAbortSignal, HookRunResult, HookRunner};
use crate::permissions::{ use crate::permissions::{
PermissionContext, PermissionOutcome, PermissionPolicy, PermissionPrompter, PermissionContext, PermissionOutcome, PermissionPolicy, PermissionPrompter,
}; };
@@ -100,7 +100,6 @@ pub struct ConversationRuntime<C, T> {
usage_tracker: UsageTracker, usage_tracker: UsageTracker,
hook_runner: HookRunner, hook_runner: HookRunner,
hook_abort_signal: HookAbortSignal, hook_abort_signal: HookAbortSignal,
hook_progress_reporter: Option<Box<dyn HookProgressReporter>>,
} }
impl<C, T> ConversationRuntime<C, T> impl<C, T> ConversationRuntime<C, T>
@@ -122,18 +121,19 @@ where
tool_executor, tool_executor,
permission_policy, permission_policy,
system_prompt, system_prompt,
&RuntimeFeatureConfig::default(), RuntimeFeatureConfig::default(),
) )
} }
#[must_use] #[must_use]
#[allow(clippy::needless_pass_by_value)]
pub fn new_with_features( pub fn new_with_features(
session: Session, session: Session,
api_client: C, api_client: C,
tool_executor: T, tool_executor: T,
permission_policy: PermissionPolicy, permission_policy: PermissionPolicy,
system_prompt: Vec<String>, system_prompt: Vec<String>,
feature_config: &RuntimeFeatureConfig, feature_config: RuntimeFeatureConfig,
) -> Self { ) -> Self {
let usage_tracker = UsageTracker::from_session(&session); let usage_tracker = UsageTracker::from_session(&session);
Self { Self {
@@ -144,8 +144,9 @@ where
system_prompt, system_prompt,
max_iterations: usize::MAX, max_iterations: usize::MAX,
usage_tracker, usage_tracker,
hook_runner: HookRunner::from_feature_config(feature_config), hook_runner: HookRunner::from_feature_config(&feature_config),
hook_abort_signal: HookAbortSignal::default(), hook_abort_signal: HookAbortSignal::default(),
hook_progress_reporter: None,
} }
} }
@@ -220,17 +221,12 @@ where
} }
for (tool_use_id, tool_name, input) in pending_tool_uses { for (tool_use_id, tool_name, input) in pending_tool_uses {
let pre_hook_result = self.hook_runner.run_pre_tool_use_with_context( let pre_hook_result = self.run_pre_tool_use_hook(&tool_name, &input);
&tool_name,
&input,
Some(&self.hook_abort_signal),
self.hook_progress_reporter.as_deref_mut(),
);
let effective_input = pre_hook_result let effective_input = pre_hook_result
.updated_input_json() .updated_input()
.map_or_else(|| input.clone(), ToOwned::to_owned); .map_or_else(|| input.clone(), ToOwned::to_owned);
let permission_context = PermissionContext::new( let permission_context = PermissionContext::new(
pre_hook_result.permission_decision(), pre_hook_result.permission_override(),
pre_hook_result.permission_reason().map(ToOwned::to_owned), pre_hook_result.permission_reason().map(ToOwned::to_owned),
); );
@@ -274,21 +270,17 @@ where
output = merge_hook_feedback(pre_hook_result.messages(), output, false); output = merge_hook_feedback(pre_hook_result.messages(), output, false);
let post_hook_result = if is_error { let post_hook_result = if is_error {
self.hook_runner.run_post_tool_use_failure_with_context( self.run_post_tool_use_failure_hook(
&tool_name, &tool_name,
&effective_input, &effective_input,
&output, &output,
Some(&self.hook_abort_signal),
self.hook_progress_reporter.as_deref_mut(),
) )
} else { } else {
self.hook_runner.run_post_tool_use_with_context( self.run_post_tool_use_hook(
&tool_name, &tool_name,
&effective_input, &effective_input,
&output, &output,
false, false,
Some(&self.hook_abort_signal),
self.hook_progress_reporter.as_deref_mut(),
) )
}; };
if post_hook_result.is_denied() || post_hook_result.is_cancelled() { if post_hook_result.is_denied() || post_hook_result.is_cancelled() {
@@ -322,6 +314,77 @@ where
}) })
} }
fn run_pre_tool_use_hook(&mut self, tool_name: &str, input: &str) -> HookRunResult {
if let Some(reporter) = self.hook_progress_reporter.as_mut() {
self.hook_runner.run_pre_tool_use_with_context(
tool_name,
input,
Some(&self.hook_abort_signal),
Some(reporter.as_mut()),
)
} else {
self.hook_runner.run_pre_tool_use_with_context(
tool_name,
input,
Some(&self.hook_abort_signal),
None,
)
}
}
fn run_post_tool_use_hook(
&mut self,
tool_name: &str,
input: &str,
output: &str,
is_error: bool,
) -> HookRunResult {
if let Some(reporter) = self.hook_progress_reporter.as_mut() {
self.hook_runner.run_post_tool_use_with_context(
tool_name,
input,
output,
is_error,
Some(&self.hook_abort_signal),
Some(reporter.as_mut()),
)
} else {
self.hook_runner.run_post_tool_use_with_context(
tool_name,
input,
output,
is_error,
Some(&self.hook_abort_signal),
None,
)
}
}
fn run_post_tool_use_failure_hook(
&mut self,
tool_name: &str,
input: &str,
output: &str,
) -> HookRunResult {
if let Some(reporter) = self.hook_progress_reporter.as_mut() {
self.hook_runner.run_post_tool_use_failure_with_context(
tool_name,
input,
output,
Some(&self.hook_abort_signal),
Some(reporter.as_mut()),
)
} else {
self.hook_runner.run_post_tool_use_failure_with_context(
tool_name,
input,
output,
Some(&self.hook_abort_signal),
None,
)
}
}
#[must_use] #[must_use]
pub fn compact(&self, config: CompactionConfig) -> CompactionResult { pub fn compact(&self, config: CompactionConfig) -> CompactionResult {
compact_session(&self.session, config) compact_session(&self.session, config)
@@ -669,7 +732,7 @@ mod tests {
}), }),
PermissionPolicy::new(PermissionMode::DangerFullAccess), PermissionPolicy::new(PermissionMode::DangerFullAccess),
vec!["system".to_string()], vec!["system".to_string()],
&RuntimeFeatureConfig::default().with_hooks(RuntimeHookConfig::new( RuntimeFeatureConfig::default().with_hooks(RuntimeHookConfig::new(
vec![shell_snippet("printf 'blocked by hook'; exit 2")], vec![shell_snippet("printf 'blocked by hook'; exit 2")],
Vec::new(), Vec::new(),
Vec::new(), Vec::new(),
@@ -736,7 +799,7 @@ mod tests {
StaticToolExecutor::new().register("add", |_input| Ok("4".to_string())), StaticToolExecutor::new().register("add", |_input| Ok("4".to_string())),
PermissionPolicy::new(PermissionMode::DangerFullAccess), PermissionPolicy::new(PermissionMode::DangerFullAccess),
vec!["system".to_string()], vec!["system".to_string()],
&RuntimeFeatureConfig::default().with_hooks(RuntimeHookConfig::new( RuntimeFeatureConfig::default().with_hooks(RuntimeHookConfig::new(
vec![shell_snippet("printf 'pre hook ran'")], vec![shell_snippet("printf 'pre hook ran'")],
vec![shell_snippet("printf 'post hook ran'")], vec![shell_snippet("printf 'post hook ran'")],
Vec::new(), Vec::new(),

View File

@@ -1,16 +1,14 @@
use std::ffi::OsStr; use std::ffi::OsStr;
use std::io::Write;
use std::process::{Command, Stdio}; use std::process::{Command, Stdio};
use std::sync::{ use std::sync::{
atomic::{AtomicBool, Ordering}, atomic::{AtomicBool, Ordering},
Arc, Arc,
}; };
use std::thread;
use std::time::Duration; use std::time::Duration;
use serde_json::{json, Value}; use serde_json::{json, Value};
use tokio::io::AsyncWriteExt;
use tokio::process::Command as TokioCommand;
use tokio::runtime::Builder;
use tokio::time::sleep;
use crate::config::{RuntimeFeatureConfig, RuntimeHookConfig}; use crate::config::{RuntimeFeatureConfig, RuntimeHookConfig};
use crate::permissions::PermissionOverride; use crate::permissions::PermissionOverride;
@@ -172,7 +170,7 @@ impl HookRunner {
abort_signal: Option<&HookAbortSignal>, abort_signal: Option<&HookAbortSignal>,
reporter: Option<&mut dyn HookProgressReporter>, reporter: Option<&mut dyn HookProgressReporter>,
) -> HookRunResult { ) -> HookRunResult {
self.run_commands( Self::run_commands(
HookEvent::PreToolUse, HookEvent::PreToolUse,
self.config.pre_tool_use(), self.config.pre_tool_use(),
tool_name, tool_name,
@@ -222,7 +220,7 @@ impl HookRunner {
abort_signal: Option<&HookAbortSignal>, abort_signal: Option<&HookAbortSignal>,
reporter: Option<&mut dyn HookProgressReporter>, reporter: Option<&mut dyn HookProgressReporter>,
) -> HookRunResult { ) -> HookRunResult {
self.run_commands( Self::run_commands(
HookEvent::PostToolUse, HookEvent::PostToolUse,
self.config.post_tool_use(), self.config.post_tool_use(),
tool_name, tool_name,
@@ -272,7 +270,7 @@ impl HookRunner {
abort_signal: Option<&HookAbortSignal>, abort_signal: Option<&HookAbortSignal>,
reporter: Option<&mut dyn HookProgressReporter>, reporter: Option<&mut dyn HookProgressReporter>,
) -> HookRunResult { ) -> HookRunResult {
self.run_commands( Self::run_commands(
HookEvent::PostToolUseFailure, HookEvent::PostToolUseFailure,
self.config.post_tool_use_failure(), self.config.post_tool_use_failure(),
tool_name, tool_name,
@@ -303,7 +301,6 @@ impl HookRunner {
#[allow(clippy::too_many_arguments)] #[allow(clippy::too_many_arguments)]
fn run_commands( fn run_commands(
&self,
event: HookEvent, event: HookEvent,
commands: &[String], commands: &[String],
tool_name: &str, tool_name: &str,
@@ -675,36 +672,23 @@ impl CommandWithStdin {
stdin: &[u8], stdin: &[u8],
abort_signal: Option<&HookAbortSignal>, abort_signal: Option<&HookAbortSignal>,
) -> std::io::Result<CommandExecution> { ) -> std::io::Result<CommandExecution> {
let runtime = Builder::new_current_thread().enable_all().build()?; let mut child = self.command.spawn()?;
let mut command = if let Some(mut child_stdin) = child.stdin.take() {
TokioCommand::from(std::mem::replace(&mut self.command, Command::new("true"))); child_stdin.write_all(stdin)?;
let stdin = stdin.to_vec(); }
let abort_signal = abort_signal.cloned();
runtime.block_on(async move { loop {
let mut child = command.spawn()?; if abort_signal.is_some_and(HookAbortSignal::is_aborted) {
if let Some(mut child_stdin) = child.stdin.take() { let _ = child.kill();
child_stdin.write_all(&stdin).await?; let _ = child.wait_with_output();
return Ok(CommandExecution::Cancelled);
} }
loop { match child.try_wait()? {
if abort_signal Some(_) => return child.wait_with_output().map(CommandExecution::Finished),
.as_ref() None => thread::sleep(Duration::from_millis(20)),
.is_some_and(HookAbortSignal::is_aborted)
{
let _ = child.start_kill();
let _ = child.wait().await;
return Ok(CommandExecution::Cancelled);
}
if let Some(status) = child.try_wait()? {
let output = child.wait_with_output().await?;
debug_assert_eq!(output.status.code(), status.code());
return Ok(CommandExecution::Finished(output));
}
sleep(Duration::from_millis(20)).await;
} }
}) }
} }
} }

View File

@@ -1923,14 +1923,15 @@ fn build_runtime(
) -> Result<ConversationRuntime<AnthropicRuntimeClient, CliToolExecutor>, Box<dyn std::error::Error>> ) -> Result<ConversationRuntime<AnthropicRuntimeClient, CliToolExecutor>, Box<dyn std::error::Error>>
{ {
let feature_config = build_runtime_feature_config()?; let feature_config = build_runtime_feature_config()?;
Ok(ConversationRuntime::new_with_features( let runtime = ConversationRuntime::new_with_features(
session, session,
AnthropicRuntimeClient::new(model, enable_tools, emit_output, allowed_tools.clone())?, AnthropicRuntimeClient::new(model, enable_tools, emit_output, allowed_tools.clone())?,
CliToolExecutor::new(allowed_tools, emit_output), CliToolExecutor::new(allowed_tools, emit_output),
permission_policy(permission_mode, &feature_config), permission_policy(permission_mode, &feature_config),
system_prompt, system_prompt,
&feature_config, feature_config,
)) );
Ok(runtime)
} }
struct CliPermissionPrompter { struct CliPermissionPrompter {
@@ -1953,6 +1954,9 @@ impl runtime::PermissionPrompter for CliPermissionPrompter {
println!(" Tool {}", request.tool_name); println!(" Tool {}", request.tool_name);
println!(" Current mode {}", self.current_mode.as_str()); println!(" Current mode {}", self.current_mode.as_str());
println!(" Required mode {}", request.required_mode.as_str()); println!(" Required mode {}", request.required_mode.as_str());
if let Some(reason) = &request.reason {
println!(" Reason {reason}");
}
println!(" Input {}", request.input); println!(" Input {}", request.input);
print!("Approve this tool call? [y/N]: "); print!("Approve this tool call? [y/N]: ");
let _ = io::stdout().flush(); let _ = io::stdout().flush();
@@ -2365,13 +2369,15 @@ fn format_bash_result(icon: &str, parsed: &serde_json::Value) -> String {
.get("backgroundTaskId") .get("backgroundTaskId")
.and_then(|value| value.as_str()) .and_then(|value| value.as_str())
{ {
lines[0].push_str(&format!(" backgrounded ({task_id})")); use std::fmt::Write as _;
let _ = write!(lines[0], " backgrounded ({task_id})");
} else if let Some(status) = parsed } else if let Some(status) = parsed
.get("returnCodeInterpretation") .get("returnCodeInterpretation")
.and_then(|value| value.as_str()) .and_then(|value| value.as_str())
.filter(|status| !status.is_empty()) .filter(|status| !status.is_empty())
{ {
lines[0].push_str(&format!(" {status}")); use std::fmt::Write as _;
let _ = write!(lines[0], " {status}");
} }
if let Some(stdout) = parsed.get("stdout").and_then(|value| value.as_str()) { if let Some(stdout) = parsed.get("stdout").and_then(|value| value.as_str()) {
@@ -2393,15 +2399,15 @@ fn format_read_result(icon: &str, parsed: &serde_json::Value) -> String {
let path = extract_tool_path(file); let path = extract_tool_path(file);
let start_line = file let start_line = file
.get("startLine") .get("startLine")
.and_then(|value| value.as_u64()) .and_then(serde_json::Value::as_u64)
.unwrap_or(1); .unwrap_or(1);
let num_lines = file let num_lines = file
.get("numLines") .get("numLines")
.and_then(|value| value.as_u64()) .and_then(serde_json::Value::as_u64)
.unwrap_or(0); .unwrap_or(0);
let total_lines = file let total_lines = file
.get("totalLines") .get("totalLines")
.and_then(|value| value.as_u64()) .and_then(serde_json::Value::as_u64)
.unwrap_or(num_lines); .unwrap_or(num_lines);
let content = file let content = file
.get("content") .get("content")
@@ -2427,8 +2433,7 @@ fn format_write_result(icon: &str, parsed: &serde_json::Value) -> String {
let line_count = parsed let line_count = parsed
.get("content") .get("content")
.and_then(|value| value.as_str()) .and_then(|value| value.as_str())
.map(|content| content.lines().count()) .map_or(0, |content| content.lines().count());
.unwrap_or(0);
format!( format!(
"{icon} \x1b[1;32m✏ {} {path}\x1b[0m \x1b[2m({line_count} lines)\x1b[0m", "{icon} \x1b[1;32m✏ {} {path}\x1b[0m \x1b[2m({line_count} lines)\x1b[0m",
if kind == "create" { "Wrote" } else { "Updated" }, if kind == "create" { "Wrote" } else { "Updated" },
@@ -2459,7 +2464,7 @@ fn format_edit_result(icon: &str, parsed: &serde_json::Value) -> String {
let path = extract_tool_path(parsed); let path = extract_tool_path(parsed);
let suffix = if parsed let suffix = if parsed
.get("replaceAll") .get("replaceAll")
.and_then(|value| value.as_bool()) .and_then(serde_json::Value::as_bool)
.unwrap_or(false) .unwrap_or(false)
{ {
" (replace all)" " (replace all)"
@@ -2487,7 +2492,7 @@ fn format_edit_result(icon: &str, parsed: &serde_json::Value) -> String {
fn format_glob_result(icon: &str, parsed: &serde_json::Value) -> String { fn format_glob_result(icon: &str, parsed: &serde_json::Value) -> String {
let num_files = parsed let num_files = parsed
.get("numFiles") .get("numFiles")
.and_then(|value| value.as_u64()) .and_then(serde_json::Value::as_u64)
.unwrap_or(0); .unwrap_or(0);
let filenames = parsed let filenames = parsed
.get("filenames") .get("filenames")
@@ -2511,11 +2516,11 @@ fn format_glob_result(icon: &str, parsed: &serde_json::Value) -> String {
fn format_grep_result(icon: &str, parsed: &serde_json::Value) -> String { fn format_grep_result(icon: &str, parsed: &serde_json::Value) -> String {
let num_matches = parsed let num_matches = parsed
.get("numMatches") .get("numMatches")
.and_then(|value| value.as_u64()) .and_then(serde_json::Value::as_u64)
.unwrap_or(0); .unwrap_or(0);
let num_files = parsed let num_files = parsed
.get("numFiles") .get("numFiles")
.and_then(|value| value.as_u64()) .and_then(serde_json::Value::as_u64)
.unwrap_or(0); .unwrap_or(0);
let content = parsed let content = parsed
.get("content") .get("content")

View File

@@ -286,7 +286,7 @@ impl TerminalRenderer {
) { ) {
match event { match event {
Event::Start(Tag::Heading { level, .. }) => { Event::Start(Tag::Heading { level, .. }) => {
self.start_heading(state, level as u8, output) Self::start_heading(state, level as u8, output);
} }
Event::End(TagEnd::Paragraph) => output.push_str("\n\n"), Event::End(TagEnd::Paragraph) => output.push_str("\n\n"),
Event::Start(Tag::BlockQuote(..)) => self.start_quote(state, output), Event::Start(Tag::BlockQuote(..)) => self.start_quote(state, output),
@@ -426,7 +426,7 @@ impl TerminalRenderer {
} }
} }
fn start_heading(&self, state: &mut RenderState, level: u8, output: &mut String) { fn start_heading(state: &mut RenderState, level: u8, output: &mut String) {
state.heading_level = Some(level); state.heading_level = Some(level);
if !output.is_empty() { if !output.is_empty() {
output.push('\n'); output.push('\n');