use std::ffi::OsStr; use std::process::{Command, Stdio}; use std::sync::{ atomic::{AtomicBool, Ordering}, Arc, }; use std::time::Duration; use serde_json::{json, Value}; use tokio::io::AsyncWriteExt; use tokio::process::Command as TokioCommand; use tokio::runtime::Builder; use tokio::time::sleep; use crate::config::{RuntimeFeatureConfig, RuntimeHookConfig}; use crate::permissions::PermissionOverride; pub type HookPermissionDecision = PermissionOverride; #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub enum HookEvent { PreToolUse, PostToolUse, PostToolUseFailure, } impl HookEvent { #[must_use] pub fn as_str(self) -> &'static str { match self { Self::PreToolUse => "PreToolUse", Self::PostToolUse => "PostToolUse", Self::PostToolUseFailure => "PostToolUseFailure", } } } #[derive(Debug, Clone, PartialEq, Eq)] pub enum HookProgressEvent { Started { event: HookEvent, tool_name: String, command: String, }, Completed { event: HookEvent, tool_name: String, command: String, }, Cancelled { event: HookEvent, tool_name: String, command: String, }, } pub trait HookProgressReporter { fn on_event(&mut self, event: &HookProgressEvent); } #[derive(Debug, Clone, Default)] pub struct HookAbortSignal { aborted: Arc, } impl HookAbortSignal { #[must_use] pub fn new() -> Self { Self::default() } pub fn abort(&self) { self.aborted.store(true, Ordering::SeqCst); } #[must_use] pub fn is_aborted(&self) -> bool { self.aborted.load(Ordering::SeqCst) } } #[derive(Debug, Clone, PartialEq, Eq)] pub struct HookRunResult { denied: bool, cancelled: bool, messages: Vec, permission_override: Option, permission_reason: Option, updated_input: Option, } impl HookRunResult { #[must_use] pub fn allow(messages: Vec) -> Self { Self { denied: false, cancelled: false, messages, permission_override: None, permission_reason: None, updated_input: None, } } #[must_use] pub fn is_denied(&self) -> bool { self.denied } #[must_use] pub fn is_cancelled(&self) -> bool { self.cancelled } #[must_use] pub fn messages(&self) -> &[String] { &self.messages } #[must_use] pub fn permission_override(&self) -> Option { self.permission_override } #[must_use] pub fn permission_decision(&self) -> Option { self.permission_override } #[must_use] pub fn permission_reason(&self) -> Option<&str> { self.permission_reason.as_deref() } #[must_use] pub fn updated_input(&self) -> Option<&str> { self.updated_input.as_deref() } #[must_use] pub fn updated_input_json(&self) -> Option<&str> { self.updated_input() } } #[derive(Debug, Clone, PartialEq, Eq, Default)] pub struct HookRunner { config: RuntimeHookConfig, } impl HookRunner { #[must_use] pub fn new(config: RuntimeHookConfig) -> Self { Self { config } } #[must_use] pub fn from_feature_config(feature_config: &RuntimeFeatureConfig) -> Self { Self::new(feature_config.hooks().clone()) } #[must_use] pub fn run_pre_tool_use(&self, tool_name: &str, tool_input: &str) -> HookRunResult { self.run_pre_tool_use_with_context(tool_name, tool_input, None, None) } #[must_use] pub fn run_pre_tool_use_with_context( &self, tool_name: &str, tool_input: &str, abort_signal: Option<&HookAbortSignal>, reporter: Option<&mut dyn HookProgressReporter>, ) -> HookRunResult { self.run_commands( HookEvent::PreToolUse, self.config.pre_tool_use(), tool_name, tool_input, None, false, abort_signal, reporter, ) } #[must_use] pub fn run_pre_tool_use_with_signal( &self, tool_name: &str, tool_input: &str, abort_signal: Option<&HookAbortSignal>, ) -> HookRunResult { self.run_pre_tool_use_with_context(tool_name, tool_input, abort_signal, None) } #[must_use] pub fn run_post_tool_use( &self, tool_name: &str, tool_input: &str, tool_output: &str, is_error: bool, ) -> HookRunResult { self.run_post_tool_use_with_context( tool_name, tool_input, tool_output, is_error, None, None, ) } #[must_use] pub fn run_post_tool_use_with_context( &self, tool_name: &str, tool_input: &str, tool_output: &str, is_error: bool, abort_signal: Option<&HookAbortSignal>, reporter: Option<&mut dyn HookProgressReporter>, ) -> HookRunResult { self.run_commands( HookEvent::PostToolUse, self.config.post_tool_use(), tool_name, tool_input, Some(tool_output), is_error, abort_signal, reporter, ) } #[must_use] pub fn run_post_tool_use_with_signal( &self, tool_name: &str, tool_input: &str, tool_output: &str, is_error: bool, abort_signal: Option<&HookAbortSignal>, ) -> HookRunResult { self.run_post_tool_use_with_context( tool_name, tool_input, tool_output, is_error, abort_signal, None, ) } #[must_use] pub fn run_post_tool_use_failure( &self, tool_name: &str, tool_input: &str, tool_error: &str, ) -> HookRunResult { self.run_post_tool_use_failure_with_context(tool_name, tool_input, tool_error, None, None) } #[must_use] pub fn run_post_tool_use_failure_with_context( &self, tool_name: &str, tool_input: &str, tool_error: &str, abort_signal: Option<&HookAbortSignal>, reporter: Option<&mut dyn HookProgressReporter>, ) -> HookRunResult { self.run_commands( HookEvent::PostToolUseFailure, self.config.post_tool_use_failure(), tool_name, tool_input, Some(tool_error), true, abort_signal, reporter, ) } #[must_use] pub fn run_post_tool_use_failure_with_signal( &self, tool_name: &str, tool_input: &str, tool_error: &str, abort_signal: Option<&HookAbortSignal>, ) -> HookRunResult { self.run_post_tool_use_failure_with_context( tool_name, tool_input, tool_error, abort_signal, None, ) } #[allow(clippy::too_many_arguments)] fn run_commands( &self, event: HookEvent, commands: &[String], tool_name: &str, tool_input: &str, tool_output: Option<&str>, is_error: bool, abort_signal: Option<&HookAbortSignal>, mut reporter: Option<&mut dyn HookProgressReporter>, ) -> HookRunResult { if commands.is_empty() { return HookRunResult::allow(Vec::new()); } if abort_signal.is_some_and(HookAbortSignal::is_aborted) { return HookRunResult { denied: false, cancelled: true, messages: vec![format!( "{} hook cancelled before execution", event.as_str() )], permission_override: None, permission_reason: None, updated_input: None, }; } let payload = hook_payload(event, tool_name, tool_input, tool_output, is_error).to_string(); let mut result = HookRunResult::allow(Vec::new()); for command in commands { if let Some(reporter) = reporter.as_deref_mut() { reporter.on_event(&HookProgressEvent::Started { event, tool_name: tool_name.to_string(), command: command.clone(), }); } match Self::run_command( command, event, tool_name, tool_input, tool_output, is_error, &payload, abort_signal, ) { HookCommandOutcome::Allow { parsed } => { if let Some(reporter) = reporter.as_deref_mut() { reporter.on_event(&HookProgressEvent::Completed { event, tool_name: tool_name.to_string(), command: command.clone(), }); } merge_parsed_hook_output(&mut result, parsed); } HookCommandOutcome::Deny { parsed } => { if let Some(reporter) = reporter.as_deref_mut() { reporter.on_event(&HookProgressEvent::Completed { event, tool_name: tool_name.to_string(), command: command.clone(), }); } merge_parsed_hook_output(&mut result, parsed); result.denied = true; return result; } HookCommandOutcome::Warn { message } => { if let Some(reporter) = reporter.as_deref_mut() { reporter.on_event(&HookProgressEvent::Completed { event, tool_name: tool_name.to_string(), command: command.clone(), }); } result.messages.push(message); } HookCommandOutcome::Cancelled { message } => { if let Some(reporter) = reporter.as_deref_mut() { reporter.on_event(&HookProgressEvent::Cancelled { event, tool_name: tool_name.to_string(), command: command.clone(), }); } result.cancelled = true; result.messages.push(message); return result; } } } result } #[allow(clippy::too_many_arguments)] fn run_command( command: &str, event: HookEvent, tool_name: &str, tool_input: &str, tool_output: Option<&str>, is_error: bool, payload: &str, abort_signal: Option<&HookAbortSignal>, ) -> HookCommandOutcome { let mut child = shell_command(command); child.stdin(Stdio::piped()); child.stdout(Stdio::piped()); child.stderr(Stdio::piped()); child.env("HOOK_EVENT", event.as_str()); child.env("HOOK_TOOL_NAME", tool_name); child.env("HOOK_TOOL_INPUT", tool_input); child.env("HOOK_TOOL_IS_ERROR", if is_error { "1" } else { "0" }); if let Some(tool_output) = tool_output { child.env("HOOK_TOOL_OUTPUT", tool_output); } match child.output_with_stdin(payload.as_bytes(), abort_signal) { Ok(CommandExecution::Finished(output)) => { let stdout = String::from_utf8_lossy(&output.stdout).trim().to_string(); let stderr = String::from_utf8_lossy(&output.stderr).trim().to_string(); let parsed = parse_hook_output(&stdout); match output.status.code() { Some(0) => { if parsed.deny { HookCommandOutcome::Deny { parsed } } else { HookCommandOutcome::Allow { parsed } } } Some(2) => HookCommandOutcome::Deny { parsed: parsed.with_fallback_message(format!( "{} hook denied tool `{tool_name}`", event.as_str() )), }, Some(code) => HookCommandOutcome::Warn { message: format_hook_warning( command, code, parsed.primary_message(), stderr.as_str(), ), }, None => HookCommandOutcome::Warn { message: format!( "{} hook `{command}` terminated by signal while handling `{tool_name}`", event.as_str() ), }, } } Ok(CommandExecution::Cancelled) => HookCommandOutcome::Cancelled { message: format!( "{} hook `{command}` cancelled while handling `{tool_name}`", event.as_str() ), }, Err(error) => HookCommandOutcome::Warn { message: format!( "{} hook `{command}` failed to start for `{tool_name}`: {error}", event.as_str() ), }, } } } enum HookCommandOutcome { Allow { parsed: ParsedHookOutput }, Deny { parsed: ParsedHookOutput }, Warn { message: String }, Cancelled { message: String }, } #[derive(Debug, Clone, PartialEq, Eq, Default)] struct ParsedHookOutput { messages: Vec, deny: bool, permission_override: Option, permission_reason: Option, updated_input: Option, } impl ParsedHookOutput { fn with_fallback_message(mut self, fallback: String) -> Self { if self.messages.is_empty() { self.messages.push(fallback); } self } fn primary_message(&self) -> Option<&str> { self.messages.first().map(String::as_str) } } fn merge_parsed_hook_output(target: &mut HookRunResult, parsed: ParsedHookOutput) { target.messages.extend(parsed.messages); if parsed.permission_override.is_some() { target.permission_override = parsed.permission_override; } if parsed.permission_reason.is_some() { target.permission_reason = parsed.permission_reason; } if parsed.updated_input.is_some() { target.updated_input = parsed.updated_input; } } fn parse_hook_output(stdout: &str) -> ParsedHookOutput { if stdout.is_empty() { return ParsedHookOutput::default(); } let Ok(Value::Object(root)) = serde_json::from_str::(stdout) else { return ParsedHookOutput { messages: vec![stdout.to_string()], ..ParsedHookOutput::default() }; }; let mut parsed = ParsedHookOutput::default(); if let Some(message) = root.get("systemMessage").and_then(Value::as_str) { parsed.messages.push(message.to_string()); } if let Some(message) = root.get("reason").and_then(Value::as_str) { parsed.messages.push(message.to_string()); } if root.get("continue").and_then(Value::as_bool) == Some(false) || root.get("decision").and_then(Value::as_str) == Some("block") { parsed.deny = true; } if let Some(Value::Object(specific)) = root.get("hookSpecificOutput") { if let Some(Value::String(additional_context)) = specific.get("additionalContext") { parsed.messages.push(additional_context.clone()); } if let Some(decision) = specific.get("permissionDecision").and_then(Value::as_str) { parsed.permission_override = match decision { "allow" => Some(PermissionOverride::Allow), "deny" => Some(PermissionOverride::Deny), "ask" => Some(PermissionOverride::Ask), _ => None, }; } if let Some(reason) = specific .get("permissionDecisionReason") .and_then(Value::as_str) { parsed.permission_reason = Some(reason.to_string()); } if let Some(updated_input) = specific.get("updatedInput") { parsed.updated_input = serde_json::to_string(updated_input).ok(); } } if parsed.messages.is_empty() { parsed.messages.push(stdout.to_string()); } parsed } fn hook_payload( event: HookEvent, tool_name: &str, tool_input: &str, tool_output: Option<&str>, is_error: bool, ) -> Value { match event { HookEvent::PostToolUseFailure => json!({ "hook_event_name": event.as_str(), "tool_name": tool_name, "tool_input": parse_tool_input(tool_input), "tool_input_json": tool_input, "tool_error": tool_output, "tool_result_is_error": true, }), _ => json!({ "hook_event_name": event.as_str(), "tool_name": tool_name, "tool_input": parse_tool_input(tool_input), "tool_input_json": tool_input, "tool_output": tool_output, "tool_result_is_error": is_error, }), } } fn parse_tool_input(tool_input: &str) -> Value { serde_json::from_str(tool_input).unwrap_or_else(|_| json!({ "raw": tool_input })) } fn format_hook_warning(command: &str, code: i32, stdout: Option<&str>, stderr: &str) -> String { let mut message = format!("Hook `{command}` exited with status {code}; allowing tool execution to continue"); if let Some(stdout) = stdout.filter(|stdout| !stdout.is_empty()) { message.push_str(": "); message.push_str(stdout); } else if !stderr.is_empty() { message.push_str(": "); message.push_str(stderr); } message } fn shell_command(command: &str) -> CommandWithStdin { #[cfg(windows)] let mut command_builder = { let mut command_builder = Command::new("cmd"); command_builder.arg("/C").arg(command); CommandWithStdin::new(command_builder) }; #[cfg(not(windows))] let command_builder = { let mut command_builder = Command::new("sh"); command_builder.arg("-lc").arg(command); CommandWithStdin::new(command_builder) }; command_builder } struct CommandWithStdin { command: Command, } impl CommandWithStdin { fn new(command: Command) -> Self { Self { command } } fn stdin(&mut self, cfg: Stdio) -> &mut Self { self.command.stdin(cfg); self } fn stdout(&mut self, cfg: Stdio) -> &mut Self { self.command.stdout(cfg); self } fn stderr(&mut self, cfg: Stdio) -> &mut Self { self.command.stderr(cfg); self } fn env(&mut self, key: K, value: V) -> &mut Self where K: AsRef, V: AsRef, { self.command.env(key, value); self } fn output_with_stdin( &mut self, stdin: &[u8], abort_signal: Option<&HookAbortSignal>, ) -> std::io::Result { let runtime = Builder::new_current_thread().enable_all().build()?; let mut command = TokioCommand::from(std::mem::replace(&mut self.command, Command::new("true"))); let stdin = stdin.to_vec(); let abort_signal = abort_signal.cloned(); runtime.block_on(async move { let mut child = command.spawn()?; if let Some(mut child_stdin) = child.stdin.take() { child_stdin.write_all(&stdin).await?; } loop { if abort_signal .as_ref() .is_some_and(HookAbortSignal::is_aborted) { let _ = child.start_kill(); let _ = child.wait().await; return Ok(CommandExecution::Cancelled); } if let Some(status) = child.try_wait()? { let output = child.wait_with_output().await?; debug_assert_eq!(output.status.code(), status.code()); return Ok(CommandExecution::Finished(output)); } sleep(Duration::from_millis(20)).await; } }) } } enum CommandExecution { Finished(std::process::Output), Cancelled, } #[cfg(test)] mod tests { use std::thread; use std::time::Duration; use super::{ HookAbortSignal, HookEvent, HookProgressEvent, HookProgressReporter, HookRunResult, HookRunner, }; use crate::config::{RuntimeFeatureConfig, RuntimeHookConfig}; use crate::permissions::PermissionOverride; struct RecordingReporter { events: Vec, } impl HookProgressReporter for RecordingReporter { fn on_event(&mut self, event: &HookProgressEvent) { self.events.push(event.clone()); } } #[test] fn allows_exit_code_zero_and_captures_stdout() { let runner = HookRunner::new(RuntimeHookConfig::new( vec![shell_snippet("printf 'pre ok'")], Vec::new(), Vec::new(), )); let result = runner.run_pre_tool_use("Read", r#"{"path":"README.md"}"#); assert_eq!(result, HookRunResult::allow(vec!["pre ok".to_string()])); } #[test] fn denies_exit_code_two() { let runner = HookRunner::new(RuntimeHookConfig::new( vec![shell_snippet("printf 'blocked by hook'; exit 2")], Vec::new(), Vec::new(), )); let result = runner.run_pre_tool_use("Bash", r#"{"command":"pwd"}"#); assert!(result.is_denied()); assert_eq!(result.messages(), &["blocked by hook".to_string()]); } #[test] fn warns_for_other_non_zero_statuses() { let runner = HookRunner::from_feature_config(&RuntimeFeatureConfig::default().with_hooks( RuntimeHookConfig::new( vec![shell_snippet("printf 'warning hook'; exit 1")], Vec::new(), Vec::new(), ), )); let result = runner.run_pre_tool_use("Edit", r#"{"file":"src/lib.rs"}"#); assert!(!result.is_denied()); assert!(result .messages() .iter() .any(|message| message.contains("allowing tool execution to continue"))); } #[test] fn parses_pre_hook_permission_override_and_updated_input() { let runner = HookRunner::new(RuntimeHookConfig::new( vec![shell_snippet( r#"printf '%s' '{"systemMessage":"updated","hookSpecificOutput":{"permissionDecision":"allow","permissionDecisionReason":"hook ok","updatedInput":{"command":"git status"}}}'"#, )], Vec::new(), Vec::new(), )); let result = runner.run_pre_tool_use("bash", r#"{"command":"pwd"}"#); assert_eq!( result.permission_override(), Some(PermissionOverride::Allow) ); assert_eq!(result.permission_reason(), Some("hook ok")); assert_eq!(result.updated_input(), Some(r#"{"command":"git status"}"#)); assert!(result.messages().iter().any(|message| message == "updated")); } #[test] fn runs_post_tool_use_failure_hooks() { let runner = HookRunner::new(RuntimeHookConfig::new( Vec::new(), Vec::new(), vec![shell_snippet("printf 'failure hook ran'")], )); let result = runner.run_post_tool_use_failure("bash", r#"{"command":"false"}"#, "command failed"); assert!(!result.is_denied()); assert_eq!(result.messages(), &["failure hook ran".to_string()]); } #[test] fn abort_signal_cancels_long_running_hook_and_reports_progress() { let runner = HookRunner::new(RuntimeHookConfig::new( vec![shell_snippet("sleep 5")], Vec::new(), Vec::new(), )); let abort_signal = HookAbortSignal::new(); let abort_signal_for_thread = abort_signal.clone(); let mut reporter = RecordingReporter { events: Vec::new() }; thread::spawn(move || { thread::sleep(Duration::from_millis(100)); abort_signal_for_thread.abort(); }); let result = runner.run_pre_tool_use_with_context( "bash", r#"{"command":"sleep 5"}"#, Some(&abort_signal), Some(&mut reporter), ); assert!(result.is_cancelled()); assert!(reporter.events.iter().any(|event| matches!( event, HookProgressEvent::Started { event: HookEvent::PreToolUse, .. } ))); assert!(reporter.events.iter().any(|event| matches!( event, HookProgressEvent::Cancelled { event: HookEvent::PreToolUse, .. } ))); } #[cfg(windows)] fn shell_snippet(script: &str) -> String { script.replace('\'', "\"") } #[cfg(not(windows))] fn shell_snippet(script: &str) -> String { script.to_string() } }