diff --git a/rust/crates/plugins/src/hooks.rs b/rust/crates/plugins/src/hooks.rs index d473da8..a03e7e8 100644 --- a/rust/crates/plugins/src/hooks.rs +++ b/rust/crates/plugins/src/hooks.rs @@ -10,6 +10,7 @@ use crate::{PluginError, PluginHooks, PluginRegistry}; pub enum HookEvent { PreToolUse, PostToolUse, + PostToolUseFailure, } impl HookEvent { @@ -17,6 +18,7 @@ impl HookEvent { match self { Self::PreToolUse => "PreToolUse", Self::PostToolUse => "PostToolUse", + Self::PostToolUseFailure => "PostToolUseFailure", } } } @@ -24,6 +26,7 @@ impl HookEvent { #[derive(Debug, Clone, PartialEq, Eq)] pub struct HookRunResult { denied: bool, + failed: bool, messages: Vec, } @@ -32,6 +35,7 @@ impl HookRunResult { pub fn allow(messages: Vec) -> Self { Self { denied: false, + failed: false, messages, } } @@ -41,6 +45,11 @@ impl HookRunResult { self.denied } + #[must_use] + pub fn is_failed(&self) -> bool { + self.failed + } + #[must_use] pub fn messages(&self) -> &[String] { &self.messages @@ -92,6 +101,23 @@ impl HookRunner { ) } + #[must_use] + pub fn run_post_tool_use_failure( + &self, + tool_name: &str, + tool_input: &str, + tool_error: &str, + ) -> HookRunResult { + self.run_commands( + HookEvent::PostToolUseFailure, + &self.hooks.post_tool_use_failure, + tool_name, + tool_input, + Some(tool_error), + true, + ) + } + fn run_commands( &self, event: HookEvent, @@ -105,15 +131,7 @@ impl HookRunner { return HookRunResult::allow(Vec::new()); } - let payload = json!({ - "hook_event_name": event.as_str(), - "tool_name": tool_name, - "tool_input": parse_tool_input(tool_input), - "tool_input_json": tool_input, - "tool_output": tool_output, - "tool_result_is_error": is_error, - }) - .to_string(); + let payload = hook_payload(event, tool_name, tool_input, tool_output, is_error).to_string(); let mut messages = Vec::new(); @@ -138,10 +156,18 @@ impl HookRunner { })); return HookRunResult { denied: true, + failed: false, + messages, + }; + } + HookCommandOutcome::Failed { message } => { + messages.push(message); + return HookRunResult { + denied: false, + failed: true, messages, }; } - HookCommandOutcome::Warn { message } => messages.push(message), } } @@ -179,7 +205,7 @@ impl HookRunner { match output.status.code() { Some(0) => HookCommandOutcome::Allow { message }, Some(2) => HookCommandOutcome::Deny { message }, - Some(code) => HookCommandOutcome::Warn { + Some(code) => HookCommandOutcome::Failed { message: format_hook_warning( command, code, @@ -187,7 +213,7 @@ impl HookRunner { stderr.as_str(), ), }, - None => HookCommandOutcome::Warn { + None => HookCommandOutcome::Failed { message: format!( "{} hook `{command}` terminated by signal while handling `{tool_name}`", event.as_str() @@ -195,7 +221,7 @@ impl HookRunner { }, } } - Err(error) => HookCommandOutcome::Warn { + Err(error) => HookCommandOutcome::Failed { message: format!( "{} hook `{command}` failed to start for `{tool_name}`: {error}", event.as_str() @@ -208,7 +234,34 @@ impl HookRunner { enum HookCommandOutcome { Allow { message: Option }, Deny { message: Option }, - Warn { message: String }, + Failed { message: String }, +} + +fn hook_payload( + event: HookEvent, + tool_name: &str, + tool_input: &str, + tool_output: Option<&str>, + is_error: bool, +) -> serde_json::Value { + match event { + HookEvent::PostToolUseFailure => json!({ + "hook_event_name": event.as_str(), + "tool_name": tool_name, + "tool_input": parse_tool_input(tool_input), + "tool_input_json": tool_input, + "tool_error": tool_output, + "tool_result_is_error": true, + }), + _ => json!({ + "hook_event_name": event.as_str(), + "tool_name": tool_name, + "tool_input": parse_tool_input(tool_input), + "tool_input_json": tool_input, + "tool_output": tool_output, + "tool_result_is_error": is_error, + }), + } } fn parse_tool_input(tool_input: &str) -> serde_json::Value { @@ -216,8 +269,7 @@ fn parse_tool_input(tool_input: &str) -> serde_json::Value { } fn format_hook_warning(command: &str, code: i32, stdout: Option<&str>, stderr: &str) -> String { - let mut message = - format!("Hook `{command}` exited with status {code}; allowing tool execution to continue"); + let mut message = format!("Hook `{command}` exited with status {code}"); if let Some(stdout) = stdout.filter(|stdout| !stdout.is_empty()) { message.push_str(": "); message.push_str(stdout); @@ -309,7 +361,13 @@ mod tests { std::env::temp_dir().join(format!("plugins-hook-runner-{label}-{nanos}")) } - fn write_hook_plugin(root: &Path, name: &str, pre_message: &str, post_message: &str) { + fn write_hook_plugin( + root: &Path, + name: &str, + pre_message: &str, + post_message: &str, + failure_message: &str, + ) { fs::create_dir_all(root.join(".claude-plugin")).expect("manifest dir"); fs::create_dir_all(root.join("hooks")).expect("hooks dir"); fs::write( @@ -322,10 +380,15 @@ mod tests { format!("#!/bin/sh\nprintf '%s\\n' '{post_message}'\n"), ) .expect("write post hook"); + fs::write( + root.join("hooks").join("failure.sh"), + format!("#!/bin/sh\nprintf '%s\\n' '{failure_message}'\n"), + ) + .expect("write failure hook"); fs::write( root.join(".claude-plugin").join("plugin.json"), format!( - "{{\n \"name\": \"{name}\",\n \"version\": \"1.0.0\",\n \"description\": \"hook plugin\",\n \"hooks\": {{\n \"PreToolUse\": [\"./hooks/pre.sh\"],\n \"PostToolUse\": [\"./hooks/post.sh\"]\n }}\n}}" + "{{\n \"name\": \"{name}\",\n \"version\": \"1.0.0\",\n \"description\": \"hook plugin\",\n \"hooks\": {{\n \"PreToolUse\": [\"./hooks/pre.sh\"],\n \"PostToolUse\": [\"./hooks/post.sh\"],\n \"PostToolUseFailure\": [\"./hooks/failure.sh\"]\n }}\n}}" ), ) .expect("write plugin manifest"); @@ -333,6 +396,7 @@ mod tests { #[test] fn collects_and_runs_hooks_from_enabled_plugins() { + // given let config_home = temp_dir("config"); let first_source_root = temp_dir("source-a"); let second_source_root = temp_dir("source-b"); @@ -341,12 +405,14 @@ mod tests { "first", "plugin pre one", "plugin post one", + "plugin failure one", ); write_hook_plugin( &second_source_root, "second", "plugin pre two", "plugin post two", + "plugin failure two", ); let mut manager = PluginManager::new(PluginManagerConfig::new(&config_home)); @@ -358,8 +424,10 @@ mod tests { .expect("second plugin install should succeed"); let registry = manager.plugin_registry().expect("registry should build"); + // when let runner = HookRunner::from_registry(®istry).expect("plugin hooks should load"); + // then assert_eq!( runner.run_pre_tool_use("Read", r#"{"path":"README.md"}"#), HookRunResult::allow(vec![ @@ -374,6 +442,13 @@ mod tests { "plugin post two".to_string(), ]) ); + assert_eq!( + runner.run_post_tool_use_failure("Read", r#"{"path":"README.md"}"#, "tool failed",), + HookRunResult::allow(vec![ + "plugin failure one".to_string(), + "plugin failure two".to_string(), + ]) + ); let _ = fs::remove_dir_all(config_home); let _ = fs::remove_dir_all(first_source_root); @@ -382,14 +457,45 @@ mod tests { #[test] fn pre_tool_use_denies_when_plugin_hook_exits_two() { + // given let runner = HookRunner::new(crate::PluginHooks { pre_tool_use: vec!["printf 'blocked by plugin'; exit 2".to_string()], post_tool_use: Vec::new(), + post_tool_use_failure: Vec::new(), }); + // when let result = runner.run_pre_tool_use("Bash", r#"{"command":"pwd"}"#); + // then assert!(result.is_denied()); assert_eq!(result.messages(), &["blocked by plugin".to_string()]); } + + #[test] + fn propagates_plugin_hook_failures() { + // given + let runner = HookRunner::new(crate::PluginHooks { + pre_tool_use: vec![ + "printf 'broken plugin hook'; exit 1".to_string(), + "printf 'later plugin hook'".to_string(), + ], + post_tool_use: Vec::new(), + post_tool_use_failure: Vec::new(), + }); + + // when + let result = runner.run_pre_tool_use("Bash", r#"{"command":"pwd"}"#); + + // then + assert!(result.is_failed()); + assert!(result + .messages() + .iter() + .any(|message| message.contains("broken plugin hook"))); + assert!(!result + .messages() + .iter() + .any(|message| message == "later plugin hook")); + } } diff --git a/rust/crates/plugins/src/lib.rs b/rust/crates/plugins/src/lib.rs index 8ab819c..321070a 100644 --- a/rust/crates/plugins/src/lib.rs +++ b/rust/crates/plugins/src/lib.rs @@ -67,12 +67,16 @@ pub struct PluginHooks { pub pre_tool_use: Vec, #[serde(rename = "PostToolUse", default)] pub post_tool_use: Vec, + #[serde(rename = "PostToolUseFailure", default)] + pub post_tool_use_failure: Vec, } impl PluginHooks { #[must_use] pub fn is_empty(&self) -> bool { - self.pre_tool_use.is_empty() && self.post_tool_use.is_empty() + self.pre_tool_use.is_empty() + && self.post_tool_use.is_empty() + && self.post_tool_use_failure.is_empty() } #[must_use] @@ -85,6 +89,9 @@ impl PluginHooks { .post_tool_use .extend(other.post_tool_use.iter().cloned()); merged + .post_tool_use_failure + .extend(other.post_tool_use_failure.iter().cloned()); + merged } } @@ -1691,6 +1698,11 @@ fn resolve_hooks(root: &Path, hooks: &PluginHooks) -> PluginHooks { .iter() .map(|entry| resolve_hook_entry(root, entry)) .collect(), + post_tool_use_failure: hooks + .post_tool_use_failure + .iter() + .map(|entry| resolve_hook_entry(root, entry)) + .collect(), } } @@ -1739,7 +1751,12 @@ fn validate_hook_paths(root: Option<&Path>, hooks: &PluginHooks) -> Result<(), P let Some(root) = root else { return Ok(()); }; - for entry in hooks.pre_tool_use.iter().chain(hooks.post_tool_use.iter()) { + for entry in hooks + .pre_tool_use + .iter() + .chain(hooks.post_tool_use.iter()) + .chain(hooks.post_tool_use_failure.iter()) + { validate_command_path(root, entry, "hook")?; } Ok(()) diff --git a/rust/crates/runtime/src/conversation.rs b/rust/crates/runtime/src/conversation.rs index 681c51f..c93e217 100644 --- a/rust/crates/runtime/src/conversation.rs +++ b/rust/crates/runtime/src/conversation.rs @@ -374,6 +374,13 @@ where &format!("PreToolUse hook cancelled tool `{tool_name}`"), ), } + } else if pre_hook_result.is_failed() { + PermissionOutcome::Deny { + reason: format_hook_message( + &pre_hook_result, + &format!("PreToolUse hook failed for tool `{tool_name}`"), + ), + } } else if pre_hook_result.is_denied() { PermissionOutcome::Deny { reason: format_hook_message( @@ -421,13 +428,18 @@ where false, ) }; - if post_hook_result.is_denied() || post_hook_result.is_cancelled() { + if post_hook_result.is_denied() + || post_hook_result.is_failed() + || post_hook_result.is_cancelled() + { is_error = true; } output = merge_hook_feedback( post_hook_result.messages(), output, - post_hook_result.is_denied() || post_hook_result.is_cancelled(), + post_hook_result.is_denied() + || post_hook_result.is_failed() + || post_hook_result.is_cancelled(), ); ConversationMessage::tool_result(tool_use_id, tool_name, output, is_error) @@ -707,7 +719,7 @@ fn format_hook_message(result: &HookRunResult, fallback: &str) -> String { } } -fn merge_hook_feedback(messages: &[String], output: String, denied: bool) -> String { +fn merge_hook_feedback(messages: &[String], output: String, is_error: bool) -> String { if messages.is_empty() { return output; } @@ -716,8 +728,8 @@ fn merge_hook_feedback(messages: &[String], output: String, denied: bool) -> Str if !output.trim().is_empty() { sections.push(output); } - let label = if denied { - "Hook feedback (denied)" + let label = if is_error { + "Hook feedback (error)" } else { "Hook feedback" }; @@ -1050,6 +1062,71 @@ mod tests { ); } + #[test] + fn denies_tool_use_when_pre_tool_hook_fails() { + struct SingleCallApiClient; + impl ApiClient for SingleCallApiClient { + fn stream(&mut self, request: ApiRequest) -> Result, RuntimeError> { + if request + .messages + .iter() + .any(|message| message.role == MessageRole::Tool) + { + return Ok(vec![ + AssistantEvent::TextDelta("failed".to_string()), + AssistantEvent::MessageStop, + ]); + } + Ok(vec![ + AssistantEvent::ToolUse { + id: "tool-1".to_string(), + name: "blocked".to_string(), + input: r#"{"path":"secret.txt"}"#.to_string(), + }, + AssistantEvent::MessageStop, + ]) + } + } + + // given + let mut runtime = ConversationRuntime::new_with_features( + Session::new(), + SingleCallApiClient, + StaticToolExecutor::new().register("blocked", |_input| { + panic!("tool should not execute when hook fails") + }), + PermissionPolicy::new(PermissionMode::DangerFullAccess), + vec!["system".to_string()], + &RuntimeFeatureConfig::default().with_hooks(RuntimeHookConfig::new( + vec![shell_snippet("printf 'broken hook'; exit 1")], + Vec::new(), + Vec::new(), + )), + ); + + // when + let summary = runtime + .run_turn("use the tool", None) + .expect("conversation should continue after hook failure"); + + // then + assert_eq!(summary.tool_results.len(), 1); + let ContentBlock::ToolResult { + is_error, output, .. + } = &summary.tool_results[0].blocks[0] + else { + panic!("expected tool result block"); + }; + assert!( + *is_error, + "hook failure should produce an error result: {output}" + ); + assert!( + output.contains("exited with status 1") || output.contains("broken hook"), + "unexpected hook failure output: {output:?}" + ); + } + #[test] fn appends_post_tool_hook_feedback_to_tool_result() { struct TwoCallApiClient { diff --git a/rust/crates/runtime/src/hooks.rs b/rust/crates/runtime/src/hooks.rs index 739065d..f0a32fe 100644 --- a/rust/crates/runtime/src/hooks.rs +++ b/rust/crates/runtime/src/hooks.rs @@ -80,6 +80,7 @@ impl HookAbortSignal { #[derive(Debug, Clone, PartialEq, Eq)] pub struct HookRunResult { denied: bool, + failed: bool, cancelled: bool, messages: Vec, permission_override: Option, @@ -92,6 +93,7 @@ impl HookRunResult { pub fn allow(messages: Vec) -> Self { Self { denied: false, + failed: false, cancelled: false, messages, permission_override: None, @@ -105,6 +107,11 @@ impl HookRunResult { self.denied } + #[must_use] + pub fn is_failed(&self) -> bool { + self.failed + } + #[must_use] pub fn is_cancelled(&self) -> bool { self.cancelled @@ -317,6 +324,7 @@ impl HookRunner { if abort_signal.is_some_and(HookAbortSignal::is_aborted) { return HookRunResult { denied: false, + failed: false, cancelled: true, messages: vec![format!( "{} hook cancelled before execution", @@ -372,7 +380,7 @@ impl HookRunner { result.denied = true; return result; } - HookCommandOutcome::Warn { message } => { + HookCommandOutcome::Failed { parsed } => { if let Some(reporter) = reporter.as_deref_mut() { reporter.on_event(&HookProgressEvent::Completed { event, @@ -380,7 +388,9 @@ impl HookRunner { command: command.clone(), }); } - result.messages.push(message); + merge_parsed_hook_output(&mut result, parsed); + result.failed = true; + return result; } HookCommandOutcome::Cancelled { message } => { if let Some(reporter) = reporter.as_deref_mut() { @@ -428,6 +438,7 @@ impl HookRunner { let stdout = String::from_utf8_lossy(&output.stdout).trim().to_string(); let stderr = String::from_utf8_lossy(&output.stderr).trim().to_string(); let parsed = parse_hook_output(&stdout); + let primary_message = parsed.primary_message().map(ToOwned::to_owned); match output.status.code() { Some(0) => { if parsed.deny { @@ -442,20 +453,20 @@ impl HookRunner { event.as_str() )), }, - Some(code) => HookCommandOutcome::Warn { - message: format_hook_warning( + Some(code) => HookCommandOutcome::Failed { + parsed: parsed.with_fallback_message(format_hook_failure( command, code, - parsed.primary_message(), + primary_message.as_deref(), stderr.as_str(), - ), + )), }, - None => HookCommandOutcome::Warn { - message: format!( + None => HookCommandOutcome::Failed { + parsed: parsed.with_fallback_message(format!( "{} hook `{command}` terminated by signal while handling `{}`", event.as_str(), tool_name - ), + )), }, } } @@ -465,12 +476,15 @@ impl HookRunner { event.as_str() ), }, - Err(error) => HookCommandOutcome::Warn { - message: format!( - "{} hook `{command}` failed to start for `{}`: {error}", - event.as_str(), - tool_name - ), + Err(error) => HookCommandOutcome::Failed { + parsed: ParsedHookOutput { + messages: vec![format!( + "{} hook `{command}` failed to start for `{}`: {error}", + event.as_str(), + tool_name + )], + ..ParsedHookOutput::default() + }, }, } } @@ -479,7 +493,7 @@ impl HookRunner { enum HookCommandOutcome { Allow { parsed: ParsedHookOutput }, Deny { parsed: ParsedHookOutput }, - Warn { message: String }, + Failed { parsed: ParsedHookOutput }, Cancelled { message: String }, } @@ -605,9 +619,8 @@ fn parse_tool_input(tool_input: &str) -> Value { serde_json::from_str(tool_input).unwrap_or_else(|_| json!({ "raw": tool_input })) } -fn format_hook_warning(command: &str, code: i32, stdout: Option<&str>, stderr: &str) -> String { - let mut message = - format!("Hook `{command}` exited with status {code}; allowing tool execution to continue"); +fn format_hook_failure(command: &str, code: i32, stdout: Option<&str>, stderr: &str) -> String { + let mut message = format!("Hook `{command}` exited with status {code}"); if let Some(stdout) = stdout.filter(|stdout| !stdout.is_empty()) { message.push_str(": "); message.push_str(stdout); @@ -749,7 +762,7 @@ mod tests { } #[test] - fn warns_for_other_non_zero_statuses() { + fn propagates_other_non_zero_statuses_as_failures() { let runner = HookRunner::from_feature_config(&RuntimeFeatureConfig::default().with_hooks( RuntimeHookConfig::new( vec![shell_snippet("printf 'warning hook'; exit 1")], @@ -758,13 +771,16 @@ mod tests { ), )); + // given + // when let result = runner.run_pre_tool_use("Edit", r#"{"file":"src/lib.rs"}"#); - assert!(!result.is_denied()); + // then + assert!(result.is_failed()); assert!(result .messages() .iter() - .any(|message| message.contains("allowing tool execution to continue"))); + .any(|message| message.contains("warning hook"))); } #[test] @@ -803,6 +819,91 @@ mod tests { assert_eq!(result.messages(), &["failure hook ran".to_string()]); } + #[test] + fn executes_hooks_in_configured_order() { + // given + let runner = HookRunner::new(RuntimeHookConfig::new( + vec![ + shell_snippet("printf 'first'"), + shell_snippet("printf 'second'"), + ], + Vec::new(), + Vec::new(), + )); + let mut reporter = RecordingReporter { events: Vec::new() }; + + // when + let result = runner.run_pre_tool_use_with_context( + "Read", + r#"{"path":"README.md"}"#, + None, + Some(&mut reporter), + ); + + // then + assert_eq!( + result, + HookRunResult::allow(vec!["first".to_string(), "second".to_string()]) + ); + assert_eq!(reporter.events.len(), 4); + assert!(matches!( + &reporter.events[0], + HookProgressEvent::Started { + event: HookEvent::PreToolUse, + command, + .. + } if command == "printf 'first'" + )); + assert!(matches!( + &reporter.events[1], + HookProgressEvent::Completed { + event: HookEvent::PreToolUse, + command, + .. + } if command == "printf 'first'" + )); + assert!(matches!( + &reporter.events[2], + HookProgressEvent::Started { + event: HookEvent::PreToolUse, + command, + .. + } if command == "printf 'second'" + )); + assert!(matches!( + &reporter.events[3], + HookProgressEvent::Completed { + event: HookEvent::PreToolUse, + command, + .. + } if command == "printf 'second'" + )); + } + + #[test] + fn stops_running_hooks_after_failure() { + // given + let runner = HookRunner::new(RuntimeHookConfig::new( + vec![ + shell_snippet("printf 'broken'; exit 1"), + shell_snippet("printf 'later'"), + ], + Vec::new(), + Vec::new(), + )); + + // when + let result = runner.run_pre_tool_use("Edit", r#"{"file":"src/lib.rs"}"#); + + // then + assert!(result.is_failed()); + assert!(result + .messages() + .iter() + .any(|message| message.contains("broken"))); + assert!(!result.messages().iter().any(|message| message == "later")); + } + #[test] fn abort_signal_cancels_long_running_hook_and_reports_progress() { let runner = HookRunner::new(RuntimeHookConfig::new(