diff --git a/rust/crates/tools/src/lib.rs b/rust/crates/tools/src/lib.rs index 479438c..a9f65ad 100644 --- a/rust/crates/tools/src/lib.rs +++ b/rust/crates/tools/src/lib.rs @@ -11,14 +11,22 @@ use api::{ use plugins::PluginTool; use reqwest::blocking::Client; use runtime::{ - edit_file, execute_bash, glob_search, grep_search, load_system_prompt, read_file, write_file, - ApiClient, ApiRequest, AssistantEvent, BashCommandInput, ContentBlock, ConversationMessage, - ConversationRuntime, GrepSearchInput, MessageRole, PermissionMode, PermissionPolicy, - PromptCacheEvent, RuntimeError, Session, ToolError, ToolExecutor, + edit_file, execute_bash, glob_search, grep_search, load_system_prompt, read_file, + task_registry::TaskRegistry, write_file, ApiClient, ApiRequest, AssistantEvent, + BashCommandInput, ContentBlock, ConversationMessage, ConversationRuntime, GrepSearchInput, + MessageRole, PermissionMode, PermissionPolicy, PromptCacheEvent, RuntimeError, Session, + ToolError, ToolExecutor, }; use serde::{Deserialize, Serialize}; use serde_json::{json, Value}; +/// Global task registry shared across tool invocations within a session. +fn global_task_registry() -> &'static TaskRegistry { + use std::sync::OnceLock; + static REGISTRY: OnceLock = OnceLock::new(); + REGISTRY.get_or_init(TaskRegistry::new) +} + #[derive(Debug, Clone, PartialEq, Eq)] pub struct ToolManifestEntry { pub name: String, @@ -905,60 +913,96 @@ fn run_ask_user_question(input: AskUserQuestionInput) -> Result #[allow(clippy::needless_pass_by_value)] fn run_task_create(input: TaskCreateInput) -> Result { - let secs = std::time::SystemTime::now() - .duration_since(std::time::UNIX_EPOCH) - .unwrap_or_default() - .as_secs(); - let task_id = format!("task_{secs:08x}"); + let registry = global_task_registry(); + let task = registry.create(&input.prompt, input.description.as_deref()); to_pretty_json(json!({ - "task_id": task_id, - "status": "created", - "prompt": input.prompt, - "description": input.description + "task_id": task.task_id, + "status": task.status, + "prompt": task.prompt, + "description": task.description, + "created_at": task.created_at })) } #[allow(clippy::needless_pass_by_value)] fn run_task_get(input: TaskIdInput) -> Result { - to_pretty_json(json!({ - "task_id": input.task_id, - "status": "unknown", - "message": "Task runtime not yet implemented" - })) + let registry = global_task_registry(); + match registry.get(&input.task_id) { + Some(task) => to_pretty_json(json!({ + "task_id": task.task_id, + "status": task.status, + "prompt": task.prompt, + "description": task.description, + "created_at": task.created_at, + "updated_at": task.updated_at, + "messages": task.messages, + "team_id": task.team_id + })), + None => Err(format!("task not found: {}", input.task_id)), + } } fn run_task_list(_input: Value) -> Result { + let registry = global_task_registry(); + let tasks: Vec<_> = registry + .list(None) + .into_iter() + .map(|t| { + json!({ + "task_id": t.task_id, + "status": t.status, + "prompt": t.prompt, + "description": t.description, + "created_at": t.created_at, + "updated_at": t.updated_at, + "team_id": t.team_id + }) + }) + .collect(); to_pretty_json(json!({ - "tasks": [], - "message": "No tasks found" + "tasks": tasks, + "count": tasks.len() })) } #[allow(clippy::needless_pass_by_value)] fn run_task_stop(input: TaskIdInput) -> Result { - to_pretty_json(json!({ - "task_id": input.task_id, - "status": "stopped", - "message": "Task stop requested" - })) + let registry = global_task_registry(); + match registry.stop(&input.task_id) { + Ok(task) => to_pretty_json(json!({ + "task_id": task.task_id, + "status": task.status, + "message": "Task stopped" + })), + Err(e) => Err(e), + } } #[allow(clippy::needless_pass_by_value)] fn run_task_update(input: TaskUpdateInput) -> Result { - to_pretty_json(json!({ - "task_id": input.task_id, - "status": "updated", - "message": input.message - })) + let registry = global_task_registry(); + match registry.update(&input.task_id, &input.message) { + Ok(task) => to_pretty_json(json!({ + "task_id": task.task_id, + "status": task.status, + "message_count": task.messages.len(), + "last_message": input.message + })), + Err(e) => Err(e), + } } #[allow(clippy::needless_pass_by_value)] fn run_task_output(input: TaskIdInput) -> Result { - to_pretty_json(json!({ - "task_id": input.task_id, - "output": "", - "message": "No output available" - })) + let registry = global_task_registry(); + match registry.output(&input.task_id) { + Ok(output) => to_pretty_json(json!({ + "task_id": input.task_id, + "output": output, + "has_output": !output.is_empty() + })), + Err(e) => Err(e), + } } #[allow(clippy::needless_pass_by_value)]