diff --git a/rust/crates/commands/src/lib.rs b/rust/crates/commands/src/lib.rs index eb04307..0875bf3 100644 --- a/rust/crates/commands/src/lib.rs +++ b/rust/crates/commands/src/lib.rs @@ -201,8 +201,8 @@ const SLASH_COMMAND_SPECS: &[SlashCommandSpec] = &[ resume_supported: false, }, SlashCommandSpec { - name: "plugin", - aliases: &["plugins", "marketplace"], + name: "plugins", + aliases: &["plugin", "marketplace"], summary: "Manage Claude Code plugins", argument_hint: Some( "[list|install |enable |disable |uninstall |update ]", @@ -1229,7 +1229,7 @@ mod tests { assert!(help.contains( "/plugins [list|install |enable |disable |uninstall |update ]" )); - assert_eq!(slash_command_specs().len(), 23); + assert_eq!(slash_command_specs().len(), 25); assert_eq!(resume_supported_slash_commands().len(), 11); } diff --git a/rust/crates/runtime/src/config.rs b/rust/crates/runtime/src/config.rs index dfc4d1a..78a1a99 100644 --- a/rust/crates/runtime/src/config.rs +++ b/rust/crates/runtime/src/config.rs @@ -52,6 +52,7 @@ pub struct RuntimeFeatureConfig { oauth: Option, model: Option, permission_mode: Option, + permission_rules: RuntimePermissionRuleConfig, sandbox: SandboxConfig, } @@ -59,6 +60,14 @@ pub struct RuntimeFeatureConfig { pub struct RuntimeHookConfig { pre_tool_use: Vec, post_tool_use: Vec, + post_tool_use_failure: Vec, +} + +#[derive(Debug, Clone, PartialEq, Eq, Default)] +pub struct RuntimePermissionRuleConfig { + allow: Vec, + deny: Vec, + ask: Vec, } #[derive(Debug, Clone, PartialEq, Eq, Default)] @@ -248,6 +257,7 @@ impl ConfigLoader { oauth: parse_optional_oauth_config(&merged_value, "merged settings.oauth")?, model: parse_optional_model(&merged_value), permission_mode: parse_optional_permission_mode(&merged_value)?, + permission_rules: parse_optional_permission_rules(&merged_value)?, sandbox: parse_optional_sandbox_config(&merged_value)?, }; @@ -324,6 +334,11 @@ impl RuntimeConfig { self.feature_config.permission_mode } + #[must_use] + pub fn permission_rules(&self) -> &RuntimePermissionRuleConfig { + &self.feature_config.permission_rules + } + #[must_use] pub fn sandbox(&self) -> &SandboxConfig { &self.feature_config.sandbox @@ -373,6 +388,11 @@ impl RuntimeFeatureConfig { self.permission_mode } + #[must_use] + pub fn permission_rules(&self) -> &RuntimePermissionRuleConfig { + &self.permission_rules + } + #[must_use] pub fn sandbox(&self) -> &SandboxConfig { &self.sandbox @@ -428,10 +448,15 @@ pub fn default_config_home() -> PathBuf { impl RuntimeHookConfig { #[must_use] - pub fn new(pre_tool_use: Vec, post_tool_use: Vec) -> Self { + pub fn new( + pre_tool_use: Vec, + post_tool_use: Vec, + post_tool_use_failure: Vec, + ) -> Self { Self { pre_tool_use, post_tool_use, + post_tool_use_failure, } } @@ -445,6 +470,11 @@ impl RuntimeHookConfig { &self.post_tool_use } + #[must_use] + pub fn post_tool_use_failure(&self) -> &[String] { + &self.post_tool_use_failure + } + #[must_use] pub fn merged(&self, other: &Self) -> Self { let mut merged = self.clone(); @@ -455,6 +485,32 @@ impl RuntimeHookConfig { pub fn extend(&mut self, other: &Self) { extend_unique(&mut self.pre_tool_use, other.pre_tool_use()); extend_unique(&mut self.post_tool_use, other.post_tool_use()); + extend_unique( + &mut self.post_tool_use_failure, + other.post_tool_use_failure(), + ); + } +} + +impl RuntimePermissionRuleConfig { + #[must_use] + pub fn new(allow: Vec, deny: Vec, ask: Vec) -> Self { + Self { allow, deny, ask } + } + + #[must_use] + pub fn allow(&self) -> &[String] { + &self.allow + } + + #[must_use] + pub fn deny(&self) -> &[String] { + &self.deny + } + + #[must_use] + pub fn ask(&self) -> &[String] { + &self.ask } } @@ -569,6 +625,32 @@ fn parse_optional_hooks_config(root: &JsonValue) -> Result Result { + let Some(object) = root.as_object() else { + return Ok(RuntimePermissionRuleConfig::default()); + }; + let Some(permissions) = object.get("permissions").and_then(JsonValue::as_object) else { + return Ok(RuntimePermissionRuleConfig::default()); + }; + + Ok(RuntimePermissionRuleConfig { + allow: optional_string_array(permissions, "allow", "merged settings.permissions")? + .unwrap_or_default(), + deny: optional_string_array(permissions, "deny", "merged settings.permissions")? + .unwrap_or_default(), + ask: optional_string_array(permissions, "ask", "merged settings.permissions")? + .unwrap_or_default(), }) } @@ -991,7 +1073,7 @@ mod tests { .expect("write user compat config"); fs::write( home.join("settings.json"), - r#"{"model":"sonnet","env":{"A2":"1"},"hooks":{"PreToolUse":["base"]},"permissions":{"defaultMode":"plan"}}"#, + r#"{"model":"sonnet","env":{"A2":"1"},"hooks":{"PreToolUse":["base"]},"permissions":{"defaultMode":"plan","allow":["Read"],"deny":["Bash(rm -rf)"]}}"#, ) .expect("write user settings"); fs::write( @@ -1001,7 +1083,7 @@ mod tests { .expect("write project compat config"); fs::write( cwd.join(".claude").join("settings.json"), - r#"{"env":{"C":"3"},"hooks":{"PostToolUse":["project"]},"mcpServers":{"project":{"command":"uvx","args":["project"]}}}"#, + r#"{"env":{"C":"3"},"hooks":{"PostToolUse":["project"],"PostToolUseFailure":["project-failure"]},"permissions":{"ask":["Edit"]},"mcpServers":{"project":{"command":"uvx","args":["project"]}}}"#, ) .expect("write project settings"); fs::write( @@ -1046,6 +1128,16 @@ mod tests { .contains_key("PostToolUse")); assert_eq!(loaded.hooks().pre_tool_use(), &["base".to_string()]); assert_eq!(loaded.hooks().post_tool_use(), &["project".to_string()]); + assert_eq!( + loaded.hooks().post_tool_use_failure(), + &["project-failure".to_string()] + ); + assert_eq!(loaded.permission_rules().allow(), &["Read".to_string()]); + assert_eq!( + loaded.permission_rules().deny(), + &["Bash(rm -rf)".to_string()] + ); + assert_eq!(loaded.permission_rules().ask(), &["Edit".to_string()]); assert!(loaded.mcp().get("home").is_some()); assert!(loaded.mcp().get("project").is_some()); diff --git a/rust/crates/runtime/src/conversation.rs b/rust/crates/runtime/src/conversation.rs index a73f2f4..bd1d1a4 100644 --- a/rust/crates/runtime/src/conversation.rs +++ b/rust/crates/runtime/src/conversation.rs @@ -7,8 +7,10 @@ use crate::compact::{ compact_session, estimate_session_tokens, CompactionConfig, CompactionResult, }; use crate::config::RuntimeFeatureConfig; -use crate::hooks::HookRunner; -use crate::permissions::{PermissionOutcome, PermissionPolicy, PermissionPrompter}; +use crate::hooks::{HookAbortSignal, HookProgressReporter, HookRunResult, HookRunner}; +use crate::permissions::{ + PermissionContext, PermissionOutcome, PermissionPolicy, PermissionPrompter, +}; use crate::session::{ContentBlock, ConversationMessage, Session}; use crate::usage::{TokenUsage, UsageTracker}; @@ -112,6 +114,8 @@ pub struct ConversationRuntime { plugin_hook_runner: Option, plugin_registry: Option, plugins_shutdown: bool, + hook_abort_signal: HookAbortSignal, + hook_progress_reporter: Option>, } impl ConversationRuntime { @@ -176,6 +180,8 @@ where plugin_hook_runner: None, plugin_registry: None, plugins_shutdown: false, + hook_abort_signal: HookAbortSignal::default(), + hook_progress_reporter: None, } } @@ -221,6 +227,92 @@ where self } + #[must_use] + pub fn with_hook_abort_signal(mut self, hook_abort_signal: HookAbortSignal) -> Self { + self.hook_abort_signal = hook_abort_signal; + self + } + + #[must_use] + pub fn with_hook_progress_reporter( + mut self, + hook_progress_reporter: Box, + ) -> Self { + self.hook_progress_reporter = Some(hook_progress_reporter); + self + } + + fn run_pre_tool_use_hook(&mut self, tool_name: &str, input: &str) -> HookRunResult { + if let Some(reporter) = self.hook_progress_reporter.as_mut() { + self.hook_runner.run_pre_tool_use_with_context( + tool_name, + input, + Some(&self.hook_abort_signal), + Some(reporter.as_mut()), + ) + } else { + self.hook_runner.run_pre_tool_use_with_context( + tool_name, + input, + Some(&self.hook_abort_signal), + None, + ) + } + } + + fn run_post_tool_use_hook( + &mut self, + tool_name: &str, + input: &str, + output: &str, + is_error: bool, + ) -> HookRunResult { + if let Some(reporter) = self.hook_progress_reporter.as_mut() { + self.hook_runner.run_post_tool_use_with_context( + tool_name, + input, + output, + is_error, + Some(&self.hook_abort_signal), + Some(reporter.as_mut()), + ) + } else { + self.hook_runner.run_post_tool_use_with_context( + tool_name, + input, + output, + is_error, + Some(&self.hook_abort_signal), + None, + ) + } + } + + fn run_post_tool_use_failure_hook( + &mut self, + tool_name: &str, + input: &str, + output: &str, + ) -> HookRunResult { + if let Some(reporter) = self.hook_progress_reporter.as_mut() { + self.hook_runner.run_post_tool_use_failure_with_context( + tool_name, + input, + output, + Some(&self.hook_abort_signal), + Some(reporter.as_mut()), + ) + } else { + self.hook_runner.run_post_tool_use_failure_with_context( + tool_name, + input, + output, + Some(&self.hook_abort_signal), + None, + ) + } + } + #[allow(clippy::too_many_lines)] pub fn run_turn( &mut self, @@ -273,94 +365,124 @@ where } for (tool_use_id, tool_name, input) in pending_tool_uses { - let permission_outcome = if let Some(prompt) = prompter.as_mut() { - self.permission_policy - .authorize(&tool_name, &input, Some(*prompt)) + let pre_hook_result = self.run_pre_tool_use_hook(&tool_name, &input); + let effective_input = pre_hook_result + .updated_input() + .map_or_else(|| input.clone(), ToOwned::to_owned); + let permission_context = PermissionContext::new( + pre_hook_result.permission_override(), + pre_hook_result.permission_reason().map(ToOwned::to_owned), + ); + + let permission_outcome = if pre_hook_result.is_cancelled() { + PermissionOutcome::Deny { + reason: format_hook_message( + pre_hook_result.messages(), + &format!("PreToolUse hook cancelled tool `{tool_name}`"), + ), + } + } else if pre_hook_result.is_denied() { + PermissionOutcome::Deny { + reason: format_hook_message( + pre_hook_result.messages(), + &format!("PreToolUse hook denied tool `{tool_name}`"), + ), + } + } else if let Some(prompt) = prompter.as_mut() { + self.permission_policy.authorize_with_context( + &tool_name, + &effective_input, + &permission_context, + Some(*prompt), + ) } else { - self.permission_policy.authorize(&tool_name, &input, None) + self.permission_policy.authorize_with_context( + &tool_name, + &effective_input, + &permission_context, + None, + ) }; let result_message = match permission_outcome { PermissionOutcome::Allow => { - let pre_hook_result = self.hook_runner.run_pre_tool_use(&tool_name, &input); - if pre_hook_result.is_denied() { + let plugin_pre_hook_result = + self.run_plugin_pre_tool_use(&tool_name, &effective_input); + if plugin_pre_hook_result.is_denied() { let deny_message = format!("PreToolUse hook denied tool `{tool_name}`"); + let mut messages = pre_hook_result.messages().to_vec(); + messages.extend(plugin_pre_hook_result.messages().iter().cloned()); ConversationMessage::tool_result( tool_use_id, tool_name, - format_hook_message(pre_hook_result.messages(), &deny_message), + format_hook_message(&messages, &deny_message), true, ) } else { - let plugin_pre_hook_result = - self.run_plugin_pre_tool_use(&tool_name, &input); - if plugin_pre_hook_result.is_denied() { - let deny_message = - format!("PreToolUse hook denied tool `{tool_name}`"); - let mut messages = pre_hook_result.messages().to_vec(); - messages.extend(plugin_pre_hook_result.messages().iter().cloned()); - ConversationMessage::tool_result( - tool_use_id, - tool_name, - format_hook_message(&messages, &deny_message), - true, + let (mut output, mut is_error) = + match self.tool_executor.execute(&tool_name, &effective_input) { + Ok(output) => (output, false), + Err(error) => (error.to_string(), true), + }; + output = merge_hook_feedback(pre_hook_result.messages(), output, false); + output = merge_hook_feedback( + plugin_pre_hook_result.messages(), + output, + false, + ); + + let hook_output = output.clone(); + let post_hook_result = if is_error { + self.run_post_tool_use_failure_hook( + &tool_name, + &effective_input, + &hook_output, ) } else { - let (mut output, mut is_error) = - match self.tool_executor.execute(&tool_name, &input) { - Ok(output) => (output, false), - Err(error) => (error.to_string(), true), - }; - output = - merge_hook_feedback(pre_hook_result.messages(), output, false); - output = merge_hook_feedback( - plugin_pre_hook_result.messages(), - output, + self.run_post_tool_use_hook( + &tool_name, + &effective_input, + &hook_output, false, - ); - - let hook_output = output.clone(); - let post_hook_result = self.hook_runner.run_post_tool_use( - &tool_name, - &input, - &hook_output, - is_error, - ); - let plugin_post_hook_result = self.run_plugin_post_tool_use( - &tool_name, - &input, - &hook_output, - is_error, - ); - if post_hook_result.is_denied() { - is_error = true; - } - if plugin_post_hook_result.is_denied() { - is_error = true; - } - output = merge_hook_feedback( - post_hook_result.messages(), - output, - post_hook_result.is_denied(), - ); - output = merge_hook_feedback( - plugin_post_hook_result.messages(), - output, - plugin_post_hook_result.is_denied(), - ); - - ConversationMessage::tool_result( - tool_use_id, - tool_name, - output, - is_error, ) + }; + let plugin_post_hook_result = self.run_plugin_post_tool_use( + &tool_name, + &effective_input, + &hook_output, + is_error, + ); + if post_hook_result.is_denied() + || post_hook_result.is_cancelled() + || plugin_post_hook_result.is_denied() + { + is_error = true; } + output = merge_hook_feedback( + post_hook_result.messages(), + output, + post_hook_result.is_denied() || post_hook_result.is_cancelled(), + ); + output = merge_hook_feedback( + plugin_post_hook_result.messages(), + output, + plugin_post_hook_result.is_denied(), + ); + + ConversationMessage::tool_result( + tool_use_id, + tool_name, + output, + is_error, + ) } } - PermissionOutcome::Deny { reason } => { - ConversationMessage::tool_result(tool_use_id, tool_name, reason, true) - } + PermissionOutcome::Deny { reason } => ConversationMessage::tool_result( + tool_use_id, + tool_name, + merge_hook_feedback(pre_hook_result.messages(), reason, true), + true, + ), }; self.session.messages.push(result_message.clone()); tool_results.push(result_message); @@ -870,6 +992,7 @@ mod tests { RuntimeFeatureConfig::default().with_hooks(RuntimeHookConfig::new( vec![shell_snippet("printf 'blocked by hook'; exit 2")], Vec::new(), + Vec::new(), )), ); @@ -936,6 +1059,7 @@ mod tests { RuntimeFeatureConfig::default().with_hooks(RuntimeHookConfig::new( vec![shell_snippet("printf 'pre hook ran'")], vec![shell_snippet("printf 'post hook ran'")], + Vec::new(), )), ); diff --git a/rust/crates/runtime/src/hooks.rs b/rust/crates/runtime/src/hooks.rs index 4aff002..3d89a13 100644 --- a/rust/crates/runtime/src/hooks.rs +++ b/rust/crates/runtime/src/hooks.rs @@ -1,30 +1,91 @@ use std::ffi::OsStr; +use std::io::Write; use std::path::Path; -use std::process::Command; +use std::process::{Command, Stdio}; +use std::sync::{ + atomic::{AtomicBool, Ordering}, + Arc, +}; +use std::thread; +use std::time::Duration; -use serde_json::json; +use serde_json::{json, Value}; use crate::config::{RuntimeFeatureConfig, RuntimeHookConfig}; +use crate::permissions::PermissionOverride; + +pub type HookPermissionDecision = PermissionOverride; #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub enum HookEvent { PreToolUse, PostToolUse, + PostToolUseFailure, } impl HookEvent { - fn as_str(self) -> &'static str { + #[must_use] + pub fn as_str(self) -> &'static str { match self { Self::PreToolUse => "PreToolUse", Self::PostToolUse => "PostToolUse", + Self::PostToolUseFailure => "PostToolUseFailure", } } } +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum HookProgressEvent { + Started { + event: HookEvent, + tool_name: String, + command: String, + }, + Completed { + event: HookEvent, + tool_name: String, + command: String, + }, + Cancelled { + event: HookEvent, + tool_name: String, + command: String, + }, +} + +pub trait HookProgressReporter { + fn on_event(&mut self, event: &HookProgressEvent); +} + +#[derive(Debug, Clone, Default)] +pub struct HookAbortSignal { + aborted: Arc, +} + +impl HookAbortSignal { + #[must_use] + pub fn new() -> Self { + Self::default() + } + + pub fn abort(&self) { + self.aborted.store(true, Ordering::SeqCst); + } + + #[must_use] + pub fn is_aborted(&self) -> bool { + self.aborted.load(Ordering::SeqCst) + } +} + #[derive(Debug, Clone, PartialEq, Eq)] pub struct HookRunResult { denied: bool, + cancelled: bool, messages: Vec, + permission_override: Option, + permission_reason: Option, + updated_input: Option, } impl HookRunResult { @@ -32,7 +93,11 @@ impl HookRunResult { pub fn allow(messages: Vec) -> Self { Self { denied: false, + cancelled: false, messages, + permission_override: None, + permission_reason: None, + updated_input: None, } } @@ -41,10 +106,40 @@ impl HookRunResult { self.denied } + #[must_use] + pub fn is_cancelled(&self) -> bool { + self.cancelled + } + #[must_use] pub fn messages(&self) -> &[String] { &self.messages } + + #[must_use] + pub fn permission_override(&self) -> Option { + self.permission_override + } + + #[must_use] + pub fn permission_decision(&self) -> Option { + self.permission_override + } + + #[must_use] + pub fn permission_reason(&self) -> Option<&str> { + self.permission_reason.as_deref() + } + + #[must_use] + pub fn updated_input(&self) -> Option<&str> { + self.updated_input.as_deref() + } + + #[must_use] + pub fn updated_input_json(&self) -> Option<&str> { + self.updated_input() + } } #[derive(Debug, Clone, PartialEq, Eq, Default)] @@ -65,16 +160,39 @@ impl HookRunner { #[must_use] pub fn run_pre_tool_use(&self, tool_name: &str, tool_input: &str) -> HookRunResult { - self.run_commands( + self.run_pre_tool_use_with_context(tool_name, tool_input, None, None) + } + + #[must_use] + pub fn run_pre_tool_use_with_context( + &self, + tool_name: &str, + tool_input: &str, + abort_signal: Option<&HookAbortSignal>, + reporter: Option<&mut dyn HookProgressReporter>, + ) -> HookRunResult { + Self::run_commands( HookEvent::PreToolUse, self.config.pre_tool_use(), tool_name, tool_input, None, false, + abort_signal, + reporter, ) } + #[must_use] + pub fn run_pre_tool_use_with_signal( + &self, + tool_name: &str, + tool_input: &str, + abort_signal: Option<&HookAbortSignal>, + ) -> HookRunResult { + self.run_pre_tool_use_with_context(tool_name, tool_input, abort_signal, None) + } + #[must_use] pub fn run_post_tool_use( &self, @@ -83,43 +201,147 @@ impl HookRunner { tool_output: &str, is_error: bool, ) -> HookRunResult { - self.run_commands( + self.run_post_tool_use_with_context( + tool_name, + tool_input, + tool_output, + is_error, + None, + None, + ) + } + + #[must_use] + pub fn run_post_tool_use_with_context( + &self, + tool_name: &str, + tool_input: &str, + tool_output: &str, + is_error: bool, + abort_signal: Option<&HookAbortSignal>, + reporter: Option<&mut dyn HookProgressReporter>, + ) -> HookRunResult { + Self::run_commands( HookEvent::PostToolUse, self.config.post_tool_use(), tool_name, tool_input, Some(tool_output), is_error, + abort_signal, + reporter, ) } - fn run_commands( + #[must_use] + pub fn run_post_tool_use_with_signal( &self, + tool_name: &str, + tool_input: &str, + tool_output: &str, + is_error: bool, + abort_signal: Option<&HookAbortSignal>, + ) -> HookRunResult { + self.run_post_tool_use_with_context( + tool_name, + tool_input, + tool_output, + is_error, + abort_signal, + None, + ) + } + + #[must_use] + pub fn run_post_tool_use_failure( + &self, + tool_name: &str, + tool_input: &str, + tool_error: &str, + ) -> HookRunResult { + self.run_post_tool_use_failure_with_context(tool_name, tool_input, tool_error, None, None) + } + + #[must_use] + pub fn run_post_tool_use_failure_with_context( + &self, + tool_name: &str, + tool_input: &str, + tool_error: &str, + abort_signal: Option<&HookAbortSignal>, + reporter: Option<&mut dyn HookProgressReporter>, + ) -> HookRunResult { + Self::run_commands( + HookEvent::PostToolUseFailure, + self.config.post_tool_use_failure(), + tool_name, + tool_input, + Some(tool_error), + true, + abort_signal, + reporter, + ) + } + + #[must_use] + pub fn run_post_tool_use_failure_with_signal( + &self, + tool_name: &str, + tool_input: &str, + tool_error: &str, + abort_signal: Option<&HookAbortSignal>, + ) -> HookRunResult { + self.run_post_tool_use_failure_with_context( + tool_name, + tool_input, + tool_error, + abort_signal, + None, + ) + } + + #[allow(clippy::too_many_arguments)] + fn run_commands( event: HookEvent, commands: &[String], tool_name: &str, tool_input: &str, tool_output: Option<&str>, is_error: bool, + abort_signal: Option<&HookAbortSignal>, + mut reporter: Option<&mut dyn HookProgressReporter>, ) -> HookRunResult { if commands.is_empty() { return HookRunResult::allow(Vec::new()); } - let payload = json!({ - "hook_event_name": event.as_str(), - "tool_name": tool_name, - "tool_input": parse_tool_input(tool_input), - "tool_input_json": tool_input, - "tool_output": tool_output, - "tool_result_is_error": is_error, - }) - .to_string(); + if abort_signal.is_some_and(HookAbortSignal::is_aborted) { + return HookRunResult { + denied: false, + cancelled: true, + messages: vec![format!( + "{} hook cancelled before execution", + event.as_str() + )], + permission_override: None, + permission_reason: None, + updated_input: None, + }; + } - let mut messages = Vec::new(); + let payload = hook_payload(event, tool_name, tool_input, tool_output, is_error).to_string(); + let mut result = HookRunResult::allow(Vec::new()); for command in commands { - match self.run_command( + if let Some(reporter) = reporter.as_deref_mut() { + reporter.on_event(&HookProgressEvent::Started { + event, + tool_name: tool_name.to_string(), + command: command.clone(), + }); + } + + match Self::run_command( command, event, tool_name, @@ -127,32 +349,60 @@ impl HookRunner { tool_output, is_error, &payload, + abort_signal, ) { - HookCommandOutcome::Allow { message } => { - if let Some(message) = message { - messages.push(message); + HookCommandOutcome::Allow { parsed } => { + if let Some(reporter) = reporter.as_deref_mut() { + reporter.on_event(&HookProgressEvent::Completed { + event, + tool_name: tool_name.to_string(), + command: command.clone(), + }); } + merge_parsed_hook_output(&mut result, parsed); } - HookCommandOutcome::Deny { message } => { - let message = message.unwrap_or_else(|| { - format!("{} hook denied tool `{tool_name}`", event.as_str()) - }); - messages.push(message); - return HookRunResult { - denied: true, - messages, - }; + HookCommandOutcome::Deny { parsed } => { + if let Some(reporter) = reporter.as_deref_mut() { + reporter.on_event(&HookProgressEvent::Completed { + event, + tool_name: tool_name.to_string(), + command: command.clone(), + }); + } + merge_parsed_hook_output(&mut result, parsed); + result.denied = true; + return result; + } + HookCommandOutcome::Warn { message } => { + if let Some(reporter) = reporter.as_deref_mut() { + reporter.on_event(&HookProgressEvent::Completed { + event, + tool_name: tool_name.to_string(), + command: command.clone(), + }); + } + result.messages.push(message); + } + HookCommandOutcome::Cancelled { message } => { + if let Some(reporter) = reporter.as_deref_mut() { + reporter.on_event(&HookProgressEvent::Cancelled { + event, + tool_name: tool_name.to_string(), + command: command.clone(), + }); + } + result.cancelled = true; + result.messages.push(message); + return result; } - HookCommandOutcome::Warn { message } => messages.push(message), } } - HookRunResult::allow(messages) + result } - #[allow(clippy::too_many_arguments, clippy::unused_self)] + #[allow(clippy::too_many_arguments)] fn run_command( - &self, command: &str, event: HookEvent, tool_name: &str, @@ -160,11 +410,12 @@ impl HookRunner { tool_output: Option<&str>, is_error: bool, payload: &str, + abort_signal: Option<&HookAbortSignal>, ) -> HookCommandOutcome { let mut child = shell_command(command); - child.stdin(std::process::Stdio::piped()); - child.stdout(std::process::Stdio::piped()); - child.stderr(std::process::Stdio::piped()); + child.stdin(Stdio::piped()); + child.stdout(Stdio::piped()); + child.stderr(Stdio::piped()); child.env("HOOK_EVENT", event.as_str()); child.env("HOOK_TOOL_NAME", tool_name); child.env("HOOK_TOOL_INPUT", tool_input); @@ -173,19 +424,30 @@ impl HookRunner { child.env("HOOK_TOOL_OUTPUT", tool_output); } - match child.output_with_stdin(payload.as_bytes()) { - Ok(output) => { + match child.output_with_stdin(payload.as_bytes(), abort_signal) { + Ok(CommandExecution::Finished(output)) => { let stdout = String::from_utf8_lossy(&output.stdout).trim().to_string(); let stderr = String::from_utf8_lossy(&output.stderr).trim().to_string(); - let message = (!stdout.is_empty()).then_some(stdout); + let parsed = parse_hook_output(&stdout); match output.status.code() { - Some(0) => HookCommandOutcome::Allow { message }, - Some(2) => HookCommandOutcome::Deny { message }, + Some(0) => { + if parsed.deny { + HookCommandOutcome::Deny { parsed } + } else { + HookCommandOutcome::Allow { parsed } + } + } + Some(2) => HookCommandOutcome::Deny { + parsed: parsed.with_fallback_message(format!( + "{} hook denied tool `{tool_name}`", + event.as_str() + )), + }, Some(code) => HookCommandOutcome::Warn { message: format_hook_warning( command, code, - message.as_deref(), + parsed.primary_message(), stderr.as_str(), ), }, @@ -197,6 +459,12 @@ impl HookRunner { }, } } + Ok(CommandExecution::Cancelled) => HookCommandOutcome::Cancelled { + message: format!( + "{} hook `{command}` cancelled while handling `{tool_name}`", + event.as_str() + ), + }, Err(error) => HookCommandOutcome::Warn { message: format!( "{} hook `{command}` failed to start for `{tool_name}`: {error}", @@ -208,12 +476,131 @@ impl HookRunner { } enum HookCommandOutcome { - Allow { message: Option }, - Deny { message: Option }, + Allow { parsed: ParsedHookOutput }, + Deny { parsed: ParsedHookOutput }, Warn { message: String }, + Cancelled { message: String }, } -fn parse_tool_input(tool_input: &str) -> serde_json::Value { +#[derive(Debug, Clone, PartialEq, Eq, Default)] +struct ParsedHookOutput { + messages: Vec, + deny: bool, + permission_override: Option, + permission_reason: Option, + updated_input: Option, +} + +impl ParsedHookOutput { + fn with_fallback_message(mut self, fallback: String) -> Self { + if self.messages.is_empty() { + self.messages.push(fallback); + } + self + } + + fn primary_message(&self) -> Option<&str> { + self.messages.first().map(String::as_str) + } +} + +fn merge_parsed_hook_output(target: &mut HookRunResult, parsed: ParsedHookOutput) { + target.messages.extend(parsed.messages); + if parsed.permission_override.is_some() { + target.permission_override = parsed.permission_override; + } + if parsed.permission_reason.is_some() { + target.permission_reason = parsed.permission_reason; + } + if parsed.updated_input.is_some() { + target.updated_input = parsed.updated_input; + } +} + +fn parse_hook_output(stdout: &str) -> ParsedHookOutput { + if stdout.is_empty() { + return ParsedHookOutput::default(); + } + + let Ok(Value::Object(root)) = serde_json::from_str::(stdout) else { + return ParsedHookOutput { + messages: vec![stdout.to_string()], + ..ParsedHookOutput::default() + }; + }; + + let mut parsed = ParsedHookOutput::default(); + + if let Some(message) = root.get("systemMessage").and_then(Value::as_str) { + parsed.messages.push(message.to_string()); + } + if let Some(message) = root.get("reason").and_then(Value::as_str) { + parsed.messages.push(message.to_string()); + } + if root.get("continue").and_then(Value::as_bool) == Some(false) + || root.get("decision").and_then(Value::as_str) == Some("block") + { + parsed.deny = true; + } + + if let Some(Value::Object(specific)) = root.get("hookSpecificOutput") { + if let Some(Value::String(additional_context)) = specific.get("additionalContext") { + parsed.messages.push(additional_context.clone()); + } + if let Some(decision) = specific.get("permissionDecision").and_then(Value::as_str) { + parsed.permission_override = match decision { + "allow" => Some(PermissionOverride::Allow), + "deny" => Some(PermissionOverride::Deny), + "ask" => Some(PermissionOverride::Ask), + _ => None, + }; + } + if let Some(reason) = specific + .get("permissionDecisionReason") + .and_then(Value::as_str) + { + parsed.permission_reason = Some(reason.to_string()); + } + if let Some(updated_input) = specific.get("updatedInput") { + parsed.updated_input = serde_json::to_string(updated_input).ok(); + } + } + + if parsed.messages.is_empty() { + parsed.messages.push(stdout.to_string()); + } + + parsed +} + +fn hook_payload( + event: HookEvent, + tool_name: &str, + tool_input: &str, + tool_output: Option<&str>, + is_error: bool, +) -> Value { + match event { + HookEvent::PostToolUseFailure => json!({ + "hook_event_name": event.as_str(), + "tool_name": tool_name, + "tool_input": parse_tool_input(tool_input), + "tool_input_json": tool_input, + "tool_error": tool_output, + "tool_result_is_error": true, + }), + _ => json!({ + "hook_event_name": event.as_str(), + "tool_name": tool_name, + "tool_input": parse_tool_input(tool_input), + "tool_input_json": tool_input, + "tool_output": tool_output, + "tool_result_is_error": is_error, + }), + } +} + +fn parse_tool_input(tool_input: &str) -> Value { serde_json::from_str(tool_input).unwrap_or_else(|_| json!({ "raw": tool_input })) } @@ -261,17 +648,17 @@ impl CommandWithStdin { Self { command } } - fn stdin(&mut self, cfg: std::process::Stdio) -> &mut Self { + fn stdin(&mut self, cfg: Stdio) -> &mut Self { self.command.stdin(cfg); self } - fn stdout(&mut self, cfg: std::process::Stdio) -> &mut Self { + fn stdout(&mut self, cfg: Stdio) -> &mut Self { self.command.stdout(cfg); self } - fn stderr(&mut self, cfg: std::process::Stdio) -> &mut Self { + fn stderr(&mut self, cfg: Stdio) -> &mut Self { self.command.stderr(cfg); self } @@ -285,26 +672,64 @@ impl CommandWithStdin { self } - fn output_with_stdin(&mut self, stdin: &[u8]) -> std::io::Result { + fn output_with_stdin( + &mut self, + stdin: &[u8], + abort_signal: Option<&HookAbortSignal>, + ) -> std::io::Result { let mut child = self.command.spawn()?; if let Some(mut child_stdin) = child.stdin.take() { - use std::io::Write; child_stdin.write_all(stdin)?; } - child.wait_with_output() + + loop { + if abort_signal.is_some_and(HookAbortSignal::is_aborted) { + let _ = child.kill(); + let _ = child.wait_with_output(); + return Ok(CommandExecution::Cancelled); + } + + match child.try_wait()? { + Some(_) => return child.wait_with_output().map(CommandExecution::Finished), + None => thread::sleep(Duration::from_millis(20)), + } + } } } +enum CommandExecution { + Finished(std::process::Output), + Cancelled, +} + #[cfg(test)] mod tests { - use super::{HookRunResult, HookRunner}; + use std::thread; + use std::time::Duration; + + use super::{ + HookAbortSignal, HookEvent, HookProgressEvent, HookProgressReporter, HookRunResult, + HookRunner, + }; use crate::config::{RuntimeFeatureConfig, RuntimeHookConfig}; + use crate::permissions::PermissionOverride; + + struct RecordingReporter { + events: Vec, + } + + impl HookProgressReporter for RecordingReporter { + fn on_event(&mut self, event: &HookProgressEvent) { + self.events.push(event.clone()); + } + } #[test] fn allows_exit_code_zero_and_captures_stdout() { let runner = HookRunner::new(RuntimeHookConfig::new( vec![shell_snippet("printf 'pre ok'")], Vec::new(), + Vec::new(), )); let result = runner.run_pre_tool_use("Read", r#"{"path":"README.md"}"#); @@ -317,6 +742,7 @@ mod tests { let runner = HookRunner::new(RuntimeHookConfig::new( vec![shell_snippet("printf 'blocked by hook'; exit 2")], Vec::new(), + Vec::new(), )); let result = runner.run_pre_tool_use("Bash", r#"{"command":"pwd"}"#); @@ -331,6 +757,7 @@ mod tests { RuntimeHookConfig::new( vec![shell_snippet("printf 'warning hook'; exit 1")], Vec::new(), + Vec::new(), ), )); @@ -343,6 +770,82 @@ mod tests { .any(|message| message.contains("allowing tool execution to continue"))); } + #[test] + fn parses_pre_hook_permission_override_and_updated_input() { + let runner = HookRunner::new(RuntimeHookConfig::new( + vec![shell_snippet( + r#"printf '%s' '{"systemMessage":"updated","hookSpecificOutput":{"permissionDecision":"allow","permissionDecisionReason":"hook ok","updatedInput":{"command":"git status"}}}'"#, + )], + Vec::new(), + Vec::new(), + )); + + let result = runner.run_pre_tool_use("bash", r#"{"command":"pwd"}"#); + + assert_eq!( + result.permission_override(), + Some(PermissionOverride::Allow) + ); + assert_eq!(result.permission_reason(), Some("hook ok")); + assert_eq!(result.updated_input(), Some(r#"{"command":"git status"}"#)); + assert!(result.messages().iter().any(|message| message == "updated")); + } + + #[test] + fn runs_post_tool_use_failure_hooks() { + let runner = HookRunner::new(RuntimeHookConfig::new( + Vec::new(), + Vec::new(), + vec![shell_snippet("printf 'failure hook ran'")], + )); + + let result = + runner.run_post_tool_use_failure("bash", r#"{"command":"false"}"#, "command failed"); + + assert!(!result.is_denied()); + assert_eq!(result.messages(), &["failure hook ran".to_string()]); + } + + #[test] + fn abort_signal_cancels_long_running_hook_and_reports_progress() { + let runner = HookRunner::new(RuntimeHookConfig::new( + vec![shell_snippet("sleep 5")], + Vec::new(), + Vec::new(), + )); + let abort_signal = HookAbortSignal::new(); + let abort_signal_for_thread = abort_signal.clone(); + let mut reporter = RecordingReporter { events: Vec::new() }; + + thread::spawn(move || { + thread::sleep(Duration::from_millis(100)); + abort_signal_for_thread.abort(); + }); + + let result = runner.run_pre_tool_use_with_context( + "bash", + r#"{"command":"sleep 5"}"#, + Some(&abort_signal), + Some(&mut reporter), + ); + + assert!(result.is_cancelled()); + assert!(reporter.events.iter().any(|event| matches!( + event, + HookProgressEvent::Started { + event: HookEvent::PreToolUse, + .. + } + ))); + assert!(reporter.events.iter().any(|event| matches!( + event, + HookProgressEvent::Cancelled { + event: HookEvent::PreToolUse, + .. + } + ))); + } + #[cfg(windows)] fn shell_snippet(script: &str) -> String { script.replace('\'', "\"") diff --git a/rust/crates/runtime/src/lib.rs b/rust/crates/runtime/src/lib.rs index edac666..98c27bb 100644 --- a/rust/crates/runtime/src/lib.rs +++ b/rust/crates/runtime/src/lib.rs @@ -28,7 +28,8 @@ pub use config::{ McpConfigCollection, McpOAuthConfig, McpRemoteServerConfig, McpSdkServerConfig, McpServerConfig, McpStdioServerConfig, McpTransport, McpWebSocketServerConfig, OAuthConfig, ResolvedPermissionMode, RuntimeConfig, RuntimeFeatureConfig, RuntimeHookConfig, - RuntimePluginConfig, ScopedMcpServerConfig, CLAUDE_CODE_SETTINGS_SCHEMA_NAME, + RuntimePermissionRuleConfig, RuntimePluginConfig, ScopedMcpServerConfig, + CLAUDE_CODE_SETTINGS_SCHEMA_NAME, }; pub use conversation::{ auto_compaction_threshold_from_env, ApiClient, ApiRequest, AssistantEvent, AutoCompactionEvent, @@ -39,7 +40,9 @@ pub use file_ops::{ GrepSearchInput, GrepSearchOutput, ReadFileOutput, StructuredPatchHunk, TextFilePayload, WriteFileOutput, }; -pub use hooks::{HookEvent, HookRunResult, HookRunner}; +pub use hooks::{ + HookAbortSignal, HookEvent, HookProgressEvent, HookProgressReporter, HookRunResult, HookRunner, +}; pub use mcp::{ mcp_server_signature, mcp_tool_name, mcp_tool_prefix, normalize_name_for_mcp, scoped_mcp_config_hash, unwrap_ccr_proxy_url, @@ -64,8 +67,8 @@ pub use oauth::{ PkceChallengeMethod, PkceCodePair, }; pub use permissions::{ - PermissionMode, PermissionOutcome, PermissionPolicy, PermissionPromptDecision, - PermissionPrompter, PermissionRequest, + PermissionContext, PermissionMode, PermissionOutcome, PermissionOverride, PermissionPolicy, + PermissionPromptDecision, PermissionPrompter, PermissionRequest, }; pub use prompt::{ load_system_prompt, prepend_bullets, ContextFile, ProjectContext, PromptBuildError, diff --git a/rust/crates/runtime/src/permissions.rs b/rust/crates/runtime/src/permissions.rs index bed2eab..3acf5c1 100644 --- a/rust/crates/runtime/src/permissions.rs +++ b/rust/crates/runtime/src/permissions.rs @@ -1,5 +1,9 @@ use std::collections::BTreeMap; +use serde_json::Value; + +use crate::config::RuntimePermissionRuleConfig; + #[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)] pub enum PermissionMode { ReadOnly, @@ -22,12 +26,49 @@ impl PermissionMode { } } +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum PermissionOverride { + Allow, + Deny, + Ask, +} + +#[derive(Debug, Clone, PartialEq, Eq, Default)] +pub struct PermissionContext { + override_decision: Option, + override_reason: Option, +} + +impl PermissionContext { + #[must_use] + pub fn new( + override_decision: Option, + override_reason: Option, + ) -> Self { + Self { + override_decision, + override_reason, + } + } + + #[must_use] + pub fn override_decision(&self) -> Option { + self.override_decision + } + + #[must_use] + pub fn override_reason(&self) -> Option<&str> { + self.override_reason.as_deref() + } +} + #[derive(Debug, Clone, PartialEq, Eq)] pub struct PermissionRequest { pub tool_name: String, pub input: String, pub current_mode: PermissionMode, pub required_mode: PermissionMode, + pub reason: Option, } #[derive(Debug, Clone, PartialEq, Eq)] @@ -50,6 +91,9 @@ pub enum PermissionOutcome { pub struct PermissionPolicy { active_mode: PermissionMode, tool_requirements: BTreeMap, + allow_rules: Vec, + deny_rules: Vec, + ask_rules: Vec, } impl PermissionPolicy { @@ -58,6 +102,9 @@ impl PermissionPolicy { Self { active_mode, tool_requirements: BTreeMap::new(), + allow_rules: Vec::new(), + deny_rules: Vec::new(), + ask_rules: Vec::new(), } } @@ -72,6 +119,26 @@ impl PermissionPolicy { self } + #[must_use] + pub fn with_permission_rules(mut self, config: &RuntimePermissionRuleConfig) -> Self { + self.allow_rules = config + .allow() + .iter() + .map(|rule| PermissionRule::parse(rule)) + .collect(); + self.deny_rules = config + .deny() + .iter() + .map(|rule| PermissionRule::parse(rule)) + .collect(); + self.ask_rules = config + .ask() + .iter() + .map(|rule| PermissionRule::parse(rule)) + .collect(); + self + } + #[must_use] pub fn active_mode(&self) -> PermissionMode { self.active_mode @@ -90,38 +157,121 @@ impl PermissionPolicy { &self, tool_name: &str, input: &str, - mut prompter: Option<&mut dyn PermissionPrompter>, + prompter: Option<&mut dyn PermissionPrompter>, ) -> PermissionOutcome { - let current_mode = self.active_mode(); - let required_mode = self.required_mode_for(tool_name); - if current_mode == PermissionMode::Allow || current_mode >= required_mode { - return PermissionOutcome::Allow; + self.authorize_with_context(tool_name, input, &PermissionContext::default(), prompter) + } + + #[must_use] + #[allow(clippy::too_many_lines)] + pub fn authorize_with_context( + &self, + tool_name: &str, + input: &str, + context: &PermissionContext, + prompter: Option<&mut dyn PermissionPrompter>, + ) -> PermissionOutcome { + if let Some(rule) = Self::find_matching_rule(&self.deny_rules, tool_name, input) { + return PermissionOutcome::Deny { + reason: format!( + "Permission to use {tool_name} has been denied by rule '{}'", + rule.raw + ), + }; } - let request = PermissionRequest { - tool_name: tool_name.to_string(), - input: input.to_string(), - current_mode, - required_mode, - }; + let current_mode = self.active_mode(); + let required_mode = self.required_mode_for(tool_name); + let ask_rule = Self::find_matching_rule(&self.ask_rules, tool_name, input); + let allow_rule = Self::find_matching_rule(&self.allow_rules, tool_name, input); + + match context.override_decision() { + Some(PermissionOverride::Deny) => { + return PermissionOutcome::Deny { + reason: context.override_reason().map_or_else( + || format!("tool '{tool_name}' denied by hook"), + ToOwned::to_owned, + ), + }; + } + Some(PermissionOverride::Ask) => { + let reason = context.override_reason().map_or_else( + || format!("tool '{tool_name}' requires approval due to hook guidance"), + ToOwned::to_owned, + ); + return Self::prompt_or_deny( + tool_name, + input, + current_mode, + required_mode, + Some(reason), + prompter, + ); + } + Some(PermissionOverride::Allow) => { + if let Some(rule) = ask_rule { + let reason = format!( + "tool '{tool_name}' requires approval due to ask rule '{}'", + rule.raw + ); + return Self::prompt_or_deny( + tool_name, + input, + current_mode, + required_mode, + Some(reason), + prompter, + ); + } + if allow_rule.is_some() + || current_mode == PermissionMode::Allow + || current_mode >= required_mode + { + return PermissionOutcome::Allow; + } + } + None => {} + } + + if let Some(rule) = ask_rule { + let reason = format!( + "tool '{tool_name}' requires approval due to ask rule '{}'", + rule.raw + ); + return Self::prompt_or_deny( + tool_name, + input, + current_mode, + required_mode, + Some(reason), + prompter, + ); + } + + if allow_rule.is_some() + || current_mode == PermissionMode::Allow + || current_mode >= required_mode + { + return PermissionOutcome::Allow; + } if current_mode == PermissionMode::Prompt || (current_mode == PermissionMode::WorkspaceWrite && required_mode == PermissionMode::DangerFullAccess) { - return match prompter.as_mut() { - Some(prompter) => match prompter.decide(&request) { - PermissionPromptDecision::Allow => PermissionOutcome::Allow, - PermissionPromptDecision::Deny { reason } => PermissionOutcome::Deny { reason }, - }, - None => PermissionOutcome::Deny { - reason: format!( - "tool '{tool_name}' requires approval to escalate from {} to {}", - current_mode.as_str(), - required_mode.as_str() - ), - }, - }; + let reason = Some(format!( + "tool '{tool_name}' requires approval to escalate from {} to {}", + current_mode.as_str(), + required_mode.as_str() + )); + return Self::prompt_or_deny( + tool_name, + input, + current_mode, + required_mode, + reason, + prompter, + ); } PermissionOutcome::Deny { @@ -132,14 +282,191 @@ impl PermissionPolicy { ), } } + + fn prompt_or_deny( + tool_name: &str, + input: &str, + current_mode: PermissionMode, + required_mode: PermissionMode, + reason: Option, + mut prompter: Option<&mut dyn PermissionPrompter>, + ) -> PermissionOutcome { + let request = PermissionRequest { + tool_name: tool_name.to_string(), + input: input.to_string(), + current_mode, + required_mode, + reason: reason.clone(), + }; + + match prompter.as_mut() { + Some(prompter) => match prompter.decide(&request) { + PermissionPromptDecision::Allow => PermissionOutcome::Allow, + PermissionPromptDecision::Deny { reason } => PermissionOutcome::Deny { reason }, + }, + None => PermissionOutcome::Deny { + reason: reason.unwrap_or_else(|| { + format!( + "tool '{tool_name}' requires approval to run while mode is {}", + current_mode.as_str() + ) + }), + }, + } + } + + fn find_matching_rule<'a>( + rules: &'a [PermissionRule], + tool_name: &str, + input: &str, + ) -> Option<&'a PermissionRule> { + rules.iter().find(|rule| rule.matches(tool_name, input)) + } +} + +#[derive(Debug, Clone, PartialEq, Eq)] +struct PermissionRule { + raw: String, + tool_name: String, + matcher: PermissionRuleMatcher, +} + +#[derive(Debug, Clone, PartialEq, Eq)] +enum PermissionRuleMatcher { + Any, + Exact(String), + Prefix(String), +} + +impl PermissionRule { + fn parse(raw: &str) -> Self { + let trimmed = raw.trim(); + let open = find_first_unescaped(trimmed, '('); + let close = find_last_unescaped(trimmed, ')'); + + if let (Some(open), Some(close)) = (open, close) { + if close == trimmed.len() - 1 && open < close { + let tool_name = trimmed[..open].trim(); + let content = &trimmed[open + 1..close]; + if !tool_name.is_empty() { + let matcher = parse_rule_matcher(content); + return Self { + raw: trimmed.to_string(), + tool_name: tool_name.to_string(), + matcher, + }; + } + } + } + + Self { + raw: trimmed.to_string(), + tool_name: trimmed.to_string(), + matcher: PermissionRuleMatcher::Any, + } + } + + fn matches(&self, tool_name: &str, input: &str) -> bool { + if self.tool_name != tool_name { + return false; + } + + match &self.matcher { + PermissionRuleMatcher::Any => true, + PermissionRuleMatcher::Exact(expected) => { + extract_permission_subject(input).is_some_and(|candidate| candidate == *expected) + } + PermissionRuleMatcher::Prefix(prefix) => extract_permission_subject(input) + .is_some_and(|candidate| candidate.starts_with(prefix)), + } + } +} + +fn parse_rule_matcher(content: &str) -> PermissionRuleMatcher { + let unescaped = unescape_rule_content(content.trim()); + if unescaped.is_empty() || unescaped == "*" { + PermissionRuleMatcher::Any + } else if let Some(prefix) = unescaped.strip_suffix(":*") { + PermissionRuleMatcher::Prefix(prefix.to_string()) + } else { + PermissionRuleMatcher::Exact(unescaped) + } +} + +fn unescape_rule_content(content: &str) -> String { + content + .replace(r"\(", "(") + .replace(r"\)", ")") + .replace(r"\\", r"\") +} + +fn find_first_unescaped(value: &str, needle: char) -> Option { + let mut escaped = false; + for (idx, ch) in value.char_indices() { + if ch == '\\' { + escaped = !escaped; + continue; + } + if ch == needle && !escaped { + return Some(idx); + } + escaped = false; + } + None +} + +fn find_last_unescaped(value: &str, needle: char) -> Option { + let chars = value.char_indices().collect::>(); + for (pos, (idx, ch)) in chars.iter().enumerate().rev() { + if *ch != needle { + continue; + } + let mut backslashes = 0; + for (_, prev) in chars[..pos].iter().rev() { + if *prev == '\\' { + backslashes += 1; + } else { + break; + } + } + if backslashes % 2 == 0 { + return Some(*idx); + } + } + None +} + +fn extract_permission_subject(input: &str) -> Option { + let parsed = serde_json::from_str::(input).ok(); + if let Some(Value::Object(object)) = parsed { + for key in [ + "command", + "path", + "file_path", + "filePath", + "notebook_path", + "notebookPath", + "url", + "pattern", + "code", + "message", + ] { + if let Some(value) = object.get(key).and_then(Value::as_str) { + return Some(value.to_string()); + } + } + } + + (!input.trim().is_empty()).then(|| input.to_string()) } #[cfg(test)] mod tests { use super::{ - PermissionMode, PermissionOutcome, PermissionPolicy, PermissionPromptDecision, - PermissionPrompter, PermissionRequest, + PermissionContext, PermissionMode, PermissionOutcome, PermissionOverride, PermissionPolicy, + PermissionPromptDecision, PermissionPrompter, PermissionRequest, }; + use crate::config::RuntimePermissionRuleConfig; struct RecordingPrompter { seen: Vec, @@ -229,4 +556,120 @@ mod tests { PermissionOutcome::Deny { reason } if reason == "not now" )); } + + #[test] + fn applies_rule_based_denials_and_allows() { + let rules = RuntimePermissionRuleConfig::new( + vec!["bash(git:*)".to_string()], + vec!["bash(rm -rf:*)".to_string()], + Vec::new(), + ); + let policy = PermissionPolicy::new(PermissionMode::ReadOnly) + .with_tool_requirement("bash", PermissionMode::DangerFullAccess) + .with_permission_rules(&rules); + + assert_eq!( + policy.authorize("bash", r#"{"command":"git status"}"#, None), + PermissionOutcome::Allow + ); + assert!(matches!( + policy.authorize("bash", r#"{"command":"rm -rf /tmp/x"}"#, None), + PermissionOutcome::Deny { reason } if reason.contains("denied by rule") + )); + } + + #[test] + fn ask_rules_force_prompt_even_when_mode_allows() { + let rules = RuntimePermissionRuleConfig::new( + Vec::new(), + Vec::new(), + vec!["bash(git:*)".to_string()], + ); + let policy = PermissionPolicy::new(PermissionMode::DangerFullAccess) + .with_tool_requirement("bash", PermissionMode::DangerFullAccess) + .with_permission_rules(&rules); + let mut prompter = RecordingPrompter { + seen: Vec::new(), + allow: true, + }; + + let outcome = policy.authorize("bash", r#"{"command":"git status"}"#, Some(&mut prompter)); + + assert_eq!(outcome, PermissionOutcome::Allow); + assert_eq!(prompter.seen.len(), 1); + assert!(prompter.seen[0] + .reason + .as_deref() + .is_some_and(|reason| reason.contains("ask rule"))); + } + + #[test] + fn hook_allow_still_respects_ask_rules() { + let rules = RuntimePermissionRuleConfig::new( + Vec::new(), + Vec::new(), + vec!["bash(git:*)".to_string()], + ); + let policy = PermissionPolicy::new(PermissionMode::ReadOnly) + .with_tool_requirement("bash", PermissionMode::DangerFullAccess) + .with_permission_rules(&rules); + let context = PermissionContext::new( + Some(PermissionOverride::Allow), + Some("hook approved".to_string()), + ); + let mut prompter = RecordingPrompter { + seen: Vec::new(), + allow: true, + }; + + let outcome = policy.authorize_with_context( + "bash", + r#"{"command":"git status"}"#, + &context, + Some(&mut prompter), + ); + + assert_eq!(outcome, PermissionOutcome::Allow); + assert_eq!(prompter.seen.len(), 1); + } + + #[test] + fn hook_deny_short_circuits_permission_flow() { + let policy = PermissionPolicy::new(PermissionMode::DangerFullAccess) + .with_tool_requirement("bash", PermissionMode::DangerFullAccess); + let context = PermissionContext::new( + Some(PermissionOverride::Deny), + Some("blocked by hook".to_string()), + ); + + assert_eq!( + policy.authorize_with_context("bash", "{}", &context, None), + PermissionOutcome::Deny { + reason: "blocked by hook".to_string(), + } + ); + } + + #[test] + fn hook_ask_forces_prompt() { + let policy = PermissionPolicy::new(PermissionMode::DangerFullAccess) + .with_tool_requirement("bash", PermissionMode::DangerFullAccess); + let context = PermissionContext::new( + Some(PermissionOverride::Ask), + Some("hook requested confirmation".to_string()), + ); + let mut prompter = RecordingPrompter { + seen: Vec::new(), + allow: true, + }; + + let outcome = policy.authorize_with_context("bash", "{}", &context, Some(&mut prompter)); + + assert_eq!(outcome, PermissionOutcome::Allow); + assert_eq!(prompter.seen.len(), 1); + assert_eq!( + prompter.seen[0].reason.as_deref(), + Some("hook requested confirmation") + ); + } } diff --git a/rust/crates/rusty-claude-cli/Cargo.toml b/rust/crates/rusty-claude-cli/Cargo.toml index 242ec0f..07744bb 100644 --- a/rust/crates/rusty-claude-cli/Cargo.toml +++ b/rust/crates/rusty-claude-cli/Cargo.toml @@ -20,7 +20,7 @@ runtime = { path = "../runtime" } plugins = { path = "../plugins" } serde_json = "1" syntect = "5" -tokio = { version = "1", features = ["rt-multi-thread", "time"] } +tokio = { version = "1", features = ["rt-multi-thread", "signal", "time"] } tools = { path = "../tools" } [lints] diff --git a/rust/crates/rusty-claude-cli/src/main.rs b/rust/crates/rusty-claude-cli/src/main.rs index fc51cb2..0b1ac0e 100644 --- a/rust/crates/rusty-claude-cli/src/main.rs +++ b/rust/crates/rusty-claude-cli/src/main.rs @@ -10,9 +10,9 @@ use std::io::{self, Read, Write}; use std::net::TcpListener; use std::path::{Path, PathBuf}; use std::process::Command; -use std::sync::mpsc::{self, RecvTimeoutError}; +use std::sync::mpsc::{self, Receiver, RecvTimeoutError, Sender}; use std::sync::{Arc, Mutex}; -use std::thread; +use std::thread::{self, JoinHandle}; use std::time::{Duration, Instant, SystemTime, UNIX_EPOCH}; use api::{ @@ -972,6 +972,61 @@ struct LiveCli { session: SessionHandle, } +struct HookAbortMonitor { + stop_tx: Option>, + join_handle: Option>, +} + +impl HookAbortMonitor { + fn spawn(abort_signal: runtime::HookAbortSignal) -> Self { + Self::spawn_with_waiter(abort_signal, move |stop_rx, abort_signal| { + let Ok(runtime) = tokio::runtime::Builder::new_current_thread() + .enable_all() + .build() + else { + return; + }; + + runtime.block_on(async move { + let wait_for_stop = tokio::task::spawn_blocking(move || { + let _ = stop_rx.recv(); + }); + + tokio::select! { + result = tokio::signal::ctrl_c() => { + if result.is_ok() { + abort_signal.abort(); + } + } + _ = wait_for_stop => {} + } + }); + }) + } + + fn spawn_with_waiter(abort_signal: runtime::HookAbortSignal, wait_for_interrupt: F) -> Self + where + F: FnOnce(Receiver<()>, runtime::HookAbortSignal) + Send + 'static, + { + let (stop_tx, stop_rx) = mpsc::channel(); + let join_handle = thread::spawn(move || wait_for_interrupt(stop_rx, abort_signal)); + + Self { + stop_tx: Some(stop_tx), + join_handle: Some(join_handle), + } + } + + fn stop(mut self) { + if let Some(stop_tx) = self.stop_tx.take() { + let _ = stop_tx.send(()); + } + if let Some(join_handle) = self.join_handle.take() { + let _ = join_handle.join(); + } + } +} + impl LiveCli { fn new( model: String, @@ -1028,7 +1083,35 @@ impl LiveCli { ) } + fn prepare_turn_runtime( + &self, + emit_output: bool, + ) -> Result< + ( + ConversationRuntime, + HookAbortMonitor, + ), + Box, + > { + let hook_abort_signal = runtime::HookAbortSignal::new(); + let runtime = build_runtime( + self.runtime.session().clone(), + self.model.clone(), + self.system_prompt.clone(), + true, + emit_output, + self.allowed_tools.clone(), + self.permission_mode, + None, + )? + .with_hook_abort_signal(hook_abort_signal.clone()); + let hook_abort_monitor = HookAbortMonitor::spawn(hook_abort_signal); + + Ok((runtime, hook_abort_monitor)) + } + fn run_turn(&mut self, input: &str) -> Result<(), Box> { + let (mut runtime, hook_abort_monitor) = self.prepare_turn_runtime(true)?; let mut spinner = Spinner::new(); let mut stdout = io::stdout(); spinner.tick( @@ -1037,7 +1120,9 @@ impl LiveCli { &mut stdout, )?; let mut permission_prompter = CliPermissionPrompter::new(self.permission_mode); - let result = self.runtime.run_turn(input, Some(&mut permission_prompter)); + let result = runtime.run_turn(input, Some(&mut permission_prompter)); + hook_abort_monitor.stop(); + self.runtime = runtime; match result { Ok(summary) => { spinner.finish( @@ -1078,19 +1163,11 @@ impl LiveCli { } fn run_prompt_json(&mut self, input: &str) -> Result<(), Box> { - let session = self.runtime.session().clone(); - let mut runtime = build_runtime( - session, - self.model.clone(), - self.system_prompt.clone(), - true, - false, - self.allowed_tools.clone(), - self.permission_mode, - None, - )?; + let (mut runtime, hook_abort_monitor) = self.prepare_turn_runtime(false)?; let mut permission_prompter = CliPermissionPrompter::new(self.permission_mode); - let summary = runtime.run_turn(input, Some(&mut permission_prompter))?; + let result = runtime.run_turn(input, Some(&mut permission_prompter)); + hook_abort_monitor.stop(); + let summary = result?; self.runtime = runtime; self.persist_session()?; println!( @@ -2756,7 +2833,7 @@ fn build_runtime( ) -> Result, Box> { let (feature_config, plugin_registry, tool_registry) = build_runtime_plugin_state()?; - Ok(ConversationRuntime::new_with_plugins( + let mut runtime = ConversationRuntime::new_with_plugins( session, AnthropicRuntimeClient::new( model, @@ -2767,11 +2844,48 @@ fn build_runtime( progress_reporter, )?, CliToolExecutor::new(allowed_tools.clone(), emit_output, tool_registry.clone()), - permission_policy(permission_mode, &tool_registry), + permission_policy(permission_mode, &feature_config, &tool_registry), system_prompt, feature_config, plugin_registry, - )?) + )?; + if emit_output { + runtime = runtime.with_hook_progress_reporter(Box::new(CliHookProgressReporter)); + } + Ok(runtime) +} + +struct CliHookProgressReporter; + +impl runtime::HookProgressReporter for CliHookProgressReporter { + fn on_event(&mut self, event: &runtime::HookProgressEvent) { + match event { + runtime::HookProgressEvent::Started { + event, + tool_name, + command, + } => eprintln!( + "[hook {event_name}] {tool_name}: {command}", + event_name = event.as_str() + ), + runtime::HookProgressEvent::Completed { + event, + tool_name, + command, + } => eprintln!( + "[hook done {event_name}] {tool_name}: {command}", + event_name = event.as_str() + ), + runtime::HookProgressEvent::Cancelled { + event, + tool_name, + command, + } => eprintln!( + "[hook cancelled {event_name}] {tool_name}: {command}", + event_name = event.as_str() + ), + } + } } struct CliPermissionPrompter { @@ -3621,9 +3735,13 @@ impl ToolExecutor for CliToolExecutor { } } -fn permission_policy(mode: PermissionMode, tool_registry: &GlobalToolRegistry) -> PermissionPolicy { +fn permission_policy( + mode: PermissionMode, + feature_config: &runtime::RuntimeFeatureConfig, + tool_registry: &GlobalToolRegistry, +) -> PermissionPolicy { tool_registry.permission_specs(None).into_iter().fold( - PermissionPolicy::new(mode), + PermissionPolicy::new(mode).with_permission_rules(feature_config.permission_rules()), |policy, (name, required_permission)| { policy.with_tool_requirement(name, required_permission) }, @@ -3773,14 +3891,18 @@ mod tests { normalize_permission_mode, parse_args, parse_git_status_metadata, permission_policy, print_help_to, push_output_block, render_config_report, render_memory_report, render_repl_help, resolve_model_alias, response_to_events, resume_supported_slash_commands, - status_context, CliAction, CliOutputFormat, InternalPromptProgressEvent, + status_context, CliAction, CliOutputFormat, HookAbortMonitor, InternalPromptProgressEvent, InternalPromptProgressState, SlashCommand, StatusUsage, DEFAULT_MODEL, }; use api::{MessageResponse, OutputContentBlock, Usage}; use plugins::{PluginTool, PluginToolDefinition, PluginToolPermission}; - use runtime::{AssistantEvent, ContentBlock, ConversationMessage, MessageRole, PermissionMode}; + use runtime::{ + AssistantEvent, ContentBlock, ConversationMessage, HookAbortSignal, MessageRole, + PermissionMode, + }; use serde_json::json; use std::path::PathBuf; + use std::sync::mpsc; use std::time::Duration; use tools::GlobalToolRegistry; @@ -4041,7 +4163,11 @@ mod tests { #[test] fn permission_policy_uses_plugin_tool_permissions() { - let policy = permission_policy(PermissionMode::ReadOnly, ®istry_with_plugin_tool()); + let policy = permission_policy( + PermissionMode::ReadOnly, + &runtime::RuntimeFeatureConfig::default(), + ®istry_with_plugin_tool(), + ); let required = policy.required_mode_for("plugin_echo"); assert_eq!(required, PermissionMode::WorkspaceWrite); } @@ -4678,4 +4804,43 @@ mod tests { )); assert!(!String::from_utf8(out).expect("utf8").contains("step 1")); } + + #[test] + fn hook_abort_monitor_stops_without_aborting() { + let abort_signal = HookAbortSignal::new(); + let (ready_tx, ready_rx) = mpsc::channel(); + let monitor = HookAbortMonitor::spawn_with_waiter( + abort_signal.clone(), + move |stop_rx, abort_signal| { + ready_tx.send(()).expect("ready signal"); + let _ = stop_rx.recv(); + assert!(!abort_signal.is_aborted()); + }, + ); + + ready_rx.recv().expect("waiter should be ready"); + monitor.stop(); + + assert!(!abort_signal.is_aborted()); + } + + #[test] + fn hook_abort_monitor_propagates_interrupt() { + let abort_signal = HookAbortSignal::new(); + let (done_tx, done_rx) = mpsc::channel(); + let monitor = HookAbortMonitor::spawn_with_waiter( + abort_signal.clone(), + move |_stop_rx, abort_signal| { + abort_signal.abort(); + done_tx.send(()).expect("done signal"); + }, + ); + + done_rx + .recv_timeout(Duration::from_secs(1)) + .expect("interrupt should complete"); + monitor.stop(); + + assert!(abort_signal.is_aborted()); + } } diff --git a/rust/crates/rusty-claude-cli/src/render.rs b/rust/crates/rusty-claude-cli/src/render.rs index 01751fd..d8d8796 100644 --- a/rust/crates/rusty-claude-cli/src/render.rs +++ b/rust/crates/rusty-claude-cli/src/render.rs @@ -286,7 +286,7 @@ impl TerminalRenderer { ) { match event { Event::Start(Tag::Heading { level, .. }) => { - self.start_heading(state, level as u8, output); + Self::start_heading(state, level as u8, output); } Event::End(TagEnd::Paragraph) => output.push_str("\n\n"), Event::Start(Tag::BlockQuote(..)) => self.start_quote(state, output), @@ -426,8 +426,7 @@ impl TerminalRenderer { } } - #[allow(clippy::unused_self)] - fn start_heading(&self, state: &mut RenderState, level: u8, output: &mut String) { + fn start_heading(state: &mut RenderState, level: u8, output: &mut String) { state.heading_level = Some(level); if !output.is_empty() { output.push('\n');