diff --git a/rust/crates/runtime/src/lib.rs b/rust/crates/runtime/src/lib.rs index 62e78d8..75c1b9f 100644 --- a/rust/crates/runtime/src/lib.rs +++ b/rust/crates/runtime/src/lib.rs @@ -16,6 +16,7 @@ mod remote; pub mod sandbox; mod session; mod sse; +pub mod task_registry; mod usage; pub use bash::{execute_bash, BashCommandInput, BashCommandOutput}; diff --git a/rust/crates/runtime/src/task_registry.rs b/rust/crates/runtime/src/task_registry.rs new file mode 100644 index 0000000..861f9bc --- /dev/null +++ b/rust/crates/runtime/src/task_registry.rs @@ -0,0 +1,335 @@ +//! In-memory task registry for sub-agent task lifecycle management. +//! +//! Provides create, get, list, stop, update, and output operations +//! matching the upstream TaskCreate/TaskGet/TaskList/TaskStop/TaskUpdate/TaskOutput +//! tool surface. + +use std::collections::HashMap; +use std::sync::{Arc, Mutex}; +use std::time::{SystemTime, UNIX_EPOCH}; + +use serde::{Deserialize, Serialize}; + +/// Current status of a managed task. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)] +#[serde(rename_all = "snake_case")] +pub enum TaskStatus { + Created, + Running, + Completed, + Failed, + Stopped, +} + +impl std::fmt::Display for TaskStatus { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Self::Created => write!(f, "created"), + Self::Running => write!(f, "running"), + Self::Completed => write!(f, "completed"), + Self::Failed => write!(f, "failed"), + Self::Stopped => write!(f, "stopped"), + } + } +} + +/// A single managed task entry. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct Task { + pub task_id: String, + pub prompt: String, + pub description: Option, + pub status: TaskStatus, + pub created_at: u64, + pub updated_at: u64, + pub messages: Vec, + pub output: String, + pub team_id: Option, +} + +/// A message exchanged with a running task. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct TaskMessage { + pub role: String, + pub content: String, + pub timestamp: u64, +} + +/// Thread-safe task registry. +#[derive(Debug, Clone, Default)] +pub struct TaskRegistry { + inner: Arc>, +} + +#[derive(Debug, Default)] +struct RegistryInner { + tasks: HashMap, + counter: u64, +} + +fn now_secs() -> u64 { + SystemTime::now() + .duration_since(UNIX_EPOCH) + .unwrap_or_default() + .as_secs() +} + +impl TaskRegistry { + /// Create a new empty registry. + #[must_use] + pub fn new() -> Self { + Self::default() + } + + /// Create a new task and return its ID. + pub fn create(&self, prompt: &str, description: Option<&str>) -> Task { + let mut inner = self.inner.lock().expect("registry lock poisoned"); + inner.counter += 1; + let ts = now_secs(); + let task_id = format!("task_{:08x}_{}", ts, inner.counter); + let task = Task { + task_id: task_id.clone(), + prompt: prompt.to_owned(), + description: description.map(str::to_owned), + status: TaskStatus::Created, + created_at: ts, + updated_at: ts, + messages: Vec::new(), + output: String::new(), + team_id: None, + }; + inner.tasks.insert(task_id, task.clone()); + task + } + + /// Look up a task by ID. + pub fn get(&self, task_id: &str) -> Option { + let inner = self.inner.lock().expect("registry lock poisoned"); + inner.tasks.get(task_id).cloned() + } + + /// List all tasks, optionally filtered by status. + pub fn list(&self, status_filter: Option) -> Vec { + let inner = self.inner.lock().expect("registry lock poisoned"); + inner + .tasks + .values() + .filter(|t| status_filter.map_or(true, |s| t.status == s)) + .cloned() + .collect() + } + + /// Mark a task as stopped. + pub fn stop(&self, task_id: &str) -> Result { + let mut inner = self.inner.lock().expect("registry lock poisoned"); + let task = inner + .tasks + .get_mut(task_id) + .ok_or_else(|| format!("task not found: {task_id}"))?; + + match task.status { + TaskStatus::Completed | TaskStatus::Failed | TaskStatus::Stopped => { + return Err(format!( + "task {task_id} is already in terminal state: {}", + task.status + )); + } + _ => {} + } + + task.status = TaskStatus::Stopped; + task.updated_at = now_secs(); + Ok(task.clone()) + } + + /// Send a message to a task, updating its state. + pub fn update(&self, task_id: &str, message: &str) -> Result { + let mut inner = self.inner.lock().expect("registry lock poisoned"); + let task = inner + .tasks + .get_mut(task_id) + .ok_or_else(|| format!("task not found: {task_id}"))?; + + task.messages.push(TaskMessage { + role: String::from("user"), + content: message.to_owned(), + timestamp: now_secs(), + }); + task.updated_at = now_secs(); + Ok(task.clone()) + } + + /// Get the accumulated output of a task. + pub fn output(&self, task_id: &str) -> Result { + let inner = self.inner.lock().expect("registry lock poisoned"); + let task = inner + .tasks + .get(task_id) + .ok_or_else(|| format!("task not found: {task_id}"))?; + Ok(task.output.clone()) + } + + /// Append output to a task (used by the task executor). + pub fn append_output(&self, task_id: &str, output: &str) -> Result<(), String> { + let mut inner = self.inner.lock().expect("registry lock poisoned"); + let task = inner + .tasks + .get_mut(task_id) + .ok_or_else(|| format!("task not found: {task_id}"))?; + task.output.push_str(output); + task.updated_at = now_secs(); + Ok(()) + } + + /// Transition a task to a new status. + pub fn set_status(&self, task_id: &str, status: TaskStatus) -> Result<(), String> { + let mut inner = self.inner.lock().expect("registry lock poisoned"); + let task = inner + .tasks + .get_mut(task_id) + .ok_or_else(|| format!("task not found: {task_id}"))?; + task.status = status; + task.updated_at = now_secs(); + Ok(()) + } + + /// Assign a task to a team. + pub fn assign_team(&self, task_id: &str, team_id: &str) -> Result<(), String> { + let mut inner = self.inner.lock().expect("registry lock poisoned"); + let task = inner + .tasks + .get_mut(task_id) + .ok_or_else(|| format!("task not found: {task_id}"))?; + task.team_id = Some(team_id.to_owned()); + task.updated_at = now_secs(); + Ok(()) + } + + /// Remove a task from the registry. + pub fn remove(&self, task_id: &str) -> Option { + let mut inner = self.inner.lock().expect("registry lock poisoned"); + inner.tasks.remove(task_id) + } + + /// Number of tasks in the registry. + #[must_use] + pub fn len(&self) -> usize { + let inner = self.inner.lock().expect("registry lock poisoned"); + inner.tasks.len() + } + + /// Whether the registry has no tasks. + #[must_use] + pub fn is_empty(&self) -> bool { + self.len() == 0 + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn creates_and_retrieves_tasks() { + let registry = TaskRegistry::new(); + let task = registry.create("Do something", Some("A test task")); + assert_eq!(task.status, TaskStatus::Created); + assert_eq!(task.prompt, "Do something"); + assert_eq!(task.description.as_deref(), Some("A test task")); + + let fetched = registry.get(&task.task_id).expect("task should exist"); + assert_eq!(fetched.task_id, task.task_id); + } + + #[test] + fn lists_tasks_with_optional_filter() { + let registry = TaskRegistry::new(); + registry.create("Task A", None); + let task_b = registry.create("Task B", None); + registry + .set_status(&task_b.task_id, TaskStatus::Running) + .expect("set status should succeed"); + + let all = registry.list(None); + assert_eq!(all.len(), 2); + + let running = registry.list(Some(TaskStatus::Running)); + assert_eq!(running.len(), 1); + assert_eq!(running[0].task_id, task_b.task_id); + + let created = registry.list(Some(TaskStatus::Created)); + assert_eq!(created.len(), 1); + } + + #[test] + fn stops_running_task() { + let registry = TaskRegistry::new(); + let task = registry.create("Stoppable", None); + registry + .set_status(&task.task_id, TaskStatus::Running) + .unwrap(); + + let stopped = registry.stop(&task.task_id).expect("stop should succeed"); + assert_eq!(stopped.status, TaskStatus::Stopped); + + // Stopping again should fail + let result = registry.stop(&task.task_id); + assert!(result.is_err()); + } + + #[test] + fn updates_task_with_messages() { + let registry = TaskRegistry::new(); + let task = registry.create("Messageable", None); + let updated = registry + .update(&task.task_id, "Here's more context") + .expect("update should succeed"); + assert_eq!(updated.messages.len(), 1); + assert_eq!(updated.messages[0].content, "Here's more context"); + assert_eq!(updated.messages[0].role, "user"); + } + + #[test] + fn appends_and_retrieves_output() { + let registry = TaskRegistry::new(); + let task = registry.create("Output task", None); + registry + .append_output(&task.task_id, "line 1\n") + .expect("append should succeed"); + registry + .append_output(&task.task_id, "line 2\n") + .expect("append should succeed"); + + let output = registry.output(&task.task_id).expect("output should exist"); + assert_eq!(output, "line 1\nline 2\n"); + } + + #[test] + fn assigns_team_and_removes_task() { + let registry = TaskRegistry::new(); + let task = registry.create("Team task", None); + registry + .assign_team(&task.task_id, "team_abc") + .expect("assign should succeed"); + + let fetched = registry.get(&task.task_id).unwrap(); + assert_eq!(fetched.team_id.as_deref(), Some("team_abc")); + + let removed = registry.remove(&task.task_id); + assert!(removed.is_some()); + assert!(registry.get(&task.task_id).is_none()); + assert!(registry.is_empty()); + } + + #[test] + fn rejects_operations_on_missing_task() { + let registry = TaskRegistry::new(); + assert!(registry.stop("nonexistent").is_err()); + assert!(registry.update("nonexistent", "msg").is_err()); + assert!(registry.output("nonexistent").is_err()); + assert!(registry.append_output("nonexistent", "data").is_err()); + assert!(registry + .set_status("nonexistent", TaskStatus::Running) + .is_err()); + } +}