mirror of
https://github.com/instructkr/claw-code.git
synced 2026-04-10 18:14:50 +08:00
Compare commits
64 Commits
feat/batch
...
feat/provi
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
5851f2dee8 | ||
|
|
8c6dfe57e6 | ||
|
|
eed57212bb | ||
|
|
3ac97e635e | ||
|
|
006f7d7ee6 | ||
|
|
82baaf3f22 | ||
|
|
c7b3296ef6 | ||
|
|
000aed4188 | ||
|
|
523ce7474a | ||
|
|
b513d6e462 | ||
|
|
c667d47c70 | ||
|
|
7546c1903d | ||
|
|
0530c509a3 | ||
|
|
eff0765167 | ||
|
|
aee5263aef | ||
|
|
9461522af5 | ||
|
|
c08f060ca1 | ||
|
|
cae11413dd | ||
|
|
60410b6c92 | ||
|
|
aa37dc6936 | ||
|
|
6ddfa78b7c | ||
|
|
bcdc52d72c | ||
|
|
dd97c49e6b | ||
|
|
5dfb1d7c2b | ||
|
|
fcb5d0c16a | ||
|
|
314f0c99fd | ||
|
|
469ae0179e | ||
|
|
092d8b6e21 | ||
|
|
b3ccd92d24 | ||
|
|
d71d109522 | ||
|
|
0f2f02af2d | ||
|
|
e51566c745 | ||
|
|
20f3a5932a | ||
|
|
28e6cc0965 | ||
|
|
f03b8dce17 | ||
|
|
ecdca49552 | ||
|
|
8cddbc6615 | ||
|
|
5c276c8e14 | ||
|
|
1f968b359f | ||
|
|
18d3c1918b | ||
|
|
8a4b613c39 | ||
|
|
82f2e8e92b | ||
|
|
8f4651a096 | ||
|
|
dab16c230a | ||
|
|
a46711779c | ||
|
|
ef0b870890 | ||
|
|
4557a81d2f | ||
|
|
86c3667836 | ||
|
|
260bac321f | ||
|
|
133ed4581e | ||
|
|
8663751650 | ||
|
|
90f2461f75 | ||
|
|
0d8fd51a6c | ||
|
|
5bcbc86a2b | ||
|
|
d509f16b5a | ||
|
|
d089d1a9cc | ||
|
|
6a6c5acb02 | ||
|
|
9105e0c656 | ||
|
|
b8f76442e2 | ||
|
|
b216f9ce05 | ||
|
|
4be4b46bd9 | ||
|
|
506ff55e53 | ||
|
|
65f4c3ad82 | ||
|
|
700534de41 |
76
ROADMAP.md
76
ROADMAP.md
@@ -308,7 +308,7 @@ Priority order: P0 = blocks CI/green state, P1 = blocks integration wiring, P2 =
|
||||
19. **Subcommand help falls through into runtime/API path** — **done**: `claw doctor --help`, `claw status --help`, `claw sandbox --help`, and nested `mcp`/`skills` help are now intercepted locally without runtime/provider startup, with regression tests covering the direct CLI paths.
|
||||
20. **Session state classification gap (working vs blocked vs finished vs truly stale)** — **done**: agent manifests now derive machine states such as `working`, `blocked_background_job`, `blocked_merge_conflict`, `degraded_mcp`, `interrupted_transport`, `finished_pending_report`, and `finished_cleanable`, and terminal-state persistence records commit provenance plus derived state so downstream monitoring can distinguish quiet progress from truly idle sessions.
|
||||
21. **Resumed `/status` JSON parity gap** — dogfooding shows fresh `claw status --output-format json` now emits structured JSON, but resumed slash-command status still leaks through a text-shaped path in at least one dispatch path. Local CI-equivalent repro fails `rust/crates/rusty-claude-cli/tests/resume_slash_commands.rs::resumed_status_command_emits_structured_json_when_requested` with `expected value at line 1 column 1`, so resumed automation can receive text where JSON was explicitly requested. **Action:** unify fresh vs resumed `/status` rendering through one output-format contract and add regression coverage so resumed JSON output is guaranteed valid.
|
||||
22. **Opaque failure surface for session/runtime crashes** — repeated dogfood-facing failures can currently collapse to generic wrappers like `Something went wrong while processing your request. Please try again, or use /new to start a fresh session.` without exposing whether the fault was provider auth, session corruption, slash-command dispatch, render failure, or transport/runtime panic. This blocks fast self-recovery and turns actionable clawability bugs into blind retries. **Action:** preserve a short user-safe failure class (`provider_auth`, `session_load`, `command_dispatch`, `render`, `runtime_panic`, etc.), attach a local trace/session id, and ensure operators can jump from the chat-visible error to the exact failure log quickly.
|
||||
22. **Opaque failure surface for session/runtime crashes** — **done**: `safe_failure_class()` in `error.rs` classifies all API errors into 8 user-safe classes (`provider_auth`, `provider_internal`, `provider_retry_exhausted`, `provider_rate_limit`, `provider_transport`, `provider_error`, `context_window`, `runtime_io`). `format_user_visible_api_error` in `main.rs` attaches session ID + request trace ID to every user-visible error. Coverage in `opaque_provider_wrapper_surfaces_failure_class_session_and_trace` and 3 related tests.
|
||||
23. **`doctor --output-format json` check-level structure gap** — **done**: `claw doctor --output-format json` now keeps the human-readable `message`/`report` while also emitting structured per-check diagnostics (`name`, `status`, `summary`, `details`, plus typed fields like workspace paths and sandbox fallback data), with regression coverage in `output_format_contract.rs`.
|
||||
24. **Plugin lifecycle init/shutdown test flakes under workspace-parallel execution** — dogfooding surfaced that `build_runtime_runs_plugin_lifecycle_init_and_shutdown` can fail under `cargo test --workspace` while passing in isolation because sibling tests race on tempdir-backed shell init script paths. This is test brittleness rather than a code-path regression, but it still destabilizes CI confidence and wastes diagnosis cycles. **Action:** isolate temp resources per test robustly (unique dirs + no shared cwd assumptions), audit cleanup timing, and add a regression guard so the plugin lifecycle test remains stable under parallel workspace execution.
|
||||
26. **Resumed local-command JSON parity gap** — **done**: direct `claw --output-format json` already had structured renderers for `sandbox`, `mcp`, `skills`, `version`, and `init`, but resumed `claw --output-format json --resume <session> /…` paths still fell back to prose because resumed slash dispatch only emitted JSON for `/status`. Resumed `/sandbox`, `/mcp`, `/skills`, `/version`, and `/init` now reuse the same JSON envelopes as their direct CLI counterparts, with regression coverage in `rust/crates/rusty-claude-cli/tests/resume_slash_commands.rs` and `rust/crates/rusty-claude-cli/tests/output_format_contract.rs`.
|
||||
@@ -385,3 +385,77 @@ to:
|
||||
- a **claw-native execution runtime**
|
||||
- an **event-native orchestration substrate**
|
||||
- a **plugin/hook-first autonomous coding harness**
|
||||
|
||||
## Deployment Architecture Gap (filed from dogfood 2026-04-08)
|
||||
|
||||
### WorkerState is in the runtime; /state is NOT in opencode serve
|
||||
|
||||
**Root cause discovered during batch 8 dogfood.**
|
||||
|
||||
`worker_boot.rs` has a solid `WorkerStatus` state machine (`Spawning → TrustRequired → ReadyForPrompt → Running → Finished/Failed`). It is exported from `runtime/src/lib.rs` as a public API. But claw-code is a **plugin** loaded inside the `opencode` binary — it cannot add HTTP routes to `opencode serve`. The HTTP server is 100% owned by the upstream opencode process (v1.3.15).
|
||||
|
||||
**Impact:** There is no way to `curl localhost:4710/state` and get back a JSON `WorkerStatus`. Any such endpoint would require either:
|
||||
1. Upstreaming a `/state` route into opencode's HTTP server (requires a PR to sst/opencode), or
|
||||
2. Writing a sidecar HTTP process that queries the `WorkerRegistry` in-process (possible but fragile), or
|
||||
3. Writing `WorkerStatus` to a well-known file path (`.claw/worker-state.json`) that an external observer can poll.
|
||||
|
||||
**Recommended path:** Option 3 — emit `WorkerStatus` transitions to `.claw/worker-state.json` on every state change. This is purely within claw-code's plugin scope, requires no upstream changes, and gives clawhip a file it can poll to distinguish a truly stalled worker from a quiet-but-progressing one.
|
||||
|
||||
**Action item:** Wire `WorkerRegistry::transition()` to atomically write `.claw/worker-state.json` on every state transition. Add a `claw state` CLI subcommand that reads and prints this file. Add regression test.
|
||||
|
||||
**Prior session note:** A previous session summary claimed commit `0984cca` landed a `/state` HTTP endpoint via axum. This was incorrect — no such commit exists on main, axum is not a dependency, and the HTTP server is not ours. The actual work that exists: `worker_boot.rs` with `WorkerStatus` enum + `WorkerRegistry`, fully wired into `runtime/src/lib.rs` as public exports.
|
||||
|
||||
## Startup Friction Gap: No Default trusted_roots in Settings (filed 2026-04-08)
|
||||
|
||||
### Every lane starts with manual trust babysitting unless caller explicitly passes roots
|
||||
|
||||
**Root cause discovered during direct dogfood of WorkerCreate tool.**
|
||||
|
||||
`WorkerCreate` accepts a `trusted_roots: Vec<String>` parameter. If the caller omits it (or passes `[]`), every new worker immediately enters `TrustRequired` and stalls — requiring manual intervention to advance to `ReadyForPrompt`. There is no mechanism to configure a default allowlist in `settings.json` or `.claw/settings.json`.
|
||||
|
||||
**Impact:** Batch tooling (clawhip, lane orchestrators) must pass `trusted_roots` explicitly on every `WorkerCreate` call. If a batch script forgets the field, all workers in that batch stall silently at `trust_required`. This was the root cause of several "batch 8 lanes not advancing" incidents.
|
||||
|
||||
**Recommended fix:**
|
||||
1. Add a `trusted_roots` field to `RuntimeConfig` (or a nested `[trust]` table), loaded via `ConfigLoader`.
|
||||
2. In `WorkerRegistry::spawn_worker()`, merge config-level `trusted_roots` with any per-call overrides.
|
||||
3. Default: empty list (safest). Users opt in by adding their repo paths to settings.
|
||||
4. Update `config_validate` schema with the new field.
|
||||
|
||||
**Action item:** Wire `RuntimeConfig::trusted_roots()` → `WorkerRegistry::spawn_worker()` default. Cover with test: config with `trusted_roots = ["/tmp"]` → spawning worker in `/tmp/x` auto-resolves trust without caller passing the field.
|
||||
|
||||
## Observability Transport Decision (filed 2026-04-08)
|
||||
|
||||
### Canonical state surface: CLI/file-based. HTTP endpoint deferred.
|
||||
|
||||
**Decision:** `claw state` reading `.claw/worker-state.json` is the **blessed observability contract** for clawhip and downstream tooling. This is not a stepping-stone — it is the supported surface. Build against it.
|
||||
|
||||
**Rationale:**
|
||||
- claw-code is a plugin running inside the opencode binary. It cannot add HTTP routes to `opencode serve` — that server belongs to upstream sst/opencode.
|
||||
- The file-based surface is fully within plugin scope: `emit_state_file()` in `worker_boot.rs` writes atomically on every `WorkerStatus` transition.
|
||||
- `claw state --output-format json` gives clawhip everything it needs: `status`, `is_ready`, `seconds_since_update`, `trust_gate_cleared`, `last_event`, `updated_at`.
|
||||
- Polling a local file has lower latency and fewer failure modes than an HTTP round-trip to a sidecar.
|
||||
- An HTTP state endpoint would require either (a) upstreaming a route to sst/opencode — a multi-week PR cycle with no guarantee of acceptance — or (b) a sidecar process that queries `WorkerRegistry` in-process, which is fragile and adds an extra failure domain.
|
||||
|
||||
**What downstream tooling (clawhip) should do:**
|
||||
1. After `WorkerCreate`, poll `.claw/worker-state.json` (or run `claw state --output-format json`) in the worker's CWD at whatever interval makes sense (e.g. 5s).
|
||||
2. Trust `seconds_since_update > 60` in `trust_required` status as the stall signal.
|
||||
3. Call `WorkerResolveTrust` tool to unblock, or `WorkerRestart` to reset.
|
||||
|
||||
**HTTP endpoint tracking:** Not scheduled. If a concrete use case emerges that file polling cannot serve (e.g. remote workers over a network boundary), open a new issue to upstream a `/worker/state` route to sst/opencode at that time. Until then: file/CLI is canonical.
|
||||
|
||||
## Provider Routing: Model-Name Prefix Must Win Over Env-Var Presence (fixed 2026-04-08, `0530c50`)
|
||||
|
||||
### `openai/gpt-4.1-mini` was silently misrouted to Anthropic when ANTHROPIC_API_KEY was set
|
||||
|
||||
**Root cause:** `metadata_for_model` returned `None` for any model not matching `claude` or `grok` prefix.
|
||||
`detect_provider_kind` then fell through to auth-sniffer order: first `has_auth_from_env_or_saved()` (Anthropic), then `OPENAI_API_KEY`, then `XAI_API_KEY`.
|
||||
|
||||
If `ANTHROPIC_API_KEY` was present in the environment (e.g. user has both Anthropic and OpenRouter configured), any unknown model — including explicitly namespaced ones like `openai/gpt-4.1-mini` — was silently routed to the Anthropic client, which then failed with `missing Anthropic credentials` or a confusing 402/auth error rather than routing to OpenAI-compatible.
|
||||
|
||||
**Fix:** Added explicit prefix checks in `metadata_for_model`:
|
||||
- `openai/` prefix → `ProviderKind::OpenAi`
|
||||
- `gpt-` prefix → `ProviderKind::OpenAi`
|
||||
|
||||
Model name prefix now wins unconditionally over env-var presence. Regression test locked in: `providers::tests::openai_namespaced_model_routes_to_openai_not_anthropic`.
|
||||
|
||||
**Lesson:** Auth-sniffer fallback order is fragile. Any new provider added in the future should be registered in `metadata_for_model` via a model-name prefix, not left to env-var order. This is the canonical extension point.
|
||||
|
||||
127
USAGE.md
127
USAGE.md
@@ -153,6 +153,133 @@ cd rust
|
||||
./target/debug/claw --model "openai/gpt-4.1-mini" prompt "summarize this repository in one sentence"
|
||||
```
|
||||
|
||||
### Alibaba DashScope (Qwen)
|
||||
|
||||
For Qwen models via Alibaba's native DashScope API (higher rate limits than OpenRouter):
|
||||
|
||||
```bash
|
||||
export DASHSCOPE_API_KEY="sk-..."
|
||||
|
||||
cd rust
|
||||
./target/debug/claw --model "qwen/qwen-max" prompt "hello"
|
||||
# or bare:
|
||||
./target/debug/claw --model "qwen-plus" prompt "hello"
|
||||
```
|
||||
|
||||
Model names starting with `qwen/` or `qwen-` are automatically routed to the DashScope compatible-mode endpoint (`https://dashscope.aliyuncs.com/compatible-mode/v1`). You do **not** need to set `OPENAI_BASE_URL` or unset `ANTHROPIC_API_KEY` — the model prefix wins over the ambient credential sniffer.
|
||||
|
||||
Reasoning variants (`qwen-qwq-*`, `qwq-*`, `*-thinking`) automatically strip `temperature`/`top_p`/`frequency_penalty`/`presence_penalty` before the request hits the wire (these params are rejected by reasoning models).
|
||||
|
||||
## Supported Providers & Models
|
||||
|
||||
`claw` has three built-in provider backends. The provider is selected automatically based on the model name, falling back to whichever credential is present in the environment.
|
||||
|
||||
### Provider matrix
|
||||
|
||||
| Provider | Protocol | Auth env var(s) | Base URL env var | Default base URL |
|
||||
|---|---|---|---|---|
|
||||
| **Anthropic** (direct) | Anthropic Messages API | `ANTHROPIC_API_KEY` or `ANTHROPIC_AUTH_TOKEN` or OAuth (`claw login`) | `ANTHROPIC_BASE_URL` | `https://api.anthropic.com` |
|
||||
| **xAI** | OpenAI-compatible | `XAI_API_KEY` | `XAI_BASE_URL` | `https://api.x.ai/v1` |
|
||||
| **OpenAI-compatible** | OpenAI Chat Completions | `OPENAI_API_KEY` | `OPENAI_BASE_URL` | `https://api.openai.com/v1` |
|
||||
| **DashScope** (Alibaba) | OpenAI-compatible | `DASHSCOPE_API_KEY` | `DASHSCOPE_BASE_URL` | `https://dashscope.aliyuncs.com/compatible-mode/v1` |
|
||||
|
||||
The OpenAI-compatible backend also serves as the gateway for **OpenRouter**, **Ollama**, and any other service that speaks the OpenAI `/v1/chat/completions` wire format — just point `OPENAI_BASE_URL` at the service.
|
||||
|
||||
**Model-name prefix routing:** If a model name starts with `openai/`, `gpt-`, `qwen/`, or `qwen-`, the provider is selected by the prefix regardless of which env vars are set. This prevents accidental misrouting to Anthropic when multiple credentials exist in the environment.
|
||||
|
||||
### Tested models and aliases
|
||||
|
||||
These are the models registered in the built-in alias table with known token limits:
|
||||
|
||||
| Alias | Resolved model name | Provider | Max output tokens | Context window |
|
||||
|---|---|---|---|---|
|
||||
| `opus` | `claude-opus-4-6` | Anthropic | 32 000 | 200 000 |
|
||||
| `sonnet` | `claude-sonnet-4-6` | Anthropic | 64 000 | 200 000 |
|
||||
| `haiku` | `claude-haiku-4-5-20251213` | Anthropic | 64 000 | 200 000 |
|
||||
| `grok` / `grok-3` | `grok-3` | xAI | 64 000 | 131 072 |
|
||||
| `grok-mini` / `grok-3-mini` | `grok-3-mini` | xAI | 64 000 | 131 072 |
|
||||
| `grok-2` | `grok-2` | xAI | — | — |
|
||||
|
||||
Any model name that does not match an alias is passed through verbatim. This is how you use OpenRouter model slugs (`openai/gpt-4.1-mini`), Ollama tags (`llama3.2`), or full Anthropic model IDs (`claude-sonnet-4-20250514`).
|
||||
|
||||
### User-defined aliases
|
||||
|
||||
You can add custom aliases in any settings file (`~/.claw/settings.json`, `.claw/settings.json`, or `.claw/settings.local.json`):
|
||||
|
||||
```json
|
||||
{
|
||||
"aliases": {
|
||||
"fast": "claude-haiku-4-5-20251213",
|
||||
"smart": "claude-opus-4-6",
|
||||
"cheap": "grok-3-mini"
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
Local project settings override user-level settings. Aliases resolve through the built-in table, so `"fast": "haiku"` also works.
|
||||
|
||||
### How provider detection works
|
||||
|
||||
1. If the resolved model name starts with `claude` → Anthropic.
|
||||
2. If it starts with `grok` → xAI.
|
||||
3. Otherwise, `claw` checks which credential is set: `ANTHROPIC_API_KEY`/`ANTHROPIC_AUTH_TOKEN` first, then `OPENAI_API_KEY`, then `XAI_API_KEY`.
|
||||
4. If nothing matches, it defaults to Anthropic.
|
||||
|
||||
## FAQ
|
||||
|
||||
### What about Codex?
|
||||
|
||||
The name "codex" appears in the Claw Code ecosystem but it does **not** refer to OpenAI Codex (the code-generation model). Here is what it means in this project:
|
||||
|
||||
- **`oh-my-codex` (OmX)** is the workflow and plugin layer that sits on top of `claw`. It provides planning modes, parallel multi-agent execution, notification routing, and other automation features. See [PHILOSOPHY.md](./PHILOSOPHY.md) and the [oh-my-codex repo](https://github.com/Yeachan-Heo/oh-my-codex).
|
||||
- **`.codex/` directories** (e.g. `.codex/skills`, `.codex/agents`, `.codex/commands`) are legacy lookup paths that `claw` still scans alongside the primary `.claw/` directories.
|
||||
- **`CODEX_HOME`** is an optional environment variable that points to a custom root for user-level skill and command lookups.
|
||||
|
||||
`claw` does **not** support OpenAI Codex sessions, the Codex CLI, or Codex session import/export. If you need to use OpenAI models (like GPT-4.1), configure the OpenAI-compatible provider as shown above in the [OpenAI-compatible endpoint](#openai-compatible-endpoint) and [OpenRouter](#openrouter) sections.
|
||||
|
||||
## HTTP proxy support
|
||||
|
||||
`claw` honours the standard `HTTP_PROXY`, `HTTPS_PROXY`, and `NO_PROXY` environment variables (both upper- and lower-case spellings are accepted) when issuing outbound requests to Anthropic, OpenAI-, and xAI-compatible endpoints. Set them before launching the CLI and the underlying `reqwest` client will be configured automatically.
|
||||
|
||||
### Environment variables
|
||||
|
||||
```bash
|
||||
export HTTPS_PROXY="http://proxy.corp.example:3128"
|
||||
export HTTP_PROXY="http://proxy.corp.example:3128"
|
||||
export NO_PROXY="localhost,127.0.0.1,.corp.example"
|
||||
|
||||
cd rust
|
||||
./target/debug/claw prompt "hello via the corporate proxy"
|
||||
```
|
||||
|
||||
### Programmatic `proxy_url` config option
|
||||
|
||||
As an alternative to per-scheme environment variables, the `ProxyConfig` type exposes a `proxy_url` field that acts as a single catch-all proxy for both HTTP and HTTPS traffic. When `proxy_url` is set it takes precedence over the separate `http_proxy` and `https_proxy` fields.
|
||||
|
||||
```rust
|
||||
use api::{build_http_client_with, ProxyConfig};
|
||||
|
||||
// From a single unified URL (config file, CLI flag, etc.)
|
||||
let config = ProxyConfig::from_proxy_url("http://proxy.corp.example:3128");
|
||||
let client = build_http_client_with(&config).expect("proxy client");
|
||||
|
||||
// Or set the field directly alongside NO_PROXY
|
||||
let config = ProxyConfig {
|
||||
proxy_url: Some("http://proxy.corp.example:3128".to_string()),
|
||||
no_proxy: Some("localhost,127.0.0.1".to_string()),
|
||||
..ProxyConfig::default()
|
||||
};
|
||||
let client = build_http_client_with(&config).expect("proxy client");
|
||||
```
|
||||
|
||||
### Notes
|
||||
|
||||
- When both `HTTPS_PROXY` and `HTTP_PROXY` are set, the secure proxy applies to `https://` URLs and the plain proxy applies to `http://` URLs.
|
||||
- `proxy_url` is a unified alternative: when set, it applies to both `http://` and `https://` destinations, overriding the per-scheme fields.
|
||||
- `NO_PROXY` accepts a comma-separated list of host suffixes (for example `.corp.example`) and IP literals.
|
||||
- Empty values are treated as unset, so leaving `HTTPS_PROXY=""` in your shell will not enable a proxy.
|
||||
- If a proxy URL cannot be parsed, `claw` falls back to a direct (no-proxy) client so existing workflows keep working; double-check the URL if you expected the request to be tunnelled.
|
||||
|
||||
## Common operational commands
|
||||
|
||||
```bash
|
||||
|
||||
394
install.sh
Executable file
394
install.sh
Executable file
@@ -0,0 +1,394 @@
|
||||
#!/usr/bin/env bash
|
||||
# Claw Code installer
|
||||
#
|
||||
# Detects the host OS, verifies the Rust toolchain (rustc + cargo),
|
||||
# builds the `claw` binary from the `rust/` workspace, and runs a
|
||||
# post-install verification step. Supports Linux, macOS, and WSL.
|
||||
#
|
||||
# Usage:
|
||||
# ./install.sh # debug build (fast, default)
|
||||
# ./install.sh --release # optimized release build
|
||||
# ./install.sh --no-verify # skip post-install verification
|
||||
# ./install.sh --help # print usage
|
||||
#
|
||||
# Environment overrides:
|
||||
# CLAW_BUILD_PROFILE=debug|release same as --release toggle
|
||||
# CLAW_SKIP_VERIFY=1 same as --no-verify
|
||||
|
||||
set -euo pipefail
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Pretty printing
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
if [ -t 1 ] && command -v tput >/dev/null 2>&1 && [ "$(tput colors 2>/dev/null || echo 0)" -ge 8 ]; then
|
||||
COLOR_RESET="$(tput sgr0)"
|
||||
COLOR_BOLD="$(tput bold)"
|
||||
COLOR_DIM="$(tput dim)"
|
||||
COLOR_RED="$(tput setaf 1)"
|
||||
COLOR_GREEN="$(tput setaf 2)"
|
||||
COLOR_YELLOW="$(tput setaf 3)"
|
||||
COLOR_BLUE="$(tput setaf 4)"
|
||||
COLOR_CYAN="$(tput setaf 6)"
|
||||
else
|
||||
COLOR_RESET=""
|
||||
COLOR_BOLD=""
|
||||
COLOR_DIM=""
|
||||
COLOR_RED=""
|
||||
COLOR_GREEN=""
|
||||
COLOR_YELLOW=""
|
||||
COLOR_BLUE=""
|
||||
COLOR_CYAN=""
|
||||
fi
|
||||
|
||||
CURRENT_STEP=0
|
||||
TOTAL_STEPS=6
|
||||
|
||||
step() {
|
||||
CURRENT_STEP=$((CURRENT_STEP + 1))
|
||||
printf '\n%s[%d/%d]%s %s%s%s\n' \
|
||||
"${COLOR_BLUE}" "${CURRENT_STEP}" "${TOTAL_STEPS}" "${COLOR_RESET}" \
|
||||
"${COLOR_BOLD}" "$1" "${COLOR_RESET}"
|
||||
}
|
||||
|
||||
info() { printf '%s ->%s %s\n' "${COLOR_CYAN}" "${COLOR_RESET}" "$1"; }
|
||||
ok() { printf '%s ok%s %s\n' "${COLOR_GREEN}" "${COLOR_RESET}" "$1"; }
|
||||
warn() { printf '%s warn%s %s\n' "${COLOR_YELLOW}" "${COLOR_RESET}" "$1"; }
|
||||
error() { printf '%s error%s %s\n' "${COLOR_RED}" "${COLOR_RESET}" "$1" 1>&2; }
|
||||
|
||||
print_banner() {
|
||||
printf '%s' "${COLOR_BOLD}"
|
||||
cat <<'EOF'
|
||||
____ _ ____ _
|
||||
/ ___|| | __ _ __ __ / ___|___ __| | ___
|
||||
| | | | / _` |\ \ /\ / /| | / _ \ / _` |/ _ \
|
||||
| |___ | || (_| | \ V V / | |__| (_) | (_| | __/
|
||||
\____||_| \__,_| \_/\_/ \____\___/ \__,_|\___|
|
||||
EOF
|
||||
printf '%s\n' "${COLOR_RESET}"
|
||||
printf '%sClaw Code installer%s\n' "${COLOR_DIM}" "${COLOR_RESET}"
|
||||
}
|
||||
|
||||
print_usage() {
|
||||
cat <<'EOF'
|
||||
Usage: ./install.sh [options]
|
||||
|
||||
Options:
|
||||
--release Build the optimized release profile (slower, smaller binary).
|
||||
--debug Build the debug profile (default, faster compile).
|
||||
--no-verify Skip the post-install verification step.
|
||||
-h, --help Show this help text and exit.
|
||||
|
||||
Environment overrides:
|
||||
CLAW_BUILD_PROFILE debug | release
|
||||
CLAW_SKIP_VERIFY set to 1 to skip verification
|
||||
EOF
|
||||
}
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Argument parsing
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
BUILD_PROFILE="${CLAW_BUILD_PROFILE:-debug}"
|
||||
SKIP_VERIFY="${CLAW_SKIP_VERIFY:-0}"
|
||||
|
||||
while [ "$#" -gt 0 ]; do
|
||||
case "$1" in
|
||||
--release)
|
||||
BUILD_PROFILE="release"
|
||||
;;
|
||||
--debug)
|
||||
BUILD_PROFILE="debug"
|
||||
;;
|
||||
--no-verify)
|
||||
SKIP_VERIFY="1"
|
||||
;;
|
||||
-h|--help)
|
||||
print_usage
|
||||
exit 0
|
||||
;;
|
||||
*)
|
||||
error "unknown argument: $1"
|
||||
print_usage
|
||||
exit 2
|
||||
;;
|
||||
esac
|
||||
shift
|
||||
done
|
||||
|
||||
case "${BUILD_PROFILE}" in
|
||||
debug|release) ;;
|
||||
*)
|
||||
error "invalid build profile: ${BUILD_PROFILE} (expected debug or release)"
|
||||
exit 2
|
||||
;;
|
||||
esac
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Troubleshooting hints
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
print_troubleshooting() {
|
||||
cat <<EOF
|
||||
|
||||
${COLOR_BOLD}Troubleshooting${COLOR_RESET}
|
||||
${COLOR_DIM}---------------${COLOR_RESET}
|
||||
|
||||
${COLOR_BOLD}1. Rust toolchain missing${COLOR_RESET}
|
||||
Install Rust via rustup:
|
||||
curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh
|
||||
Then reload your shell or run:
|
||||
source "\$HOME/.cargo/env"
|
||||
|
||||
${COLOR_BOLD}2. Linux: missing system packages${COLOR_RESET}
|
||||
The build needs git, pkg-config, and OpenSSL headers.
|
||||
Debian/Ubuntu:
|
||||
sudo apt-get update && sudo apt-get install -y \\
|
||||
git pkg-config libssl-dev ca-certificates build-essential
|
||||
Fedora/RHEL:
|
||||
sudo dnf install -y git pkgconf-pkg-config openssl-devel gcc
|
||||
Arch:
|
||||
sudo pacman -S --needed git pkgconf openssl base-devel
|
||||
|
||||
${COLOR_BOLD}3. macOS: missing Xcode CLT${COLOR_RESET}
|
||||
Install the command line tools:
|
||||
xcode-select --install
|
||||
|
||||
${COLOR_BOLD}4. Windows users${COLOR_RESET}
|
||||
Run this script from inside a WSL distro (Ubuntu/Debian recommended).
|
||||
Native Windows builds are not supported by this installer.
|
||||
|
||||
${COLOR_BOLD}5. Build fails partway through${COLOR_RESET}
|
||||
Try a clean build:
|
||||
cd rust && cargo clean && cargo build --workspace
|
||||
If the failure mentions ring/openssl, double check step 2.
|
||||
|
||||
${COLOR_BOLD}6. 'claw' not found after install${COLOR_RESET}
|
||||
The binary lives at:
|
||||
rust/target/${BUILD_PROFILE}/claw
|
||||
Add it to your PATH or invoke it with the full path.
|
||||
|
||||
EOF
|
||||
}
|
||||
|
||||
trap 'rc=$?; if [ "$rc" -ne 0 ]; then error "installation failed (exit ${rc})"; print_troubleshooting; fi' EXIT
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
require_cmd() {
|
||||
command -v "$1" >/dev/null 2>&1
|
||||
}
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Step 1: detect OS / arch / WSL
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
print_banner
|
||||
step "Detecting host environment"
|
||||
|
||||
UNAME_S="$(uname -s 2>/dev/null || echo unknown)"
|
||||
UNAME_M="$(uname -m 2>/dev/null || echo unknown)"
|
||||
OS_FAMILY="unknown"
|
||||
IS_WSL="0"
|
||||
|
||||
case "${UNAME_S}" in
|
||||
Linux*)
|
||||
OS_FAMILY="linux"
|
||||
if grep -qiE 'microsoft|wsl' /proc/version 2>/dev/null; then
|
||||
IS_WSL="1"
|
||||
fi
|
||||
;;
|
||||
Darwin*)
|
||||
OS_FAMILY="macos"
|
||||
;;
|
||||
MINGW*|MSYS*|CYGWIN*)
|
||||
OS_FAMILY="windows-shell"
|
||||
;;
|
||||
esac
|
||||
|
||||
info "uname: ${UNAME_S} ${UNAME_M}"
|
||||
info "os family: ${OS_FAMILY}"
|
||||
if [ "${IS_WSL}" = "1" ]; then
|
||||
info "wsl: yes"
|
||||
fi
|
||||
|
||||
case "${OS_FAMILY}" in
|
||||
linux|macos)
|
||||
ok "supported platform detected"
|
||||
;;
|
||||
windows-shell)
|
||||
error "Detected a native Windows shell (MSYS/Cygwin/MinGW)."
|
||||
error "Please re-run this script from inside a WSL distribution."
|
||||
exit 1
|
||||
;;
|
||||
*)
|
||||
error "Unsupported or unknown OS: ${UNAME_S}"
|
||||
error "Supported: Linux, macOS, and Windows via WSL."
|
||||
exit 1
|
||||
;;
|
||||
esac
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Step 2: locate the Rust workspace
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
step "Locating the Rust workspace"
|
||||
|
||||
SCRIPT_DIR="$(cd "$(dirname "$0")" && pwd)"
|
||||
RUST_DIR="${SCRIPT_DIR}/rust"
|
||||
|
||||
if [ ! -d "${RUST_DIR}" ]; then
|
||||
error "Could not find rust/ workspace next to install.sh"
|
||||
error "Expected: ${RUST_DIR}"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
if [ ! -f "${RUST_DIR}/Cargo.toml" ]; then
|
||||
error "Missing ${RUST_DIR}/Cargo.toml — repository layout looks unexpected."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
ok "workspace at ${RUST_DIR}"
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Step 3: prerequisite checks
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
step "Checking prerequisites"
|
||||
|
||||
MISSING_PREREQS=0
|
||||
|
||||
if require_cmd rustc; then
|
||||
RUSTC_VERSION="$(rustc --version 2>/dev/null || echo 'unknown')"
|
||||
ok "rustc found: ${RUSTC_VERSION}"
|
||||
else
|
||||
error "rustc not found in PATH"
|
||||
MISSING_PREREQS=1
|
||||
fi
|
||||
|
||||
if require_cmd cargo; then
|
||||
CARGO_VERSION="$(cargo --version 2>/dev/null || echo 'unknown')"
|
||||
ok "cargo found: ${CARGO_VERSION}"
|
||||
else
|
||||
error "cargo not found in PATH"
|
||||
MISSING_PREREQS=1
|
||||
fi
|
||||
|
||||
if require_cmd git; then
|
||||
ok "git found: $(git --version 2>/dev/null || echo 'unknown')"
|
||||
else
|
||||
warn "git not found — some workflows (login, session export) may degrade"
|
||||
fi
|
||||
|
||||
if [ "${OS_FAMILY}" = "linux" ]; then
|
||||
if require_cmd pkg-config; then
|
||||
ok "pkg-config found"
|
||||
else
|
||||
warn "pkg-config not found — may be required for OpenSSL-linked crates"
|
||||
fi
|
||||
fi
|
||||
|
||||
if [ "${OS_FAMILY}" = "macos" ]; then
|
||||
if ! require_cmd cc && ! xcode-select -p >/dev/null 2>&1; then
|
||||
warn "Xcode command line tools not detected — run: xcode-select --install"
|
||||
fi
|
||||
fi
|
||||
|
||||
if [ "${MISSING_PREREQS}" -ne 0 ]; then
|
||||
error "Missing required tools. See troubleshooting below."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Step 4: build the workspace
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
step "Building the claw workspace (${BUILD_PROFILE})"
|
||||
|
||||
CARGO_FLAGS=("build" "--workspace")
|
||||
if [ "${BUILD_PROFILE}" = "release" ]; then
|
||||
CARGO_FLAGS+=("--release")
|
||||
fi
|
||||
|
||||
info "running: cargo ${CARGO_FLAGS[*]}"
|
||||
info "this may take a few minutes on the first build"
|
||||
|
||||
(
|
||||
cd "${RUST_DIR}"
|
||||
CARGO_TERM_COLOR="${CARGO_TERM_COLOR:-always}" cargo "${CARGO_FLAGS[@]}"
|
||||
)
|
||||
|
||||
CLAW_BIN="${RUST_DIR}/target/${BUILD_PROFILE}/claw"
|
||||
|
||||
if [ ! -x "${CLAW_BIN}" ]; then
|
||||
error "Expected binary not found at ${CLAW_BIN}"
|
||||
error "The build reported success but the binary is missing — check cargo output above."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
ok "built ${CLAW_BIN}"
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Step 5: post-install verification
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
step "Verifying the installed binary"
|
||||
|
||||
if [ "${SKIP_VERIFY}" = "1" ]; then
|
||||
warn "verification skipped (--no-verify or CLAW_SKIP_VERIFY=1)"
|
||||
else
|
||||
info "running: claw --version"
|
||||
if VERSION_OUT="$("${CLAW_BIN}" --version 2>&1)"; then
|
||||
ok "claw --version -> ${VERSION_OUT}"
|
||||
else
|
||||
error "claw --version failed:"
|
||||
printf '%s\n' "${VERSION_OUT}" 1>&2
|
||||
exit 1
|
||||
fi
|
||||
|
||||
info "running: claw --help (smoke test)"
|
||||
if "${CLAW_BIN}" --help >/dev/null 2>&1; then
|
||||
ok "claw --help responded"
|
||||
else
|
||||
error "claw --help failed"
|
||||
exit 1
|
||||
fi
|
||||
fi
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Step 6: next steps
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
step "Next steps"
|
||||
|
||||
cat <<EOF
|
||||
${COLOR_GREEN}Claw Code is built and ready.${COLOR_RESET}
|
||||
|
||||
Binary: ${COLOR_BOLD}${CLAW_BIN}${COLOR_RESET}
|
||||
Profile: ${BUILD_PROFILE}
|
||||
|
||||
Try it out:
|
||||
|
||||
${COLOR_DIM}# interactive REPL${COLOR_RESET}
|
||||
${CLAW_BIN}
|
||||
|
||||
${COLOR_DIM}# one-shot prompt${COLOR_RESET}
|
||||
${CLAW_BIN} prompt "summarize this repository"
|
||||
|
||||
${COLOR_DIM}# health check (run /doctor inside the REPL)${COLOR_RESET}
|
||||
${CLAW_BIN}
|
||||
/doctor
|
||||
|
||||
Authentication:
|
||||
|
||||
export ANTHROPIC_API_KEY="sk-ant-..."
|
||||
${COLOR_DIM}# or use OAuth:${COLOR_RESET}
|
||||
${CLAW_BIN} login
|
||||
|
||||
For deeper docs, see USAGE.md and rust/README.md.
|
||||
EOF
|
||||
|
||||
# clear the failure trap on clean exit
|
||||
trap - EXIT
|
||||
1
rust/Cargo.lock
generated
1
rust/Cargo.lock
generated
@@ -1580,6 +1580,7 @@ version = "0.1.0"
|
||||
dependencies = [
|
||||
"api",
|
||||
"commands",
|
||||
"flate2",
|
||||
"plugins",
|
||||
"reqwest",
|
||||
"runtime",
|
||||
|
||||
344
rust/crates/api/src/http_client.rs
Normal file
344
rust/crates/api/src/http_client.rs
Normal file
@@ -0,0 +1,344 @@
|
||||
use crate::error::ApiError;
|
||||
|
||||
const HTTP_PROXY_KEYS: [&str; 2] = ["HTTP_PROXY", "http_proxy"];
|
||||
const HTTPS_PROXY_KEYS: [&str; 2] = ["HTTPS_PROXY", "https_proxy"];
|
||||
const NO_PROXY_KEYS: [&str; 2] = ["NO_PROXY", "no_proxy"];
|
||||
|
||||
/// Snapshot of the proxy-related environment variables that influence the
|
||||
/// outbound HTTP client. Captured up front so callers can inspect, log, and
|
||||
/// test the resolved configuration without re-reading the process environment.
|
||||
///
|
||||
/// When `proxy_url` is set it acts as a single catch-all proxy for both
|
||||
/// HTTP and HTTPS traffic, taking precedence over the per-scheme fields.
|
||||
#[derive(Debug, Clone, Default, PartialEq, Eq)]
|
||||
pub struct ProxyConfig {
|
||||
pub http_proxy: Option<String>,
|
||||
pub https_proxy: Option<String>,
|
||||
pub no_proxy: Option<String>,
|
||||
/// Optional unified proxy URL that applies to both HTTP and HTTPS.
|
||||
/// When set, this takes precedence over `http_proxy` and `https_proxy`.
|
||||
pub proxy_url: Option<String>,
|
||||
}
|
||||
|
||||
impl ProxyConfig {
|
||||
/// Read proxy settings from the live process environment, honouring both
|
||||
/// the upper- and lower-case spellings used by curl, git, and friends.
|
||||
#[must_use]
|
||||
pub fn from_env() -> Self {
|
||||
Self::from_lookup(|key| std::env::var(key).ok())
|
||||
}
|
||||
|
||||
/// Create a proxy configuration from a single URL that applies to both
|
||||
/// HTTP and HTTPS traffic. This is the config-file alternative to setting
|
||||
/// `HTTP_PROXY` and `HTTPS_PROXY` environment variables separately.
|
||||
#[must_use]
|
||||
pub fn from_proxy_url(url: impl Into<String>) -> Self {
|
||||
Self {
|
||||
proxy_url: Some(url.into()),
|
||||
..Self::default()
|
||||
}
|
||||
}
|
||||
|
||||
fn from_lookup<F>(mut lookup: F) -> Self
|
||||
where
|
||||
F: FnMut(&str) -> Option<String>,
|
||||
{
|
||||
Self {
|
||||
http_proxy: first_non_empty(&HTTP_PROXY_KEYS, &mut lookup),
|
||||
https_proxy: first_non_empty(&HTTPS_PROXY_KEYS, &mut lookup),
|
||||
no_proxy: first_non_empty(&NO_PROXY_KEYS, &mut lookup),
|
||||
proxy_url: None,
|
||||
}
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
pub fn is_empty(&self) -> bool {
|
||||
self.proxy_url.is_none() && self.http_proxy.is_none() && self.https_proxy.is_none()
|
||||
}
|
||||
}
|
||||
|
||||
/// Build a `reqwest::Client` that honours the standard `HTTP_PROXY`,
|
||||
/// `HTTPS_PROXY`, and `NO_PROXY` environment variables. When no proxy is
|
||||
/// configured the client behaves identically to `reqwest::Client::new()`.
|
||||
pub fn build_http_client() -> Result<reqwest::Client, ApiError> {
|
||||
build_http_client_with(&ProxyConfig::from_env())
|
||||
}
|
||||
|
||||
/// Infallible counterpart to [`build_http_client`] for constructors that
|
||||
/// historically returned `Self` rather than `Result<Self, _>`. When the proxy
|
||||
/// configuration is malformed we fall back to a default client so that
|
||||
/// callers retain the previous behaviour and the failure surfaces on the
|
||||
/// first outbound request instead of at construction time.
|
||||
#[must_use]
|
||||
pub fn build_http_client_or_default() -> reqwest::Client {
|
||||
build_http_client().unwrap_or_else(|_| reqwest::Client::new())
|
||||
}
|
||||
|
||||
/// Build a `reqwest::Client` from an explicit [`ProxyConfig`]. Used by tests
|
||||
/// and by callers that want to override process-level environment lookups.
|
||||
///
|
||||
/// When `config.proxy_url` is set it overrides the per-scheme `http_proxy`
|
||||
/// and `https_proxy` fields and is registered as both an HTTP and HTTPS
|
||||
/// proxy so a single value can route every outbound request.
|
||||
pub fn build_http_client_with(config: &ProxyConfig) -> Result<reqwest::Client, ApiError> {
|
||||
let mut builder = reqwest::Client::builder().no_proxy();
|
||||
|
||||
let no_proxy = config
|
||||
.no_proxy
|
||||
.as_deref()
|
||||
.and_then(reqwest::NoProxy::from_string);
|
||||
|
||||
let (http_proxy_url, https_proxy_url) = match config.proxy_url.as_deref() {
|
||||
Some(unified) => (Some(unified), Some(unified)),
|
||||
None => (config.http_proxy.as_deref(), config.https_proxy.as_deref()),
|
||||
};
|
||||
|
||||
if let Some(url) = https_proxy_url {
|
||||
let mut proxy = reqwest::Proxy::https(url)?;
|
||||
if let Some(filter) = no_proxy.clone() {
|
||||
proxy = proxy.no_proxy(Some(filter));
|
||||
}
|
||||
builder = builder.proxy(proxy);
|
||||
}
|
||||
|
||||
if let Some(url) = http_proxy_url {
|
||||
let mut proxy = reqwest::Proxy::http(url)?;
|
||||
if let Some(filter) = no_proxy.clone() {
|
||||
proxy = proxy.no_proxy(Some(filter));
|
||||
}
|
||||
builder = builder.proxy(proxy);
|
||||
}
|
||||
|
||||
Ok(builder.build()?)
|
||||
}
|
||||
|
||||
fn first_non_empty<F>(keys: &[&str], lookup: &mut F) -> Option<String>
|
||||
where
|
||||
F: FnMut(&str) -> Option<String>,
|
||||
{
|
||||
keys.iter()
|
||||
.find_map(|key| lookup(key).filter(|value| !value.is_empty()))
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::collections::HashMap;
|
||||
|
||||
use super::{build_http_client_with, ProxyConfig};
|
||||
|
||||
fn config_from_map(pairs: &[(&str, &str)]) -> ProxyConfig {
|
||||
let map: HashMap<String, String> = pairs
|
||||
.iter()
|
||||
.map(|(key, value)| ((*key).to_string(), (*value).to_string()))
|
||||
.collect();
|
||||
ProxyConfig::from_lookup(|key| map.get(key).cloned())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn proxy_config_is_empty_when_no_env_vars_are_set() {
|
||||
// given
|
||||
let config = config_from_map(&[]);
|
||||
|
||||
// when
|
||||
let empty = config.is_empty();
|
||||
|
||||
// then
|
||||
assert!(empty);
|
||||
assert_eq!(config, ProxyConfig::default());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn proxy_config_reads_uppercase_http_https_and_no_proxy() {
|
||||
// given
|
||||
let pairs = [
|
||||
("HTTP_PROXY", "http://proxy.internal:3128"),
|
||||
("HTTPS_PROXY", "http://secure.internal:3129"),
|
||||
("NO_PROXY", "localhost,127.0.0.1,.corp"),
|
||||
];
|
||||
|
||||
// when
|
||||
let config = config_from_map(&pairs);
|
||||
|
||||
// then
|
||||
assert_eq!(
|
||||
config.http_proxy.as_deref(),
|
||||
Some("http://proxy.internal:3128")
|
||||
);
|
||||
assert_eq!(
|
||||
config.https_proxy.as_deref(),
|
||||
Some("http://secure.internal:3129")
|
||||
);
|
||||
assert_eq!(
|
||||
config.no_proxy.as_deref(),
|
||||
Some("localhost,127.0.0.1,.corp")
|
||||
);
|
||||
assert!(!config.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn proxy_config_falls_back_to_lowercase_keys() {
|
||||
// given
|
||||
let pairs = [
|
||||
("http_proxy", "http://lower.internal:3128"),
|
||||
("https_proxy", "http://lower-secure.internal:3129"),
|
||||
("no_proxy", ".lower"),
|
||||
];
|
||||
|
||||
// when
|
||||
let config = config_from_map(&pairs);
|
||||
|
||||
// then
|
||||
assert_eq!(
|
||||
config.http_proxy.as_deref(),
|
||||
Some("http://lower.internal:3128")
|
||||
);
|
||||
assert_eq!(
|
||||
config.https_proxy.as_deref(),
|
||||
Some("http://lower-secure.internal:3129")
|
||||
);
|
||||
assert_eq!(config.no_proxy.as_deref(), Some(".lower"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn proxy_config_prefers_uppercase_over_lowercase_when_both_set() {
|
||||
// given
|
||||
let pairs = [
|
||||
("HTTP_PROXY", "http://upper.internal:3128"),
|
||||
("http_proxy", "http://lower.internal:3128"),
|
||||
];
|
||||
|
||||
// when
|
||||
let config = config_from_map(&pairs);
|
||||
|
||||
// then
|
||||
assert_eq!(
|
||||
config.http_proxy.as_deref(),
|
||||
Some("http://upper.internal:3128")
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn proxy_config_treats_empty_strings_as_unset() {
|
||||
// given
|
||||
let pairs = [("HTTP_PROXY", ""), ("http_proxy", "")];
|
||||
|
||||
// when
|
||||
let config = config_from_map(&pairs);
|
||||
|
||||
// then
|
||||
assert!(config.http_proxy.is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn build_http_client_succeeds_when_no_proxy_is_configured() {
|
||||
// given
|
||||
let config = ProxyConfig::default();
|
||||
|
||||
// when
|
||||
let result = build_http_client_with(&config);
|
||||
|
||||
// then
|
||||
assert!(result.is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn build_http_client_succeeds_with_valid_http_and_https_proxies() {
|
||||
// given
|
||||
let config = ProxyConfig {
|
||||
http_proxy: Some("http://proxy.internal:3128".to_string()),
|
||||
https_proxy: Some("http://secure.internal:3129".to_string()),
|
||||
no_proxy: Some("localhost,127.0.0.1".to_string()),
|
||||
proxy_url: None,
|
||||
};
|
||||
|
||||
// when
|
||||
let result = build_http_client_with(&config);
|
||||
|
||||
// then
|
||||
assert!(result.is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn build_http_client_returns_http_error_for_invalid_proxy_url() {
|
||||
// given
|
||||
let config = ProxyConfig {
|
||||
http_proxy: None,
|
||||
https_proxy: Some("not a url".to_string()),
|
||||
no_proxy: None,
|
||||
proxy_url: None,
|
||||
};
|
||||
|
||||
// when
|
||||
let result = build_http_client_with(&config);
|
||||
|
||||
// then
|
||||
let error = result.expect_err("invalid proxy URL must be reported as a build failure");
|
||||
assert!(
|
||||
matches!(error, crate::error::ApiError::Http(_)),
|
||||
"expected ApiError::Http for invalid proxy URL, got: {error:?}"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn from_proxy_url_sets_unified_field_and_leaves_per_scheme_empty() {
|
||||
// given / when
|
||||
let config = ProxyConfig::from_proxy_url("http://unified.internal:3128");
|
||||
|
||||
// then
|
||||
assert_eq!(
|
||||
config.proxy_url.as_deref(),
|
||||
Some("http://unified.internal:3128")
|
||||
);
|
||||
assert!(config.http_proxy.is_none());
|
||||
assert!(config.https_proxy.is_none());
|
||||
assert!(!config.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn build_http_client_succeeds_with_unified_proxy_url() {
|
||||
// given
|
||||
let config = ProxyConfig {
|
||||
proxy_url: Some("http://unified.internal:3128".to_string()),
|
||||
no_proxy: Some("localhost".to_string()),
|
||||
..ProxyConfig::default()
|
||||
};
|
||||
|
||||
// when
|
||||
let result = build_http_client_with(&config);
|
||||
|
||||
// then
|
||||
assert!(result.is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn proxy_url_takes_precedence_over_per_scheme_fields() {
|
||||
// given – both per-scheme and unified are set
|
||||
let config = ProxyConfig {
|
||||
http_proxy: Some("http://per-scheme.internal:1111".to_string()),
|
||||
https_proxy: Some("http://per-scheme.internal:2222".to_string()),
|
||||
no_proxy: None,
|
||||
proxy_url: Some("http://unified.internal:3128".to_string()),
|
||||
};
|
||||
|
||||
// when – building succeeds (the unified URL is valid)
|
||||
let result = build_http_client_with(&config);
|
||||
|
||||
// then
|
||||
assert!(result.is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn build_http_client_returns_error_for_invalid_unified_proxy_url() {
|
||||
// given
|
||||
let config = ProxyConfig::from_proxy_url("not a url");
|
||||
|
||||
// when
|
||||
let result = build_http_client_with(&config);
|
||||
|
||||
// then
|
||||
assert!(
|
||||
matches!(result, Err(crate::error::ApiError::Http(_))),
|
||||
"invalid unified proxy URL should fail: {result:?}"
|
||||
);
|
||||
}
|
||||
}
|
||||
@@ -1,5 +1,6 @@
|
||||
mod client;
|
||||
mod error;
|
||||
mod http_client;
|
||||
mod prompt_cache;
|
||||
mod providers;
|
||||
mod sse;
|
||||
@@ -10,6 +11,9 @@ pub use client::{
|
||||
resolve_startup_auth_source, MessageStream, OAuthTokenSet, ProviderClient,
|
||||
};
|
||||
pub use error::ApiError;
|
||||
pub use http_client::{
|
||||
build_http_client, build_http_client_or_default, build_http_client_with, ProxyConfig,
|
||||
};
|
||||
pub use prompt_cache::{
|
||||
CacheBreakEvent, PromptCache, PromptCacheConfig, PromptCachePaths, PromptCacheRecord,
|
||||
PromptCacheStats,
|
||||
@@ -17,7 +21,8 @@ pub use prompt_cache::{
|
||||
pub use providers::anthropic::{AnthropicClient, AnthropicClient as ApiClient, AuthSource};
|
||||
pub use providers::openai_compat::{OpenAiCompatClient, OpenAiCompatConfig};
|
||||
pub use providers::{
|
||||
detect_provider_kind, max_tokens_for_model, resolve_model_alias, ProviderKind,
|
||||
detect_provider_kind, max_tokens_for_model, max_tokens_for_model_with_override,
|
||||
resolve_model_alias, ProviderKind,
|
||||
};
|
||||
pub use sse::{parse_frame, SseParser};
|
||||
pub use types::{
|
||||
|
||||
@@ -704,6 +704,7 @@ mod tests {
|
||||
tools: None,
|
||||
tool_choice: None,
|
||||
stream: false,
|
||||
..Default::default()
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
use std::collections::VecDeque;
|
||||
use std::sync::atomic::{AtomicU64, Ordering};
|
||||
use std::sync::{Arc, Mutex};
|
||||
use std::time::{Duration, SystemTime, UNIX_EPOCH};
|
||||
|
||||
@@ -12,6 +13,7 @@ use serde_json::{Map, Value};
|
||||
use telemetry::{AnalyticsEvent, AnthropicRequestProfile, ClientIdentity, SessionTracer};
|
||||
|
||||
use crate::error::ApiError;
|
||||
use crate::http_client::build_http_client_or_default;
|
||||
use crate::prompt_cache::{PromptCache, PromptCacheRecord, PromptCacheStats};
|
||||
|
||||
use super::{model_token_limit, resolve_model_alias, Provider, ProviderFuture};
|
||||
@@ -21,9 +23,9 @@ use crate::types::{MessageDeltaEvent, MessageRequest, MessageResponse, StreamEve
|
||||
pub const DEFAULT_BASE_URL: &str = "https://api.anthropic.com";
|
||||
const REQUEST_ID_HEADER: &str = "request-id";
|
||||
const ALT_REQUEST_ID_HEADER: &str = "x-request-id";
|
||||
const DEFAULT_INITIAL_BACKOFF: Duration = Duration::from_millis(200);
|
||||
const DEFAULT_MAX_BACKOFF: Duration = Duration::from_secs(2);
|
||||
const DEFAULT_MAX_RETRIES: u32 = 2;
|
||||
const DEFAULT_INITIAL_BACKOFF: Duration = Duration::from_secs(1);
|
||||
const DEFAULT_MAX_BACKOFF: Duration = Duration::from_secs(128);
|
||||
const DEFAULT_MAX_RETRIES: u32 = 8;
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub enum AuthSource {
|
||||
@@ -127,7 +129,7 @@ impl AnthropicClient {
|
||||
#[must_use]
|
||||
pub fn new(api_key: impl Into<String>) -> Self {
|
||||
Self {
|
||||
http: reqwest::Client::new(),
|
||||
http: build_http_client_or_default(),
|
||||
auth: AuthSource::ApiKey(api_key.into()),
|
||||
base_url: DEFAULT_BASE_URL.to_string(),
|
||||
max_retries: DEFAULT_MAX_RETRIES,
|
||||
@@ -143,7 +145,7 @@ impl AnthropicClient {
|
||||
#[must_use]
|
||||
pub fn from_auth(auth: AuthSource) -> Self {
|
||||
Self {
|
||||
http: reqwest::Client::new(),
|
||||
http: build_http_client_or_default(),
|
||||
auth,
|
||||
base_url: DEFAULT_BASE_URL.to_string(),
|
||||
max_retries: DEFAULT_MAX_RETRIES,
|
||||
@@ -452,7 +454,7 @@ impl AnthropicClient {
|
||||
break;
|
||||
}
|
||||
|
||||
tokio::time::sleep(self.backoff_for_attempt(attempts)?).await;
|
||||
tokio::time::sleep(self.jittered_backoff_for_attempt(attempts)?).await;
|
||||
}
|
||||
|
||||
Err(ApiError::RetriesExhausted {
|
||||
@@ -485,10 +487,21 @@ impl AnthropicClient {
|
||||
}
|
||||
|
||||
async fn preflight_message_request(&self, request: &MessageRequest) -> Result<(), ApiError> {
|
||||
// Always run the local byte-estimate guard first. This catches
|
||||
// oversized requests even if the remote count_tokens endpoint is
|
||||
// unreachable, misconfigured, or unimplemented (e.g., third-party
|
||||
// Anthropic-compatible gateways). If byte estimation already flags
|
||||
// the request as oversized, reject immediately without a network
|
||||
// round trip.
|
||||
super::preflight_message_request(request)?;
|
||||
|
||||
let Some(limit) = model_token_limit(&request.model) else {
|
||||
return Ok(());
|
||||
};
|
||||
|
||||
// Best-effort refinement using the Anthropic count_tokens endpoint.
|
||||
// On any failure (network, parse, auth), fall back to the local
|
||||
// byte-estimate result which already passed above.
|
||||
let counted_input_tokens = match self.count_tokens(request).await {
|
||||
Ok(count) => count,
|
||||
Err(_) => return Ok(()),
|
||||
@@ -513,7 +526,10 @@ impl AnthropicClient {
|
||||
input_tokens: u32,
|
||||
}
|
||||
|
||||
let request_url = format!("{}/v1/messages/count_tokens", self.base_url.trim_end_matches('/'));
|
||||
let request_url = format!(
|
||||
"{}/v1/messages/count_tokens",
|
||||
self.base_url.trim_end_matches('/')
|
||||
);
|
||||
let mut request_body = self.request_profile.render_json_body(request)?;
|
||||
strip_unsupported_beta_body_fields(&mut request_body);
|
||||
let response = self
|
||||
@@ -526,12 +542,7 @@ impl AnthropicClient {
|
||||
let response = expect_success(response).await?;
|
||||
let body = response.text().await.map_err(ApiError::from)?;
|
||||
let parsed = serde_json::from_str::<CountTokensResponse>(&body).map_err(|error| {
|
||||
ApiError::json_deserialize(
|
||||
"Anthropic count_tokens",
|
||||
&request.model,
|
||||
&body,
|
||||
error,
|
||||
)
|
||||
ApiError::json_deserialize("Anthropic count_tokens", &request.model, &body, error)
|
||||
})?;
|
||||
Ok(parsed.input_tokens)
|
||||
}
|
||||
@@ -568,6 +579,42 @@ impl AnthropicClient {
|
||||
.checked_mul(multiplier)
|
||||
.map_or(self.max_backoff, |delay| delay.min(self.max_backoff)))
|
||||
}
|
||||
|
||||
fn jittered_backoff_for_attempt(&self, attempt: u32) -> Result<Duration, ApiError> {
|
||||
let base = self.backoff_for_attempt(attempt)?;
|
||||
Ok(base + jitter_for_base(base))
|
||||
}
|
||||
}
|
||||
|
||||
/// Process-wide counter that guarantees distinct jitter samples even when
|
||||
/// the system clock resolution is coarser than consecutive retry sleeps.
|
||||
static JITTER_COUNTER: AtomicU64 = AtomicU64::new(0);
|
||||
|
||||
/// Returns a random additive jitter in `[0, base]` to decorrelate retries
|
||||
/// from multiple concurrent clients. Entropy is drawn from the nanosecond
|
||||
/// wall clock mixed with a monotonic counter and run through a splitmix64
|
||||
/// finalizer; adequate for retry jitter (no cryptographic requirement).
|
||||
fn jitter_for_base(base: Duration) -> Duration {
|
||||
let base_nanos = u64::try_from(base.as_nanos()).unwrap_or(u64::MAX);
|
||||
if base_nanos == 0 {
|
||||
return Duration::ZERO;
|
||||
}
|
||||
let raw_nanos = SystemTime::now()
|
||||
.duration_since(UNIX_EPOCH)
|
||||
.map(|elapsed| u64::try_from(elapsed.as_nanos()).unwrap_or(u64::MAX))
|
||||
.unwrap_or(0);
|
||||
let tick = JITTER_COUNTER.fetch_add(1, Ordering::Relaxed);
|
||||
// splitmix64 finalizer — mixes the low bits so large bases still see
|
||||
// jitter across their full range instead of being clamped to subsec nanos.
|
||||
let mut mixed = raw_nanos
|
||||
.wrapping_add(tick)
|
||||
.wrapping_add(0x9E37_79B9_7F4A_7C15);
|
||||
mixed = (mixed ^ (mixed >> 30)).wrapping_mul(0xBF58_476D_1CE4_E5B9);
|
||||
mixed = (mixed ^ (mixed >> 27)).wrapping_mul(0x94D0_49BB_1331_11EB);
|
||||
mixed ^= mixed >> 31;
|
||||
// Inclusive upper bound: jitter may equal `base`, matching "up to base".
|
||||
let jitter_nanos = mixed % base_nanos.saturating_add(1);
|
||||
Duration::from_nanos(jitter_nanos)
|
||||
}
|
||||
|
||||
impl AuthSource {
|
||||
@@ -894,6 +941,15 @@ const fn is_retryable_status(status: reqwest::StatusCode) -> bool {
|
||||
fn strip_unsupported_beta_body_fields(body: &mut Value) {
|
||||
if let Some(object) = body.as_object_mut() {
|
||||
object.remove("betas");
|
||||
// These fields are OpenAI-compatible only; Anthropic rejects them.
|
||||
object.remove("frequency_penalty");
|
||||
object.remove("presence_penalty");
|
||||
// Anthropic uses "stop_sequences" not "stop". Convert if present.
|
||||
if let Some(stop_val) = object.remove("stop") {
|
||||
if stop_val.as_array().map_or(false, |a| !a.is_empty()) {
|
||||
object.insert("stop_sequences".to_string(), stop_val);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1223,6 +1279,7 @@ mod tests {
|
||||
tools: None,
|
||||
tool_choice: None,
|
||||
stream: false,
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
assert!(request.with_streaming().stream);
|
||||
@@ -1249,6 +1306,58 @@ mod tests {
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn jittered_backoff_stays_within_additive_bounds_and_varies() {
|
||||
let client = AnthropicClient::new("test-key").with_retry_policy(
|
||||
8,
|
||||
Duration::from_secs(1),
|
||||
Duration::from_secs(128),
|
||||
);
|
||||
let mut samples = Vec::with_capacity(64);
|
||||
for _ in 0..64 {
|
||||
let base = client.backoff_for_attempt(3).expect("base attempt 3");
|
||||
let jittered = client
|
||||
.jittered_backoff_for_attempt(3)
|
||||
.expect("jittered attempt 3");
|
||||
assert!(
|
||||
jittered >= base,
|
||||
"jittered delay {jittered:?} must be at least the base {base:?}"
|
||||
);
|
||||
assert!(
|
||||
jittered <= base * 2,
|
||||
"jittered delay {jittered:?} must not exceed base*2 {:?}",
|
||||
base * 2
|
||||
);
|
||||
samples.push(jittered);
|
||||
}
|
||||
let distinct: std::collections::HashSet<_> = samples.iter().collect();
|
||||
assert!(
|
||||
distinct.len() > 1,
|
||||
"jitter should produce varied delays across samples, got {samples:?}"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn default_retry_policy_matches_exponential_schedule() {
|
||||
let client = AnthropicClient::new("test-key");
|
||||
assert_eq!(
|
||||
client.backoff_for_attempt(1).expect("attempt 1"),
|
||||
Duration::from_secs(1)
|
||||
);
|
||||
assert_eq!(
|
||||
client.backoff_for_attempt(2).expect("attempt 2"),
|
||||
Duration::from_secs(2)
|
||||
);
|
||||
assert_eq!(
|
||||
client.backoff_for_attempt(3).expect("attempt 3"),
|
||||
Duration::from_secs(4)
|
||||
);
|
||||
assert_eq!(
|
||||
client.backoff_for_attempt(8).expect("attempt 8"),
|
||||
Duration::from_secs(128)
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn retryable_statuses_are_detected() {
|
||||
assert!(super::is_retryable_status(
|
||||
@@ -1350,6 +1459,52 @@ mod tests {
|
||||
assert_eq!(body, original);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn strip_removes_openai_only_fields_and_converts_stop() {
|
||||
let mut body = serde_json::json!({
|
||||
"model": "claude-sonnet-4-6",
|
||||
"max_tokens": 1024,
|
||||
"temperature": 0.7,
|
||||
"frequency_penalty": 0.5,
|
||||
"presence_penalty": 0.3,
|
||||
"stop": ["\n"],
|
||||
});
|
||||
|
||||
super::strip_unsupported_beta_body_fields(&mut body);
|
||||
|
||||
// temperature is kept (Anthropic supports it)
|
||||
assert_eq!(body["temperature"], serde_json::json!(0.7));
|
||||
// frequency_penalty and presence_penalty are removed
|
||||
assert!(
|
||||
body.get("frequency_penalty").is_none(),
|
||||
"frequency_penalty must be stripped for Anthropic"
|
||||
);
|
||||
assert!(
|
||||
body.get("presence_penalty").is_none(),
|
||||
"presence_penalty must be stripped for Anthropic"
|
||||
);
|
||||
// stop is renamed to stop_sequences
|
||||
assert!(body.get("stop").is_none(), "stop must be renamed");
|
||||
assert_eq!(body["stop_sequences"], serde_json::json!(["\n"]));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn strip_does_not_add_empty_stop_sequences() {
|
||||
let mut body = serde_json::json!({
|
||||
"model": "claude-sonnet-4-6",
|
||||
"max_tokens": 1024,
|
||||
"stop": [],
|
||||
});
|
||||
|
||||
super::strip_unsupported_beta_body_fields(&mut body);
|
||||
|
||||
assert!(body.get("stop").is_none());
|
||||
assert!(
|
||||
body.get("stop_sequences").is_none(),
|
||||
"empty stop should not produce stop_sequences"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn rendered_request_body_strips_betas_for_standard_messages_endpoint() {
|
||||
let client = AnthropicClient::new("test-key").with_beta("tools-2026-04-01");
|
||||
@@ -1361,6 +1516,7 @@ mod tests {
|
||||
tools: None,
|
||||
tool_choice: None,
|
||||
stream: false,
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let mut rendered = client
|
||||
|
||||
@@ -169,6 +169,31 @@ pub fn metadata_for_model(model: &str) -> Option<ProviderMetadata> {
|
||||
default_base_url: openai_compat::DEFAULT_XAI_BASE_URL,
|
||||
});
|
||||
}
|
||||
// Explicit provider-namespaced models (e.g. "openai/gpt-4.1-mini") must
|
||||
// route to the correct provider regardless of which auth env vars are set.
|
||||
// Without this, detect_provider_kind falls through to the auth-sniffer
|
||||
// order and misroutes to Anthropic if ANTHROPIC_API_KEY is present.
|
||||
if canonical.starts_with("openai/") || canonical.starts_with("gpt-") {
|
||||
return Some(ProviderMetadata {
|
||||
provider: ProviderKind::OpenAi,
|
||||
auth_env: "OPENAI_API_KEY",
|
||||
base_url_env: "OPENAI_BASE_URL",
|
||||
default_base_url: openai_compat::DEFAULT_OPENAI_BASE_URL,
|
||||
});
|
||||
}
|
||||
// Alibaba DashScope compatible-mode endpoint. Routes qwen/* and bare
|
||||
// qwen-* model names (qwen-max, qwen-plus, qwen-turbo, qwen-qwq, etc.)
|
||||
// to the OpenAI-compat client pointed at DashScope's /compatible-mode/v1.
|
||||
// Uses the OpenAi provider kind because DashScope speaks the OpenAI REST
|
||||
// shape — only the base URL and auth env var differ.
|
||||
if canonical.starts_with("qwen/") || canonical.starts_with("qwen-") {
|
||||
return Some(ProviderMetadata {
|
||||
provider: ProviderKind::OpenAi,
|
||||
auth_env: "DASHSCOPE_API_KEY",
|
||||
base_url_env: "DASHSCOPE_BASE_URL",
|
||||
default_base_url: openai_compat::DEFAULT_DASHSCOPE_BASE_URL,
|
||||
});
|
||||
}
|
||||
None
|
||||
}
|
||||
|
||||
@@ -204,6 +229,14 @@ pub fn max_tokens_for_model(model: &str) -> u32 {
|
||||
)
|
||||
}
|
||||
|
||||
/// Returns the effective max output tokens for a model, preferring a plugin
|
||||
/// override when present. Falls back to [`max_tokens_for_model`] when the
|
||||
/// override is `None`.
|
||||
#[must_use]
|
||||
pub fn max_tokens_for_model_with_override(model: &str, plugin_override: Option<u32>) -> u32 {
|
||||
plugin_override.unwrap_or_else(|| max_tokens_for_model(model))
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
pub fn model_token_limit(model: &str) -> Option<ModelTokenLimit> {
|
||||
let canonical = resolve_model_alias(model);
|
||||
@@ -323,8 +356,9 @@ mod tests {
|
||||
};
|
||||
|
||||
use super::{
|
||||
detect_provider_kind, load_dotenv_file, max_tokens_for_model, model_token_limit,
|
||||
parse_dotenv, preflight_message_request, resolve_model_alias, ProviderKind,
|
||||
detect_provider_kind, load_dotenv_file, max_tokens_for_model,
|
||||
max_tokens_for_model_with_override, model_token_limit, parse_dotenv,
|
||||
preflight_message_request, resolve_model_alias, ProviderKind,
|
||||
};
|
||||
|
||||
#[test]
|
||||
@@ -343,12 +377,114 @@ mod tests {
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn openai_namespaced_model_routes_to_openai_not_anthropic() {
|
||||
// Regression: "openai/gpt-4.1-mini" was misrouted to Anthropic when
|
||||
// ANTHROPIC_API_KEY was set because metadata_for_model returned None
|
||||
// and detect_provider_kind fell through to auth-sniffer order.
|
||||
// The model prefix must win over env-var presence.
|
||||
let kind = super::metadata_for_model("openai/gpt-4.1-mini")
|
||||
.map(|m| m.provider)
|
||||
.unwrap_or_else(|| detect_provider_kind("openai/gpt-4.1-mini"));
|
||||
assert_eq!(
|
||||
kind,
|
||||
ProviderKind::OpenAi,
|
||||
"openai/ prefix must route to OpenAi regardless of ANTHROPIC_API_KEY"
|
||||
);
|
||||
|
||||
// Also cover bare gpt- prefix
|
||||
let kind2 = super::metadata_for_model("gpt-4o")
|
||||
.map(|m| m.provider)
|
||||
.unwrap_or_else(|| detect_provider_kind("gpt-4o"));
|
||||
assert_eq!(kind2, ProviderKind::OpenAi);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn qwen_prefix_routes_to_dashscope_not_anthropic() {
|
||||
// User request from Discord #clawcode-get-help: web3g wants to use
|
||||
// Qwen 3.6 Plus via native Alibaba DashScope API (not OpenRouter,
|
||||
// which has lower rate limits). metadata_for_model must route
|
||||
// qwen/* and bare qwen-* to the OpenAi provider kind pointed at
|
||||
// the DashScope compatible-mode endpoint, regardless of whether
|
||||
// ANTHROPIC_API_KEY is present in the environment.
|
||||
let meta = super::metadata_for_model("qwen/qwen-max")
|
||||
.expect("qwen/ prefix must resolve to DashScope metadata");
|
||||
assert_eq!(meta.provider, ProviderKind::OpenAi);
|
||||
assert_eq!(meta.auth_env, "DASHSCOPE_API_KEY");
|
||||
assert_eq!(meta.base_url_env, "DASHSCOPE_BASE_URL");
|
||||
assert!(meta.default_base_url.contains("dashscope.aliyuncs.com"));
|
||||
|
||||
// Bare qwen- prefix also routes
|
||||
let meta2 = super::metadata_for_model("qwen-plus")
|
||||
.expect("qwen- prefix must resolve to DashScope metadata");
|
||||
assert_eq!(meta2.provider, ProviderKind::OpenAi);
|
||||
assert_eq!(meta2.auth_env, "DASHSCOPE_API_KEY");
|
||||
|
||||
// detect_provider_kind must agree even if ANTHROPIC_API_KEY is set
|
||||
let kind = detect_provider_kind("qwen/qwen3-coder");
|
||||
assert_eq!(
|
||||
kind,
|
||||
ProviderKind::OpenAi,
|
||||
"qwen/ prefix must win over auth-sniffer order"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn keeps_existing_max_token_heuristic() {
|
||||
assert_eq!(max_tokens_for_model("opus"), 32_000);
|
||||
assert_eq!(max_tokens_for_model("grok-3"), 64_000);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn plugin_config_max_output_tokens_overrides_model_default() {
|
||||
// given
|
||||
let nanos = std::time::SystemTime::now()
|
||||
.duration_since(std::time::UNIX_EPOCH)
|
||||
.expect("time should be after epoch")
|
||||
.as_nanos();
|
||||
let root = std::env::temp_dir().join(format!("api-plugin-max-tokens-{nanos}"));
|
||||
let cwd = root.join("project");
|
||||
let home = root.join("home").join(".claw");
|
||||
std::fs::create_dir_all(cwd.join(".claw")).expect("project config dir");
|
||||
std::fs::create_dir_all(&home).expect("home config dir");
|
||||
std::fs::write(
|
||||
home.join("settings.json"),
|
||||
r#"{
|
||||
"plugins": {
|
||||
"maxOutputTokens": 12345
|
||||
}
|
||||
}"#,
|
||||
)
|
||||
.expect("write plugin settings");
|
||||
|
||||
// when
|
||||
let loaded = runtime::ConfigLoader::new(&cwd, &home)
|
||||
.load()
|
||||
.expect("config should load");
|
||||
let plugin_override = loaded.plugins().max_output_tokens();
|
||||
let effective = max_tokens_for_model_with_override("claude-opus-4-6", plugin_override);
|
||||
|
||||
// then
|
||||
assert_eq!(plugin_override, Some(12345));
|
||||
assert_eq!(effective, 12345);
|
||||
assert_ne!(effective, max_tokens_for_model("claude-opus-4-6"));
|
||||
|
||||
std::fs::remove_dir_all(root).expect("cleanup temp dir");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn max_tokens_for_model_with_override_falls_back_when_plugin_unset() {
|
||||
// given
|
||||
let plugin_override: Option<u32> = None;
|
||||
|
||||
// when
|
||||
let effective = max_tokens_for_model_with_override("claude-opus-4-6", plugin_override);
|
||||
|
||||
// then
|
||||
assert_eq!(effective, max_tokens_for_model("claude-opus-4-6"));
|
||||
assert_eq!(effective, 32_000);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn returns_context_window_metadata_for_supported_models() {
|
||||
assert_eq!(
|
||||
@@ -387,6 +523,7 @@ mod tests {
|
||||
}]),
|
||||
tool_choice: Some(ToolChoice::Auto),
|
||||
stream: true,
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let error = preflight_message_request(&request)
|
||||
@@ -425,6 +562,7 @@ mod tests {
|
||||
tools: None,
|
||||
tool_choice: None,
|
||||
stream: false,
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
preflight_message_request(&request)
|
||||
|
||||
@@ -1,10 +1,12 @@
|
||||
use std::collections::{BTreeMap, VecDeque};
|
||||
use std::time::Duration;
|
||||
use std::sync::atomic::{AtomicU64, Ordering};
|
||||
use std::time::{Duration, SystemTime, UNIX_EPOCH};
|
||||
|
||||
use serde::Deserialize;
|
||||
use serde_json::{json, Value};
|
||||
|
||||
use crate::error::ApiError;
|
||||
use crate::http_client::build_http_client_or_default;
|
||||
use crate::types::{
|
||||
ContentBlockDelta, ContentBlockDeltaEvent, ContentBlockStartEvent, ContentBlockStopEvent,
|
||||
InputContentBlock, InputMessage, MessageDelta, MessageDeltaEvent, MessageRequest,
|
||||
@@ -16,11 +18,12 @@ use super::{preflight_message_request, Provider, ProviderFuture};
|
||||
|
||||
pub const DEFAULT_XAI_BASE_URL: &str = "https://api.x.ai/v1";
|
||||
pub const DEFAULT_OPENAI_BASE_URL: &str = "https://api.openai.com/v1";
|
||||
pub const DEFAULT_DASHSCOPE_BASE_URL: &str = "https://dashscope.aliyuncs.com/compatible-mode/v1";
|
||||
const REQUEST_ID_HEADER: &str = "request-id";
|
||||
const ALT_REQUEST_ID_HEADER: &str = "x-request-id";
|
||||
const DEFAULT_INITIAL_BACKOFF: Duration = Duration::from_millis(200);
|
||||
const DEFAULT_MAX_BACKOFF: Duration = Duration::from_secs(2);
|
||||
const DEFAULT_MAX_RETRIES: u32 = 2;
|
||||
const DEFAULT_INITIAL_BACKOFF: Duration = Duration::from_secs(1);
|
||||
const DEFAULT_MAX_BACKOFF: Duration = Duration::from_secs(128);
|
||||
const DEFAULT_MAX_RETRIES: u32 = 8;
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub struct OpenAiCompatConfig {
|
||||
@@ -32,6 +35,7 @@ pub struct OpenAiCompatConfig {
|
||||
|
||||
const XAI_ENV_VARS: &[&str] = &["XAI_API_KEY"];
|
||||
const OPENAI_ENV_VARS: &[&str] = &["OPENAI_API_KEY"];
|
||||
const DASHSCOPE_ENV_VARS: &[&str] = &["DASHSCOPE_API_KEY"];
|
||||
|
||||
impl OpenAiCompatConfig {
|
||||
#[must_use]
|
||||
@@ -53,11 +57,27 @@ impl OpenAiCompatConfig {
|
||||
default_base_url: DEFAULT_OPENAI_BASE_URL,
|
||||
}
|
||||
}
|
||||
|
||||
/// Alibaba DashScope compatible-mode endpoint (Qwen family models).
|
||||
/// Uses the OpenAI-compatible REST shape at /compatible-mode/v1.
|
||||
/// Requested via Discord #clawcode-get-help: native Alibaba API for
|
||||
/// higher rate limits than going through OpenRouter.
|
||||
#[must_use]
|
||||
pub const fn dashscope() -> Self {
|
||||
Self {
|
||||
provider_name: "DashScope",
|
||||
api_key_env: "DASHSCOPE_API_KEY",
|
||||
base_url_env: "DASHSCOPE_BASE_URL",
|
||||
default_base_url: DEFAULT_DASHSCOPE_BASE_URL,
|
||||
}
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
pub fn credential_env_vars(self) -> &'static [&'static str] {
|
||||
match self.provider_name {
|
||||
"xAI" => XAI_ENV_VARS,
|
||||
"OpenAI" => OPENAI_ENV_VARS,
|
||||
"DashScope" => DASHSCOPE_ENV_VARS,
|
||||
_ => &[],
|
||||
}
|
||||
}
|
||||
@@ -81,7 +101,7 @@ impl OpenAiCompatClient {
|
||||
#[must_use]
|
||||
pub fn new(api_key: impl Into<String>, config: OpenAiCompatConfig) -> Self {
|
||||
Self {
|
||||
http: reqwest::Client::new(),
|
||||
http: build_http_client_or_default(),
|
||||
api_key: api_key.into(),
|
||||
config,
|
||||
base_url: read_base_url(config),
|
||||
@@ -133,12 +153,7 @@ impl OpenAiCompatClient {
|
||||
let request_id = request_id_from_headers(response.headers());
|
||||
let body = response.text().await.map_err(ApiError::from)?;
|
||||
let payload = serde_json::from_str::<ChatCompletionResponse>(&body).map_err(|error| {
|
||||
ApiError::json_deserialize(
|
||||
self.config.provider_name,
|
||||
&request.model,
|
||||
&body,
|
||||
error,
|
||||
)
|
||||
ApiError::json_deserialize(self.config.provider_name, &request.model, &body, error)
|
||||
})?;
|
||||
let mut normalized = normalize_response(&request.model, payload)?;
|
||||
if normalized.request_id.is_none() {
|
||||
@@ -158,10 +173,7 @@ impl OpenAiCompatClient {
|
||||
Ok(MessageStream {
|
||||
request_id: request_id_from_headers(response.headers()),
|
||||
response,
|
||||
parser: OpenAiSseParser::with_context(
|
||||
self.config.provider_name,
|
||||
request.model.clone(),
|
||||
),
|
||||
parser: OpenAiSseParser::with_context(self.config.provider_name, request.model.clone()),
|
||||
pending: VecDeque::new(),
|
||||
done: false,
|
||||
state: StreamState::new(request.model.clone()),
|
||||
@@ -190,7 +202,7 @@ impl OpenAiCompatClient {
|
||||
break retryable_error;
|
||||
}
|
||||
|
||||
tokio::time::sleep(self.backoff_for_attempt(attempts)?).await;
|
||||
tokio::time::sleep(self.jittered_backoff_for_attempt(attempts)?).await;
|
||||
};
|
||||
|
||||
Err(ApiError::RetriesExhausted {
|
||||
@@ -226,6 +238,39 @@ impl OpenAiCompatClient {
|
||||
.checked_mul(multiplier)
|
||||
.map_or(self.max_backoff, |delay| delay.min(self.max_backoff)))
|
||||
}
|
||||
|
||||
fn jittered_backoff_for_attempt(&self, attempt: u32) -> Result<Duration, ApiError> {
|
||||
let base = self.backoff_for_attempt(attempt)?;
|
||||
Ok(base + jitter_for_base(base))
|
||||
}
|
||||
}
|
||||
|
||||
/// Process-wide counter that guarantees distinct jitter samples even when
|
||||
/// the system clock resolution is coarser than consecutive retry sleeps.
|
||||
static JITTER_COUNTER: AtomicU64 = AtomicU64::new(0);
|
||||
|
||||
/// Returns a random additive jitter in `[0, base]` to decorrelate retries
|
||||
/// from multiple concurrent clients. Entropy is drawn from the nanosecond
|
||||
/// wall clock mixed with a monotonic counter and run through a splitmix64
|
||||
/// finalizer; adequate for retry jitter (no cryptographic requirement).
|
||||
fn jitter_for_base(base: Duration) -> Duration {
|
||||
let base_nanos = u64::try_from(base.as_nanos()).unwrap_or(u64::MAX);
|
||||
if base_nanos == 0 {
|
||||
return Duration::ZERO;
|
||||
}
|
||||
let raw_nanos = SystemTime::now()
|
||||
.duration_since(UNIX_EPOCH)
|
||||
.map(|elapsed| u64::try_from(elapsed.as_nanos()).unwrap_or(u64::MAX))
|
||||
.unwrap_or(0);
|
||||
let tick = JITTER_COUNTER.fetch_add(1, Ordering::Relaxed);
|
||||
let mut mixed = raw_nanos
|
||||
.wrapping_add(tick)
|
||||
.wrapping_add(0x9E37_79B9_7F4A_7C15);
|
||||
mixed = (mixed ^ (mixed >> 30)).wrapping_mul(0xBF58_476D_1CE4_E5B9);
|
||||
mixed = (mixed ^ (mixed >> 27)).wrapping_mul(0x94D0_49BB_1331_11EB);
|
||||
mixed ^= mixed >> 31;
|
||||
let jitter_nanos = mixed % base_nanos.saturating_add(1);
|
||||
Duration::from_nanos(jitter_nanos)
|
||||
}
|
||||
|
||||
impl Provider for OpenAiCompatClient {
|
||||
@@ -657,6 +702,25 @@ struct ErrorBody {
|
||||
message: Option<String>,
|
||||
}
|
||||
|
||||
/// Returns true for models known to reject tuning parameters like temperature,
|
||||
/// top_p, frequency_penalty, and presence_penalty. These are typically
|
||||
/// reasoning/chain-of-thought models with fixed sampling.
|
||||
fn is_reasoning_model(model: &str) -> bool {
|
||||
let lowered = model.to_ascii_lowercase();
|
||||
// Strip any provider/ prefix for the check (e.g. qwen/qwen-qwq -> qwen-qwq)
|
||||
let canonical = lowered.rsplit('/').next().unwrap_or(lowered.as_str());
|
||||
// OpenAI reasoning models
|
||||
canonical.starts_with("o1")
|
||||
|| canonical.starts_with("o3")
|
||||
|| canonical.starts_with("o4")
|
||||
// xAI reasoning: grok-3-mini always uses reasoning mode
|
||||
|| canonical == "grok-3-mini"
|
||||
// Alibaba DashScope reasoning variants (QwQ + Qwen3-Thinking family)
|
||||
|| canonical.starts_with("qwen-qwq")
|
||||
|| canonical.starts_with("qwq")
|
||||
|| canonical.contains("thinking")
|
||||
}
|
||||
|
||||
fn build_chat_completion_request(request: &MessageRequest, config: OpenAiCompatConfig) -> Value {
|
||||
let mut messages = Vec::new();
|
||||
if let Some(system) = request.system.as_ref().filter(|value| !value.is_empty()) {
|
||||
@@ -688,6 +752,30 @@ fn build_chat_completion_request(request: &MessageRequest, config: OpenAiCompatC
|
||||
payload["tool_choice"] = openai_tool_choice(tool_choice);
|
||||
}
|
||||
|
||||
// OpenAI-compatible tuning parameters — only included when explicitly set.
|
||||
// Reasoning models (o1/o3/o4/grok-3-mini) reject these params with 400;
|
||||
// silently strip them to avoid cryptic provider errors.
|
||||
if !is_reasoning_model(&request.model) {
|
||||
if let Some(temperature) = request.temperature {
|
||||
payload["temperature"] = json!(temperature);
|
||||
}
|
||||
if let Some(top_p) = request.top_p {
|
||||
payload["top_p"] = json!(top_p);
|
||||
}
|
||||
if let Some(frequency_penalty) = request.frequency_penalty {
|
||||
payload["frequency_penalty"] = json!(frequency_penalty);
|
||||
}
|
||||
if let Some(presence_penalty) = request.presence_penalty {
|
||||
payload["presence_penalty"] = json!(presence_penalty);
|
||||
}
|
||||
}
|
||||
// stop is generally safe for all providers
|
||||
if let Some(stop) = &request.stop {
|
||||
if !stop.is_empty() {
|
||||
payload["stop"] = json!(stop);
|
||||
}
|
||||
}
|
||||
|
||||
payload
|
||||
}
|
||||
|
||||
@@ -976,8 +1064,9 @@ impl StringExt for String {
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::{
|
||||
build_chat_completion_request, chat_completions_endpoint, normalize_finish_reason,
|
||||
openai_tool_choice, parse_tool_arguments, OpenAiCompatClient, OpenAiCompatConfig,
|
||||
build_chat_completion_request, chat_completions_endpoint, is_reasoning_model,
|
||||
normalize_finish_reason, openai_tool_choice, parse_tool_arguments, OpenAiCompatClient,
|
||||
OpenAiCompatConfig,
|
||||
};
|
||||
use crate::error::ApiError;
|
||||
use crate::types::{
|
||||
@@ -1016,6 +1105,7 @@ mod tests {
|
||||
}]),
|
||||
tool_choice: Some(ToolChoice::Auto),
|
||||
stream: false,
|
||||
..Default::default()
|
||||
},
|
||||
OpenAiCompatConfig::xai(),
|
||||
);
|
||||
@@ -1038,6 +1128,7 @@ mod tests {
|
||||
tools: None,
|
||||
tool_choice: None,
|
||||
stream: true,
|
||||
..Default::default()
|
||||
},
|
||||
OpenAiCompatConfig::openai(),
|
||||
);
|
||||
@@ -1056,6 +1147,7 @@ mod tests {
|
||||
tools: None,
|
||||
tool_choice: None,
|
||||
stream: true,
|
||||
..Default::default()
|
||||
},
|
||||
OpenAiCompatConfig::xai(),
|
||||
);
|
||||
@@ -1126,4 +1218,104 @@ mod tests {
|
||||
assert_eq!(normalize_finish_reason("stop"), "end_turn");
|
||||
assert_eq!(normalize_finish_reason("tool_calls"), "tool_use");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn tuning_params_included_in_payload_when_set() {
|
||||
let request = MessageRequest {
|
||||
model: "gpt-4o".to_string(),
|
||||
max_tokens: 1024,
|
||||
messages: vec![],
|
||||
system: None,
|
||||
tools: None,
|
||||
tool_choice: None,
|
||||
stream: false,
|
||||
temperature: Some(0.7),
|
||||
top_p: Some(0.9),
|
||||
frequency_penalty: Some(0.5),
|
||||
presence_penalty: Some(0.3),
|
||||
stop: Some(vec!["\n".to_string()]),
|
||||
};
|
||||
let payload = build_chat_completion_request(&request, OpenAiCompatConfig::openai());
|
||||
assert_eq!(payload["temperature"], 0.7);
|
||||
assert_eq!(payload["top_p"], 0.9);
|
||||
assert_eq!(payload["frequency_penalty"], 0.5);
|
||||
assert_eq!(payload["presence_penalty"], 0.3);
|
||||
assert_eq!(payload["stop"], json!(["\n"]));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn reasoning_model_strips_tuning_params() {
|
||||
let request = MessageRequest {
|
||||
model: "o1-mini".to_string(),
|
||||
max_tokens: 1024,
|
||||
messages: vec![],
|
||||
stream: false,
|
||||
temperature: Some(0.7),
|
||||
top_p: Some(0.9),
|
||||
frequency_penalty: Some(0.5),
|
||||
presence_penalty: Some(0.3),
|
||||
stop: Some(vec!["\n".to_string()]),
|
||||
..Default::default()
|
||||
};
|
||||
let payload = build_chat_completion_request(&request, OpenAiCompatConfig::openai());
|
||||
assert!(
|
||||
payload.get("temperature").is_none(),
|
||||
"reasoning model should strip temperature"
|
||||
);
|
||||
assert!(
|
||||
payload.get("top_p").is_none(),
|
||||
"reasoning model should strip top_p"
|
||||
);
|
||||
assert!(payload.get("frequency_penalty").is_none());
|
||||
assert!(payload.get("presence_penalty").is_none());
|
||||
// stop is safe for all providers
|
||||
assert_eq!(payload["stop"], json!(["\n"]));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn grok_3_mini_is_reasoning_model() {
|
||||
assert!(is_reasoning_model("grok-3-mini"));
|
||||
assert!(is_reasoning_model("o1"));
|
||||
assert!(is_reasoning_model("o1-mini"));
|
||||
assert!(is_reasoning_model("o3-mini"));
|
||||
assert!(!is_reasoning_model("gpt-4o"));
|
||||
assert!(!is_reasoning_model("grok-3"));
|
||||
assert!(!is_reasoning_model("claude-sonnet-4-6"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn qwen_reasoning_variants_are_detected() {
|
||||
// QwQ reasoning model
|
||||
assert!(is_reasoning_model("qwen-qwq-32b"));
|
||||
assert!(is_reasoning_model("qwen/qwen-qwq-32b"));
|
||||
// Qwen3 thinking family
|
||||
assert!(is_reasoning_model("qwen3-30b-a3b-thinking"));
|
||||
assert!(is_reasoning_model("qwen/qwen3-30b-a3b-thinking"));
|
||||
// Bare qwq
|
||||
assert!(is_reasoning_model("qwq-plus"));
|
||||
// Regular Qwen models must NOT be classified as reasoning
|
||||
assert!(!is_reasoning_model("qwen-max"));
|
||||
assert!(!is_reasoning_model("qwen/qwen-plus"));
|
||||
assert!(!is_reasoning_model("qwen-turbo"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn tuning_params_omitted_from_payload_when_none() {
|
||||
let request = MessageRequest {
|
||||
model: "gpt-4o".to_string(),
|
||||
max_tokens: 1024,
|
||||
messages: vec![],
|
||||
stream: false,
|
||||
..Default::default()
|
||||
};
|
||||
let payload = build_chat_completion_request(&request, OpenAiCompatConfig::openai());
|
||||
assert!(
|
||||
payload.get("temperature").is_none(),
|
||||
"temperature should be absent"
|
||||
);
|
||||
assert!(payload.get("top_p").is_none(), "top_p should be absent");
|
||||
assert!(payload.get("frequency_penalty").is_none());
|
||||
assert!(payload.get("presence_penalty").is_none());
|
||||
assert!(payload.get("stop").is_none());
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2,7 +2,7 @@ use runtime::{pricing_for_model, TokenUsage, UsageCostEstimate};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::Value;
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
|
||||
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, Default)]
|
||||
pub struct MessageRequest {
|
||||
pub model: String,
|
||||
pub max_tokens: u32,
|
||||
@@ -15,6 +15,17 @@ pub struct MessageRequest {
|
||||
pub tool_choice: Option<ToolChoice>,
|
||||
#[serde(default, skip_serializing_if = "std::ops::Not::not")]
|
||||
pub stream: bool,
|
||||
/// OpenAI-compatible tuning parameters. Optional — omitted from payload when None.
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub temperature: Option<f64>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub top_p: Option<f64>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub frequency_penalty: Option<f64>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub presence_penalty: Option<f64>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub stop: Option<Vec<String>>,
|
||||
}
|
||||
|
||||
impl MessageRequest {
|
||||
|
||||
@@ -127,6 +127,7 @@ async fn send_message_blocks_oversized_requests_before_the_http_call() {
|
||||
tools: None,
|
||||
tool_choice: None,
|
||||
stream: false,
|
||||
..Default::default()
|
||||
})
|
||||
.await
|
||||
.expect_err("oversized request should fail local context-window preflight");
|
||||
@@ -545,6 +546,71 @@ async fn surfaces_retry_exhaustion_for_persistent_retryable_errors() {
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn retries_multiple_retryable_failures_with_exponential_backoff_and_jitter() {
|
||||
let state = Arc::new(Mutex::new(Vec::<CapturedRequest>::new()));
|
||||
let server = spawn_server(
|
||||
state.clone(),
|
||||
vec![
|
||||
http_response(
|
||||
"429 Too Many Requests",
|
||||
"application/json",
|
||||
"{\"type\":\"error\",\"error\":{\"type\":\"rate_limit_error\",\"message\":\"slow down\"}}",
|
||||
),
|
||||
http_response(
|
||||
"500 Internal Server Error",
|
||||
"application/json",
|
||||
"{\"type\":\"error\",\"error\":{\"type\":\"api_error\",\"message\":\"boom\"}}",
|
||||
),
|
||||
http_response(
|
||||
"503 Service Unavailable",
|
||||
"application/json",
|
||||
"{\"type\":\"error\",\"error\":{\"type\":\"overloaded_error\",\"message\":\"busy\"}}",
|
||||
),
|
||||
http_response(
|
||||
"429 Too Many Requests",
|
||||
"application/json",
|
||||
"{\"type\":\"error\",\"error\":{\"type\":\"rate_limit_error\",\"message\":\"slow down again\"}}",
|
||||
),
|
||||
http_response(
|
||||
"503 Service Unavailable",
|
||||
"application/json",
|
||||
"{\"type\":\"error\",\"error\":{\"type\":\"overloaded_error\",\"message\":\"still busy\"}}",
|
||||
),
|
||||
http_response(
|
||||
"200 OK",
|
||||
"application/json",
|
||||
"{\"id\":\"msg_exp_retry\",\"type\":\"message\",\"role\":\"assistant\",\"content\":[{\"type\":\"text\",\"text\":\"Recovered after 5\"}],\"model\":\"claude-3-7-sonnet-latest\",\"stop_reason\":\"end_turn\",\"stop_sequence\":null,\"usage\":{\"input_tokens\":3,\"output_tokens\":2}}",
|
||||
),
|
||||
],
|
||||
)
|
||||
.await;
|
||||
|
||||
let client = ApiClient::new("test-key")
|
||||
.with_base_url(server.base_url())
|
||||
.with_retry_policy(8, Duration::from_millis(1), Duration::from_millis(4));
|
||||
let started_at = std::time::Instant::now();
|
||||
|
||||
let response = client
|
||||
.send_message(&sample_request(false))
|
||||
.await
|
||||
.expect("8-retry policy should absorb 5 retryable failures");
|
||||
|
||||
let elapsed = started_at.elapsed();
|
||||
assert_eq!(response.total_tokens(), 5);
|
||||
assert_eq!(
|
||||
state.lock().await.len(),
|
||||
6,
|
||||
"client should issue 1 original + 5 retry requests before the 200"
|
||||
);
|
||||
// Jittered sleeps are bounded by 2 * max_backoff per retry (base + jitter),
|
||||
// so 5 sleeps fit comfortably below this upper bound with generous slack.
|
||||
assert!(
|
||||
elapsed < Duration::from_secs(5),
|
||||
"retries should complete promptly, took {elapsed:?}"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
#[allow(clippy::await_holding_lock)]
|
||||
async fn send_message_reuses_recent_completion_cache_entries() {
|
||||
@@ -676,6 +742,7 @@ async fn live_stream_smoke_test() {
|
||||
tools: None,
|
||||
tool_choice: None,
|
||||
stream: false,
|
||||
..Default::default()
|
||||
})
|
||||
.await
|
||||
.expect("live stream should start");
|
||||
@@ -856,5 +923,6 @@ fn sample_request(stream: bool) -> MessageRequest {
|
||||
}]),
|
||||
tool_choice: Some(ToolChoice::Auto),
|
||||
stream,
|
||||
..Default::default()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -88,6 +88,7 @@ async fn send_message_blocks_oversized_xai_requests_before_the_http_call() {
|
||||
tools: None,
|
||||
tool_choice: None,
|
||||
stream: false,
|
||||
..Default::default()
|
||||
})
|
||||
.await
|
||||
.expect_err("oversized request should fail local context-window preflight");
|
||||
@@ -496,6 +497,7 @@ fn sample_request(stream: bool) -> MessageRequest {
|
||||
}]),
|
||||
tool_choice: Some(ToolChoice::Auto),
|
||||
stream,
|
||||
..Default::default()
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
173
rust/crates/api/tests/proxy_integration.rs
Normal file
173
rust/crates/api/tests/proxy_integration.rs
Normal file
@@ -0,0 +1,173 @@
|
||||
use std::ffi::OsString;
|
||||
use std::sync::{Mutex, OnceLock};
|
||||
|
||||
use api::{build_http_client_with, ProxyConfig};
|
||||
|
||||
fn env_lock() -> std::sync::MutexGuard<'static, ()> {
|
||||
static LOCK: OnceLock<Mutex<()>> = OnceLock::new();
|
||||
LOCK.get_or_init(|| Mutex::new(()))
|
||||
.lock()
|
||||
.unwrap_or_else(std::sync::PoisonError::into_inner)
|
||||
}
|
||||
|
||||
struct EnvVarGuard {
|
||||
key: &'static str,
|
||||
original: Option<OsString>,
|
||||
}
|
||||
|
||||
impl EnvVarGuard {
|
||||
fn set(key: &'static str, value: Option<&str>) -> Self {
|
||||
let original = std::env::var_os(key);
|
||||
match value {
|
||||
Some(value) => std::env::set_var(key, value),
|
||||
None => std::env::remove_var(key),
|
||||
}
|
||||
Self { key, original }
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for EnvVarGuard {
|
||||
fn drop(&mut self) {
|
||||
match &self.original {
|
||||
Some(value) => std::env::set_var(self.key, value),
|
||||
None => std::env::remove_var(self.key),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn proxy_config_from_env_reads_uppercase_proxy_vars() {
|
||||
// given
|
||||
let _lock = env_lock();
|
||||
let _http = EnvVarGuard::set("HTTP_PROXY", Some("http://proxy.corp:3128"));
|
||||
let _https = EnvVarGuard::set("HTTPS_PROXY", Some("http://secure.corp:3129"));
|
||||
let _no = EnvVarGuard::set("NO_PROXY", Some("localhost,127.0.0.1"));
|
||||
let _http_lower = EnvVarGuard::set("http_proxy", None);
|
||||
let _https_lower = EnvVarGuard::set("https_proxy", None);
|
||||
let _no_lower = EnvVarGuard::set("no_proxy", None);
|
||||
|
||||
// when
|
||||
let config = ProxyConfig::from_env();
|
||||
|
||||
// then
|
||||
assert_eq!(config.http_proxy.as_deref(), Some("http://proxy.corp:3128"));
|
||||
assert_eq!(
|
||||
config.https_proxy.as_deref(),
|
||||
Some("http://secure.corp:3129")
|
||||
);
|
||||
assert_eq!(config.no_proxy.as_deref(), Some("localhost,127.0.0.1"));
|
||||
assert!(config.proxy_url.is_none());
|
||||
assert!(!config.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn proxy_config_from_env_reads_lowercase_proxy_vars() {
|
||||
// given
|
||||
let _lock = env_lock();
|
||||
let _http = EnvVarGuard::set("HTTP_PROXY", None);
|
||||
let _https = EnvVarGuard::set("HTTPS_PROXY", None);
|
||||
let _no = EnvVarGuard::set("NO_PROXY", None);
|
||||
let _http_lower = EnvVarGuard::set("http_proxy", Some("http://lower.corp:3128"));
|
||||
let _https_lower = EnvVarGuard::set("https_proxy", Some("http://lower-secure.corp:3129"));
|
||||
let _no_lower = EnvVarGuard::set("no_proxy", Some(".internal"));
|
||||
|
||||
// when
|
||||
let config = ProxyConfig::from_env();
|
||||
|
||||
// then
|
||||
assert_eq!(config.http_proxy.as_deref(), Some("http://lower.corp:3128"));
|
||||
assert_eq!(
|
||||
config.https_proxy.as_deref(),
|
||||
Some("http://lower-secure.corp:3129")
|
||||
);
|
||||
assert_eq!(config.no_proxy.as_deref(), Some(".internal"));
|
||||
assert!(!config.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn proxy_config_from_env_is_empty_when_no_vars_set() {
|
||||
// given
|
||||
let _lock = env_lock();
|
||||
let _http = EnvVarGuard::set("HTTP_PROXY", None);
|
||||
let _https = EnvVarGuard::set("HTTPS_PROXY", None);
|
||||
let _no = EnvVarGuard::set("NO_PROXY", None);
|
||||
let _http_lower = EnvVarGuard::set("http_proxy", None);
|
||||
let _https_lower = EnvVarGuard::set("https_proxy", None);
|
||||
let _no_lower = EnvVarGuard::set("no_proxy", None);
|
||||
|
||||
// when
|
||||
let config = ProxyConfig::from_env();
|
||||
|
||||
// then
|
||||
assert!(config.is_empty());
|
||||
assert!(config.http_proxy.is_none());
|
||||
assert!(config.https_proxy.is_none());
|
||||
assert!(config.no_proxy.is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn proxy_config_from_env_treats_empty_values_as_unset() {
|
||||
// given
|
||||
let _lock = env_lock();
|
||||
let _http = EnvVarGuard::set("HTTP_PROXY", Some(""));
|
||||
let _https = EnvVarGuard::set("HTTPS_PROXY", Some(""));
|
||||
let _http_lower = EnvVarGuard::set("http_proxy", Some(""));
|
||||
let _https_lower = EnvVarGuard::set("https_proxy", Some(""));
|
||||
let _no = EnvVarGuard::set("NO_PROXY", Some(""));
|
||||
let _no_lower = EnvVarGuard::set("no_proxy", Some(""));
|
||||
|
||||
// when
|
||||
let config = ProxyConfig::from_env();
|
||||
|
||||
// then
|
||||
assert!(config.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn build_client_with_env_proxy_config_succeeds() {
|
||||
// given
|
||||
let _lock = env_lock();
|
||||
let _http = EnvVarGuard::set("HTTP_PROXY", Some("http://proxy.corp:3128"));
|
||||
let _https = EnvVarGuard::set("HTTPS_PROXY", Some("http://secure.corp:3129"));
|
||||
let _no = EnvVarGuard::set("NO_PROXY", Some("localhost"));
|
||||
let _http_lower = EnvVarGuard::set("http_proxy", None);
|
||||
let _https_lower = EnvVarGuard::set("https_proxy", None);
|
||||
let _no_lower = EnvVarGuard::set("no_proxy", None);
|
||||
let config = ProxyConfig::from_env();
|
||||
|
||||
// when
|
||||
let result = build_http_client_with(&config);
|
||||
|
||||
// then
|
||||
assert!(result.is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn build_client_with_proxy_url_config_succeeds() {
|
||||
// given
|
||||
let config = ProxyConfig::from_proxy_url("http://unified.corp:3128");
|
||||
|
||||
// when
|
||||
let result = build_http_client_with(&config);
|
||||
|
||||
// then
|
||||
assert!(result.is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn proxy_config_from_env_prefers_uppercase_over_lowercase() {
|
||||
// given
|
||||
let _lock = env_lock();
|
||||
let _http_upper = EnvVarGuard::set("HTTP_PROXY", Some("http://upper.corp:3128"));
|
||||
let _http_lower = EnvVarGuard::set("http_proxy", Some("http://lower.corp:3128"));
|
||||
let _https = EnvVarGuard::set("HTTPS_PROXY", None);
|
||||
let _https_lower = EnvVarGuard::set("https_proxy", None);
|
||||
let _no = EnvVarGuard::set("NO_PROXY", None);
|
||||
let _no_lower = EnvVarGuard::set("no_proxy", None);
|
||||
|
||||
// when
|
||||
let config = ProxyConfig::from_env();
|
||||
|
||||
// then
|
||||
assert_eq!(config.http_proxy.as_deref(), Some("http://upper.corp:3128"));
|
||||
}
|
||||
@@ -221,8 +221,10 @@ const SLASH_COMMAND_SPECS: &[SlashCommandSpec] = &[
|
||||
SlashCommandSpec {
|
||||
name: "session",
|
||||
aliases: &[],
|
||||
summary: "List, switch, or fork managed local sessions",
|
||||
argument_hint: Some("[list|switch <session-id>|fork [branch-name]]"),
|
||||
summary: "List, switch, fork, or delete managed local sessions",
|
||||
argument_hint: Some(
|
||||
"[list|switch <session-id>|fork [branch-name]|delete <session-id> [--force]]",
|
||||
),
|
||||
resume_supported: false,
|
||||
},
|
||||
SlashCommandSpec {
|
||||
@@ -1188,6 +1190,9 @@ pub enum SlashCommand {
|
||||
AddDir {
|
||||
path: Option<String>,
|
||||
},
|
||||
History {
|
||||
count: Option<String>,
|
||||
},
|
||||
Unknown(String),
|
||||
}
|
||||
|
||||
@@ -1421,6 +1426,9 @@ pub fn validate_slash_command_input(
|
||||
"tag" => SlashCommand::Tag { label: remainder },
|
||||
"output-style" => SlashCommand::OutputStyle { style: remainder },
|
||||
"add-dir" => SlashCommand::AddDir { path: remainder },
|
||||
"history" => SlashCommand::History {
|
||||
count: optional_single_arg(command, &args, "[count]")?,
|
||||
},
|
||||
other => SlashCommand::Unknown(other.to_string()),
|
||||
}))
|
||||
}
|
||||
@@ -1520,7 +1528,7 @@ fn parse_session_command(args: &[&str]) -> Result<SlashCommand, SlashCommandPars
|
||||
action: Some("list".to_string()),
|
||||
target: None,
|
||||
}),
|
||||
["list", ..] => Err(usage_error("session", "[list|switch <session-id>|fork [branch-name]]")),
|
||||
["list", ..] => Err(usage_error("session", "[list|switch <session-id>|fork [branch-name]|delete <session-id> [--force]]")),
|
||||
["switch"] => Err(usage_error("session switch", "<session-id>")),
|
||||
["switch", target] => Ok(SlashCommand::Session {
|
||||
action: Some("switch".to_string()),
|
||||
@@ -1544,12 +1552,33 @@ fn parse_session_command(args: &[&str]) -> Result<SlashCommand, SlashCommandPars
|
||||
"session",
|
||||
"/session fork [branch-name]",
|
||||
)),
|
||||
[action, ..] => Err(command_error(
|
||||
["delete"] => Err(usage_error("session delete", "<session-id> [--force]")),
|
||||
["delete", target] => Ok(SlashCommand::Session {
|
||||
action: Some("delete".to_string()),
|
||||
target: Some((*target).to_string()),
|
||||
}),
|
||||
["delete", target, "--force"] => Ok(SlashCommand::Session {
|
||||
action: Some("delete-force".to_string()),
|
||||
target: Some((*target).to_string()),
|
||||
}),
|
||||
["delete", _target, unexpected] => Err(command_error(
|
||||
&format!(
|
||||
"Unknown /session action '{action}'. Use list, switch <session-id>, or fork [branch-name]."
|
||||
"Unsupported /session delete flag '{unexpected}'. Use --force to skip confirmation."
|
||||
),
|
||||
"session",
|
||||
"/session [list|switch <session-id>|fork [branch-name]]",
|
||||
"/session delete <session-id> [--force]",
|
||||
)),
|
||||
["delete", ..] => Err(command_error(
|
||||
"Unexpected arguments for /session delete.",
|
||||
"session",
|
||||
"/session delete <session-id> [--force]",
|
||||
)),
|
||||
[action, ..] => Err(command_error(
|
||||
&format!(
|
||||
"Unknown /session action '{action}'. Use list, switch <session-id>, fork [branch-name], or delete <session-id> [--force]."
|
||||
),
|
||||
"session",
|
||||
"/session [list|switch <session-id>|fork [branch-name]|delete <session-id> [--force]]",
|
||||
)),
|
||||
}
|
||||
}
|
||||
@@ -1786,24 +1815,29 @@ pub fn resume_supported_slash_commands() -> Vec<&'static SlashCommandSpec> {
|
||||
|
||||
fn slash_command_category(name: &str) -> &'static str {
|
||||
match name {
|
||||
"help" | "status" | "sandbox" | "model" | "permissions" | "cost" | "resume" | "session"
|
||||
| "version" | "login" | "logout" | "usage" | "stats" | "rename" | "privacy-settings" => {
|
||||
"Session & visibility"
|
||||
}
|
||||
"compact" | "clear" | "config" | "memory" | "init" | "diff" | "commit" | "pr" | "issue"
|
||||
| "export" | "plugin" | "branch" | "add-dir" | "files" | "hooks" | "release-notes" => {
|
||||
"Workspace & git"
|
||||
}
|
||||
"agents" | "skills" | "teleport" | "debug-tool-call" | "mcp" | "context" | "tasks"
|
||||
| "doctor" | "ide" | "desktop" => "Discovery & debugging",
|
||||
"bughunter" | "ultraplan" | "review" | "security-review" | "advisor" | "insights" => {
|
||||
"Analysis & automation"
|
||||
}
|
||||
"theme" | "vim" | "voice" | "color" | "effort" | "fast" | "brief" | "output-style"
|
||||
| "keybindings" | "stickers" => "Appearance & input",
|
||||
"copy" | "share" | "feedback" | "summary" | "tag" | "thinkback" | "plan" | "exit"
|
||||
| "upgrade" | "rewind" => "Communication & control",
|
||||
_ => "Other",
|
||||
"help" | "status" | "cost" | "resume" | "session" | "version" | "login" | "logout"
|
||||
| "usage" | "stats" | "rename" | "clear" | "compact" | "history" | "tokens" | "cache"
|
||||
| "exit" | "summary" | "tag" | "thinkback" | "copy" | "share" | "feedback" | "rewind"
|
||||
| "pin" | "unpin" | "bookmarks" | "context" | "files" | "focus" | "unfocus" | "retry"
|
||||
| "stop" | "undo" => "Session",
|
||||
"diff" | "commit" | "pr" | "issue" | "branch" | "blame" | "log" | "git" | "stash"
|
||||
| "init" | "export" | "plan" | "review" | "security-review" | "bughunter" | "ultraplan"
|
||||
| "teleport" | "refactor" | "fix" | "autofix" | "explain" | "docs" | "perf" | "search"
|
||||
| "references" | "definition" | "hover" | "symbols" | "map" | "web" | "image"
|
||||
| "screenshot" | "paste" | "listen" | "speak" | "test" | "lint" | "build" | "run"
|
||||
| "format" | "parallel" | "multi" | "macro" | "alias" | "templates" | "migrate"
|
||||
| "benchmark" | "cron" | "agent" | "subagent" | "agents" | "skills" | "team" | "plugin"
|
||||
| "mcp" | "hooks" | "tasks" | "advisor" | "insights" | "release-notes" | "chat"
|
||||
| "approve" | "deny" | "allowed-tools" | "add-dir" => "Tools",
|
||||
"model" | "permissions" | "config" | "memory" | "theme" | "vim" | "voice" | "color"
|
||||
| "effort" | "fast" | "brief" | "output-style" | "keybindings" | "privacy-settings"
|
||||
| "stickers" | "language" | "profile" | "max-tokens" | "temperature" | "system-prompt"
|
||||
| "api-key" | "terminal-setup" | "notifications" | "telemetry" | "providers" | "env"
|
||||
| "project" | "reasoning" | "budget" | "rate-limit" | "workspace" | "reset" | "ide"
|
||||
| "desktop" | "upgrade" => "Config",
|
||||
"debug-tool-call" | "doctor" | "sandbox" | "diagnostics" | "tool-details" | "changelog"
|
||||
| "metrics" => "Debug",
|
||||
_ => "Tools",
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1912,12 +1946,7 @@ pub fn render_slash_command_help() -> String {
|
||||
String::new(),
|
||||
];
|
||||
|
||||
let categories = [
|
||||
"Session & visibility",
|
||||
"Workspace & git",
|
||||
"Discovery & debugging",
|
||||
"Analysis & automation",
|
||||
];
|
||||
let categories = ["Session", "Tools", "Config", "Debug"];
|
||||
|
||||
for category in categories {
|
||||
lines.push(category.to_string());
|
||||
@@ -1930,6 +1959,12 @@ pub fn render_slash_command_help() -> String {
|
||||
lines.push(String::new());
|
||||
}
|
||||
|
||||
lines.push("Keyboard shortcuts".to_string());
|
||||
lines.push(" Up/Down Navigate prompt history".to_string());
|
||||
lines.push(" Tab Complete commands, modes, and recent sessions".to_string());
|
||||
lines.push(" Ctrl-C Clear input (or exit on empty prompt)".to_string());
|
||||
lines.push(" Shift+Enter/Ctrl+J Insert a newline".to_string());
|
||||
|
||||
lines
|
||||
.into_iter()
|
||||
.rev()
|
||||
@@ -2314,8 +2349,7 @@ pub fn resolve_skill_invocation(
|
||||
.unwrap_or_default();
|
||||
if !skill_token.is_empty() {
|
||||
if let Err(error) = resolve_skill_path(cwd, skill_token) {
|
||||
let mut message =
|
||||
format!("Unknown skill: {skill_token} ({error})");
|
||||
let mut message = format!("Unknown skill: {skill_token} ({error})");
|
||||
let roots = discover_skill_roots(cwd);
|
||||
if let Ok(available) = load_skills_from_roots(&roots) {
|
||||
let names: Vec<String> = available
|
||||
@@ -2324,15 +2358,10 @@ pub fn resolve_skill_invocation(
|
||||
.map(|s| s.name.clone())
|
||||
.collect();
|
||||
if !names.is_empty() {
|
||||
message.push_str(&format!(
|
||||
"\n Available skills: {}",
|
||||
names.join(", ")
|
||||
));
|
||||
message.push_str(&format!("\n Available skills: {}", names.join(", ")));
|
||||
}
|
||||
}
|
||||
message.push_str(
|
||||
"\n Usage: /skills [list|install <path>|help|<skill> [args]]",
|
||||
);
|
||||
message.push_str("\n Usage: /skills [list|install <path>|help|<skill> [args]]");
|
||||
return Err(message);
|
||||
}
|
||||
}
|
||||
@@ -3942,6 +3971,7 @@ pub fn handle_slash_command(
|
||||
| SlashCommand::Tag { .. }
|
||||
| SlashCommand::OutputStyle { .. }
|
||||
| SlashCommand::AddDir { .. }
|
||||
| SlashCommand::History { .. }
|
||||
| SlashCommand::Unknown(_) => None,
|
||||
}
|
||||
}
|
||||
@@ -4256,6 +4286,47 @@ mod tests {
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parses_history_command_without_count() {
|
||||
// given
|
||||
let input = "/history";
|
||||
|
||||
// when
|
||||
let parsed = SlashCommand::parse(input);
|
||||
|
||||
// then
|
||||
assert_eq!(parsed, Ok(Some(SlashCommand::History { count: None })));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parses_history_command_with_numeric_count() {
|
||||
// given
|
||||
let input = "/history 25";
|
||||
|
||||
// when
|
||||
let parsed = SlashCommand::parse(input);
|
||||
|
||||
// then
|
||||
assert_eq!(
|
||||
parsed,
|
||||
Ok(Some(SlashCommand::History {
|
||||
count: Some("25".to_string())
|
||||
}))
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn rejects_history_with_extra_arguments() {
|
||||
// given
|
||||
let input = "/history 25 extra";
|
||||
|
||||
// when
|
||||
let error = parse_error_message(input);
|
||||
|
||||
// then
|
||||
assert!(error.contains("Usage: /history [count]"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn rejects_unexpected_arguments_for_no_arg_commands() {
|
||||
// given
|
||||
@@ -4297,7 +4368,7 @@ mod tests {
|
||||
|
||||
// then
|
||||
assert!(error.contains("Usage: /teleport <symbol-or-path>"));
|
||||
assert!(error.contains(" Category Discovery & debugging"));
|
||||
assert!(error.contains(" Category Tools"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
@@ -4371,10 +4442,10 @@ mod tests {
|
||||
let help = render_slash_command_help();
|
||||
assert!(help.contains("Start here /status, /diff, /agents, /skills, /commit"));
|
||||
assert!(help.contains("[resume] also works with --resume SESSION.jsonl"));
|
||||
assert!(help.contains("Session & visibility"));
|
||||
assert!(help.contains("Workspace & git"));
|
||||
assert!(help.contains("Discovery & debugging"));
|
||||
assert!(help.contains("Analysis & automation"));
|
||||
assert!(help.contains("Session"));
|
||||
assert!(help.contains("Tools"));
|
||||
assert!(help.contains("Config"));
|
||||
assert!(help.contains("Debug"));
|
||||
assert!(help.contains("/help"));
|
||||
assert!(help.contains("/status"));
|
||||
assert!(help.contains("/sandbox"));
|
||||
@@ -4398,7 +4469,7 @@ mod tests {
|
||||
assert!(help.contains("/diff"));
|
||||
assert!(help.contains("/version"));
|
||||
assert!(help.contains("/export [file]"));
|
||||
assert!(help.contains("/session [list|switch <session-id>|fork [branch-name]]"));
|
||||
assert!(help.contains("/session"), "help must mention /session");
|
||||
assert!(help.contains("/sandbox"));
|
||||
assert!(help.contains(
|
||||
"/plugin [list|install <path>|enable <name>|disable <name>|uninstall <id>|update <id>]"
|
||||
@@ -4411,6 +4482,53 @@ mod tests {
|
||||
assert!(resume_supported_slash_commands().len() >= 39);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn renders_help_with_grouped_categories_and_keyboard_shortcuts() {
|
||||
// given
|
||||
let categories = ["Session", "Tools", "Config", "Debug"];
|
||||
|
||||
// when
|
||||
let help = render_slash_command_help();
|
||||
|
||||
// then
|
||||
for category in categories {
|
||||
assert!(
|
||||
help.contains(category),
|
||||
"expected help to contain category {category}"
|
||||
);
|
||||
}
|
||||
let session_index = help.find("Session").expect("Session header should exist");
|
||||
let tools_index = help.find("Tools").expect("Tools header should exist");
|
||||
let config_index = help.find("Config").expect("Config header should exist");
|
||||
let debug_index = help.find("Debug").expect("Debug header should exist");
|
||||
assert!(session_index < tools_index);
|
||||
assert!(tools_index < config_index);
|
||||
assert!(config_index < debug_index);
|
||||
|
||||
assert!(help.contains("Keyboard shortcuts"));
|
||||
assert!(help.contains("Up/Down Navigate prompt history"));
|
||||
assert!(help.contains("Tab Complete commands, modes, and recent sessions"));
|
||||
assert!(help.contains("Ctrl-C Clear input (or exit on empty prompt)"));
|
||||
assert!(help.contains("Shift+Enter/Ctrl+J Insert a newline"));
|
||||
|
||||
// every command should still render with a summary line
|
||||
for spec in slash_command_specs() {
|
||||
let usage = match spec.argument_hint {
|
||||
Some(hint) => format!("/{} {hint}", spec.name),
|
||||
None => format!("/{}", spec.name),
|
||||
};
|
||||
assert!(
|
||||
help.contains(&usage),
|
||||
"expected help to contain command {usage}"
|
||||
);
|
||||
assert!(
|
||||
help.contains(spec.summary),
|
||||
"expected help to contain summary for /{}",
|
||||
spec.name
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn renders_per_command_help_detail() {
|
||||
// given
|
||||
@@ -4423,7 +4541,7 @@ mod tests {
|
||||
assert!(help.contains("/plugin"));
|
||||
assert!(help.contains("Summary Manage Claw Code plugins"));
|
||||
assert!(help.contains("Aliases /plugins, /marketplace"));
|
||||
assert!(help.contains("Category Workspace & git"));
|
||||
assert!(help.contains("Category Tools"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
@@ -4431,7 +4549,7 @@ mod tests {
|
||||
let help = render_slash_command_help_detail("mcp").expect("detail help should exist");
|
||||
assert!(help.contains("/mcp"));
|
||||
assert!(help.contains("Summary Inspect configured MCP servers"));
|
||||
assert!(help.contains("Category Discovery & debugging"));
|
||||
assert!(help.contains("Category Tools"));
|
||||
assert!(help.contains("Resume Supported with --resume SESSION.jsonl"));
|
||||
}
|
||||
|
||||
|
||||
@@ -13,7 +13,7 @@ regex = "1"
|
||||
serde = { version = "1", features = ["derive"] }
|
||||
serde_json.workspace = true
|
||||
telemetry = { path = "../telemetry" }
|
||||
tokio = { version = "1", features = ["io-util", "macros", "process", "rt", "rt-multi-thread", "time"] }
|
||||
tokio = { version = "1", features = ["io-std", "io-util", "macros", "process", "rt", "rt-multi-thread", "time"] }
|
||||
walkdir = "2"
|
||||
|
||||
[lints]
|
||||
|
||||
@@ -48,6 +48,7 @@ pub struct RuntimePluginConfig {
|
||||
install_root: Option<String>,
|
||||
registry_path: Option<String>,
|
||||
bundled_root: Option<String>,
|
||||
max_output_tokens: Option<u32>,
|
||||
}
|
||||
|
||||
/// Structured feature configuration consumed by runtime subsystems.
|
||||
@@ -58,9 +59,21 @@ pub struct RuntimeFeatureConfig {
|
||||
mcp: McpConfigCollection,
|
||||
oauth: Option<OAuthConfig>,
|
||||
model: Option<String>,
|
||||
aliases: BTreeMap<String, String>,
|
||||
permission_mode: Option<ResolvedPermissionMode>,
|
||||
permission_rules: RuntimePermissionRuleConfig,
|
||||
sandbox: SandboxConfig,
|
||||
provider_fallbacks: ProviderFallbackConfig,
|
||||
trusted_roots: Vec<String>,
|
||||
}
|
||||
|
||||
/// Ordered chain of fallback model identifiers used when the primary
|
||||
/// provider returns a retryable failure (429/500/503/etc.). The chain is
|
||||
/// strict: each entry is tried in order until one succeeds.
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Default)]
|
||||
pub struct ProviderFallbackConfig {
|
||||
primary: Option<String>,
|
||||
fallbacks: Vec<String>,
|
||||
}
|
||||
|
||||
/// Hook command lists grouped by lifecycle stage.
|
||||
@@ -259,17 +272,33 @@ impl ConfigLoader {
|
||||
let mut merged = BTreeMap::new();
|
||||
let mut loaded_entries = Vec::new();
|
||||
let mut mcp_servers = BTreeMap::new();
|
||||
let mut all_warnings = Vec::new();
|
||||
|
||||
for entry in self.discover() {
|
||||
let Some(value) = read_optional_json_object(&entry.path)? else {
|
||||
crate::config_validate::check_unsupported_format(&entry.path)?;
|
||||
let Some(parsed) = read_optional_json_object(&entry.path)? else {
|
||||
continue;
|
||||
};
|
||||
validate_optional_hooks_config(&value, &entry.path)?;
|
||||
merge_mcp_servers(&mut mcp_servers, entry.source, &value, &entry.path)?;
|
||||
deep_merge_objects(&mut merged, &value);
|
||||
let validation = crate::config_validate::validate_config_file(
|
||||
&parsed.object,
|
||||
&parsed.source,
|
||||
&entry.path,
|
||||
);
|
||||
if !validation.is_ok() {
|
||||
let first_error = &validation.errors[0];
|
||||
return Err(ConfigError::Parse(first_error.to_string()));
|
||||
}
|
||||
all_warnings.extend(validation.warnings);
|
||||
validate_optional_hooks_config(&parsed.object, &entry.path)?;
|
||||
merge_mcp_servers(&mut mcp_servers, entry.source, &parsed.object, &entry.path)?;
|
||||
deep_merge_objects(&mut merged, &parsed.object);
|
||||
loaded_entries.push(entry);
|
||||
}
|
||||
|
||||
for warning in &all_warnings {
|
||||
eprintln!("warning: {warning}");
|
||||
}
|
||||
|
||||
let merged_value = JsonValue::Object(merged.clone());
|
||||
|
||||
let feature_config = RuntimeFeatureConfig {
|
||||
@@ -280,9 +309,12 @@ impl ConfigLoader {
|
||||
},
|
||||
oauth: parse_optional_oauth_config(&merged_value, "merged settings.oauth")?,
|
||||
model: parse_optional_model(&merged_value),
|
||||
aliases: parse_optional_aliases(&merged_value)?,
|
||||
permission_mode: parse_optional_permission_mode(&merged_value)?,
|
||||
permission_rules: parse_optional_permission_rules(&merged_value)?,
|
||||
sandbox: parse_optional_sandbox_config(&merged_value)?,
|
||||
provider_fallbacks: parse_optional_provider_fallbacks(&merged_value)?,
|
||||
trusted_roots: parse_optional_trusted_roots(&merged_value)?,
|
||||
};
|
||||
|
||||
Ok(RuntimeConfig {
|
||||
@@ -353,6 +385,11 @@ impl RuntimeConfig {
|
||||
self.feature_config.model.as_deref()
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
pub fn aliases(&self) -> &BTreeMap<String, String> {
|
||||
&self.feature_config.aliases
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
pub fn permission_mode(&self) -> Option<ResolvedPermissionMode> {
|
||||
self.feature_config.permission_mode
|
||||
@@ -367,6 +404,16 @@ impl RuntimeConfig {
|
||||
pub fn sandbox(&self) -> &SandboxConfig {
|
||||
&self.feature_config.sandbox
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
pub fn provider_fallbacks(&self) -> &ProviderFallbackConfig {
|
||||
&self.feature_config.provider_fallbacks
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
pub fn trusted_roots(&self) -> &[String] {
|
||||
&self.feature_config.trusted_roots
|
||||
}
|
||||
}
|
||||
|
||||
impl RuntimeFeatureConfig {
|
||||
@@ -407,6 +454,11 @@ impl RuntimeFeatureConfig {
|
||||
self.model.as_deref()
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
pub fn aliases(&self) -> &BTreeMap<String, String> {
|
||||
&self.aliases
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
pub fn permission_mode(&self) -> Option<ResolvedPermissionMode> {
|
||||
self.permission_mode
|
||||
@@ -421,6 +473,38 @@ impl RuntimeFeatureConfig {
|
||||
pub fn sandbox(&self) -> &SandboxConfig {
|
||||
&self.sandbox
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
pub fn provider_fallbacks(&self) -> &ProviderFallbackConfig {
|
||||
&self.provider_fallbacks
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
pub fn trusted_roots(&self) -> &[String] {
|
||||
&self.trusted_roots
|
||||
}
|
||||
}
|
||||
|
||||
impl ProviderFallbackConfig {
|
||||
#[must_use]
|
||||
pub fn new(primary: Option<String>, fallbacks: Vec<String>) -> Self {
|
||||
Self { primary, fallbacks }
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
pub fn primary(&self) -> Option<&str> {
|
||||
self.primary.as_deref()
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
pub fn fallbacks(&self) -> &[String] {
|
||||
&self.fallbacks
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
pub fn is_empty(&self) -> bool {
|
||||
self.fallbacks.is_empty()
|
||||
}
|
||||
}
|
||||
|
||||
impl RuntimePluginConfig {
|
||||
@@ -449,6 +533,15 @@ impl RuntimePluginConfig {
|
||||
self.bundled_root.as_deref()
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
pub fn max_output_tokens(&self) -> Option<u32> {
|
||||
self.max_output_tokens
|
||||
}
|
||||
|
||||
pub fn set_max_output_tokens(&mut self, max_output_tokens: Option<u32>) {
|
||||
self.max_output_tokens = max_output_tokens;
|
||||
}
|
||||
|
||||
pub fn set_plugin_state(&mut self, plugin_id: String, enabled: bool) {
|
||||
self.enabled_plugins.insert(plugin_id, enabled);
|
||||
}
|
||||
@@ -572,9 +665,13 @@ impl McpServerConfig {
|
||||
}
|
||||
}
|
||||
|
||||
fn read_optional_json_object(
|
||||
path: &Path,
|
||||
) -> Result<Option<BTreeMap<String, JsonValue>>, ConfigError> {
|
||||
/// Parsed JSON object paired with its raw source text for validation.
|
||||
struct ParsedConfigFile {
|
||||
object: BTreeMap<String, JsonValue>,
|
||||
source: String,
|
||||
}
|
||||
|
||||
fn read_optional_json_object(path: &Path) -> Result<Option<ParsedConfigFile>, ConfigError> {
|
||||
let is_legacy_config = path.file_name().and_then(|name| name.to_str()) == Some(".claw.json");
|
||||
let contents = match fs::read_to_string(path) {
|
||||
Ok(contents) => contents,
|
||||
@@ -583,7 +680,10 @@ fn read_optional_json_object(
|
||||
};
|
||||
|
||||
if contents.trim().is_empty() {
|
||||
return Ok(Some(BTreeMap::new()));
|
||||
return Ok(Some(ParsedConfigFile {
|
||||
object: BTreeMap::new(),
|
||||
source: contents,
|
||||
}));
|
||||
}
|
||||
|
||||
let parsed = match JsonValue::parse(&contents) {
|
||||
@@ -600,7 +700,10 @@ fn read_optional_json_object(
|
||||
path.display()
|
||||
)));
|
||||
};
|
||||
Ok(Some(object.clone()))
|
||||
Ok(Some(ParsedConfigFile {
|
||||
object: object.clone(),
|
||||
source: contents,
|
||||
}))
|
||||
}
|
||||
|
||||
fn merge_mcp_servers(
|
||||
@@ -637,6 +740,13 @@ fn parse_optional_model(root: &JsonValue) -> Option<String> {
|
||||
.map(ToOwned::to_owned)
|
||||
}
|
||||
|
||||
fn parse_optional_aliases(root: &JsonValue) -> Result<BTreeMap<String, String>, ConfigError> {
|
||||
let Some(object) = root.as_object() else {
|
||||
return Ok(BTreeMap::new());
|
||||
};
|
||||
Ok(optional_string_map(object, "aliases", "merged settings")?.unwrap_or_default())
|
||||
}
|
||||
|
||||
fn parse_optional_hooks_config(root: &JsonValue) -> Result<RuntimeHookConfig, ConfigError> {
|
||||
let Some(object) = root.as_object() else {
|
||||
return Ok(RuntimeHookConfig::default());
|
||||
@@ -714,6 +824,7 @@ fn parse_optional_plugin_config(root: &JsonValue) -> Result<RuntimePluginConfig,
|
||||
optional_string(plugins, "registryPath", "merged settings.plugins")?.map(str::to_string);
|
||||
config.bundled_root =
|
||||
optional_string(plugins, "bundledRoot", "merged settings.plugins")?.map(str::to_string);
|
||||
config.max_output_tokens = optional_u32(plugins, "maxOutputTokens", "merged settings.plugins")?;
|
||||
Ok(config)
|
||||
}
|
||||
|
||||
@@ -776,6 +887,33 @@ fn parse_optional_sandbox_config(root: &JsonValue) -> Result<SandboxConfig, Conf
|
||||
})
|
||||
}
|
||||
|
||||
fn parse_optional_provider_fallbacks(
|
||||
root: &JsonValue,
|
||||
) -> Result<ProviderFallbackConfig, ConfigError> {
|
||||
let Some(object) = root.as_object() else {
|
||||
return Ok(ProviderFallbackConfig::default());
|
||||
};
|
||||
let Some(value) = object.get("providerFallbacks") else {
|
||||
return Ok(ProviderFallbackConfig::default());
|
||||
};
|
||||
let entry = expect_object(value, "merged settings.providerFallbacks")?;
|
||||
let primary =
|
||||
optional_string(entry, "primary", "merged settings.providerFallbacks")?.map(str::to_string);
|
||||
let fallbacks = optional_string_array(entry, "fallbacks", "merged settings.providerFallbacks")?
|
||||
.unwrap_or_default();
|
||||
Ok(ProviderFallbackConfig { primary, fallbacks })
|
||||
}
|
||||
|
||||
fn parse_optional_trusted_roots(root: &JsonValue) -> Result<Vec<String>, ConfigError> {
|
||||
let Some(object) = root.as_object() else {
|
||||
return Ok(Vec::new());
|
||||
};
|
||||
Ok(
|
||||
optional_string_array(object, "trustedRoots", "merged settings.trustedRoots")?
|
||||
.unwrap_or_default(),
|
||||
)
|
||||
}
|
||||
|
||||
fn parse_filesystem_mode_label(value: &str) -> Result<FilesystemIsolationMode, ConfigError> {
|
||||
match value {
|
||||
"off" => Ok(FilesystemIsolationMode::Off),
|
||||
@@ -957,6 +1095,27 @@ fn optional_u16(
|
||||
}
|
||||
}
|
||||
|
||||
fn optional_u32(
|
||||
object: &BTreeMap<String, JsonValue>,
|
||||
key: &str,
|
||||
context: &str,
|
||||
) -> Result<Option<u32>, ConfigError> {
|
||||
match object.get(key) {
|
||||
Some(value) => {
|
||||
let Some(number) = value.as_i64() else {
|
||||
return Err(ConfigError::Parse(format!(
|
||||
"{context}: field {key} must be a non-negative integer"
|
||||
)));
|
||||
};
|
||||
let number = u32::try_from(number).map_err(|_| {
|
||||
ConfigError::Parse(format!("{context}: field {key} is out of range"))
|
||||
})?;
|
||||
Ok(Some(number))
|
||||
}
|
||||
None => Ok(None),
|
||||
}
|
||||
}
|
||||
|
||||
fn optional_u64(
|
||||
object: &BTreeMap<String, JsonValue>,
|
||||
key: &str,
|
||||
@@ -1247,6 +1406,113 @@ mod tests {
|
||||
fs::remove_dir_all(root).expect("cleanup temp dir");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parses_provider_fallbacks_chain_with_primary_and_ordered_fallbacks() {
|
||||
// given
|
||||
let root = temp_dir();
|
||||
let cwd = root.join("project");
|
||||
let home = root.join("home").join(".claw");
|
||||
fs::create_dir_all(cwd.join(".claw")).expect("project config dir");
|
||||
fs::create_dir_all(&home).expect("home config dir");
|
||||
fs::write(
|
||||
home.join("settings.json"),
|
||||
r#"{
|
||||
"providerFallbacks": {
|
||||
"primary": "claude-opus-4-6",
|
||||
"fallbacks": ["grok-3", "grok-3-mini"]
|
||||
}
|
||||
}"#,
|
||||
)
|
||||
.expect("write provider fallback settings");
|
||||
|
||||
// when
|
||||
let loaded = ConfigLoader::new(&cwd, &home)
|
||||
.load()
|
||||
.expect("config should load");
|
||||
|
||||
// then
|
||||
let chain = loaded.provider_fallbacks();
|
||||
assert_eq!(chain.primary(), Some("claude-opus-4-6"));
|
||||
assert_eq!(
|
||||
chain.fallbacks(),
|
||||
&["grok-3".to_string(), "grok-3-mini".to_string()]
|
||||
);
|
||||
assert!(!chain.is_empty());
|
||||
|
||||
fs::remove_dir_all(root).expect("cleanup temp dir");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn provider_fallbacks_default_is_empty_when_unset() {
|
||||
// given
|
||||
let root = temp_dir();
|
||||
let cwd = root.join("project");
|
||||
let home = root.join("home").join(".claw");
|
||||
fs::create_dir_all(&home).expect("home config dir");
|
||||
fs::create_dir_all(&cwd).expect("project dir");
|
||||
fs::write(home.join("settings.json"), "{}").expect("write empty settings");
|
||||
|
||||
// when
|
||||
let loaded = ConfigLoader::new(&cwd, &home)
|
||||
.load()
|
||||
.expect("config should load");
|
||||
|
||||
// then
|
||||
let chain = loaded.provider_fallbacks();
|
||||
assert_eq!(chain.primary(), None);
|
||||
assert!(chain.fallbacks().is_empty());
|
||||
assert!(chain.is_empty());
|
||||
|
||||
fs::remove_dir_all(root).expect("cleanup temp dir");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parses_trusted_roots_from_settings() {
|
||||
// given
|
||||
let root = temp_dir();
|
||||
let cwd = root.join("project");
|
||||
let home = root.join("home").join(".claw");
|
||||
fs::create_dir_all(&home).expect("home config dir");
|
||||
fs::create_dir_all(&cwd).expect("project dir");
|
||||
fs::write(
|
||||
home.join("settings.json"),
|
||||
r#"{"trustedRoots": ["/tmp/worktrees", "/home/user/projects"]}"#,
|
||||
)
|
||||
.expect("write settings");
|
||||
|
||||
// when
|
||||
let loaded = ConfigLoader::new(&cwd, &home)
|
||||
.load()
|
||||
.expect("config should load");
|
||||
|
||||
// then
|
||||
let roots = loaded.trusted_roots();
|
||||
assert_eq!(roots, ["/tmp/worktrees", "/home/user/projects"]);
|
||||
|
||||
fs::remove_dir_all(root).expect("cleanup temp dir");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn trusted_roots_default_is_empty_when_unset() {
|
||||
// given
|
||||
let root = temp_dir();
|
||||
let cwd = root.join("project");
|
||||
let home = root.join("home").join(".claw");
|
||||
fs::create_dir_all(&home).expect("home config dir");
|
||||
fs::create_dir_all(&cwd).expect("project dir");
|
||||
fs::write(home.join("settings.json"), "{}").expect("write empty settings");
|
||||
|
||||
// when
|
||||
let loaded = ConfigLoader::new(&cwd, &home)
|
||||
.load()
|
||||
.expect("config should load");
|
||||
|
||||
// then
|
||||
assert!(loaded.trusted_roots().is_empty());
|
||||
|
||||
fs::remove_dir_all(root).expect("cleanup temp dir");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parses_typed_mcp_and_oauth_config() {
|
||||
let root = temp_dir();
|
||||
@@ -1493,6 +1759,49 @@ mod tests {
|
||||
fs::remove_dir_all(root).expect("cleanup temp dir");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parses_user_defined_model_aliases_from_settings() {
|
||||
// given
|
||||
let root = temp_dir();
|
||||
let cwd = root.join("project");
|
||||
let home = root.join("home").join(".claw");
|
||||
fs::create_dir_all(cwd.join(".claw")).expect("project config dir");
|
||||
fs::create_dir_all(&home).expect("home config dir");
|
||||
|
||||
fs::write(
|
||||
home.join("settings.json"),
|
||||
r#"{"aliases":{"fast":"claude-haiku-4-5-20251213","smart":"claude-opus-4-6"}}"#,
|
||||
)
|
||||
.expect("write user settings");
|
||||
fs::write(
|
||||
cwd.join(".claw").join("settings.local.json"),
|
||||
r#"{"aliases":{"smart":"claude-sonnet-4-6","cheap":"grok-3-mini"}}"#,
|
||||
)
|
||||
.expect("write local settings");
|
||||
|
||||
// when
|
||||
let loaded = ConfigLoader::new(&cwd, &home)
|
||||
.load()
|
||||
.expect("config should load");
|
||||
|
||||
// then
|
||||
let aliases = loaded.aliases();
|
||||
assert_eq!(
|
||||
aliases.get("fast").map(String::as_str),
|
||||
Some("claude-haiku-4-5-20251213")
|
||||
);
|
||||
assert_eq!(
|
||||
aliases.get("smart").map(String::as_str),
|
||||
Some("claude-sonnet-4-6")
|
||||
);
|
||||
assert_eq!(
|
||||
aliases.get("cheap").map(String::as_str),
|
||||
Some("grok-3-mini")
|
||||
);
|
||||
|
||||
fs::remove_dir_all(root).expect("cleanup temp dir");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn empty_settings_file_loads_defaults() {
|
||||
// given
|
||||
@@ -1574,12 +1883,13 @@ mod tests {
|
||||
.load()
|
||||
.expect_err("config should fail");
|
||||
|
||||
// then
|
||||
// then — config validation now catches the mixed array before the hooks parser
|
||||
let rendered = error.to_string();
|
||||
assert!(rendered.contains(&format!(
|
||||
"{}: hooks: field PreToolUse must contain only strings",
|
||||
project_settings.display()
|
||||
)));
|
||||
assert!(
|
||||
rendered.contains("hooks.PreToolUse")
|
||||
&& rendered.contains("must be an array of strings"),
|
||||
"expected validation error for hooks.PreToolUse, got: {rendered}"
|
||||
);
|
||||
assert!(!rendered.contains("merged settings.hooks"));
|
||||
|
||||
fs::remove_dir_all(root).expect("cleanup temp dir");
|
||||
@@ -1645,4 +1955,157 @@ mod tests {
|
||||
assert!(config.state_for("missing", true));
|
||||
assert!(!config.state_for("missing", false));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn validates_unknown_top_level_keys_with_line_and_field_name() {
|
||||
// given
|
||||
let root = temp_dir();
|
||||
let cwd = root.join("project");
|
||||
let home = root.join("home").join(".claw");
|
||||
let user_settings = home.join("settings.json");
|
||||
fs::create_dir_all(&home).expect("home config dir");
|
||||
fs::create_dir_all(&cwd).expect("project dir");
|
||||
fs::write(
|
||||
&user_settings,
|
||||
"{\n \"model\": \"opus\",\n \"telemetry\": true\n}\n",
|
||||
)
|
||||
.expect("write user settings");
|
||||
|
||||
// when
|
||||
let error = ConfigLoader::new(&cwd, &home)
|
||||
.load()
|
||||
.expect_err("config should fail");
|
||||
|
||||
// then
|
||||
let rendered = error.to_string();
|
||||
assert!(
|
||||
rendered.contains(&user_settings.display().to_string()),
|
||||
"error should include file path, got: {rendered}"
|
||||
);
|
||||
assert!(
|
||||
rendered.contains("line 3"),
|
||||
"error should include line number, got: {rendered}"
|
||||
);
|
||||
assert!(
|
||||
rendered.contains("telemetry"),
|
||||
"error should name the offending field, got: {rendered}"
|
||||
);
|
||||
|
||||
fs::remove_dir_all(root).expect("cleanup temp dir");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn validates_deprecated_top_level_keys_with_replacement_guidance() {
|
||||
// given
|
||||
let root = temp_dir();
|
||||
let cwd = root.join("project");
|
||||
let home = root.join("home").join(".claw");
|
||||
let user_settings = home.join("settings.json");
|
||||
fs::create_dir_all(&home).expect("home config dir");
|
||||
fs::create_dir_all(&cwd).expect("project dir");
|
||||
fs::write(
|
||||
&user_settings,
|
||||
"{\n \"model\": \"opus\",\n \"allowedTools\": [\"Read\"]\n}\n",
|
||||
)
|
||||
.expect("write user settings");
|
||||
|
||||
// when
|
||||
let error = ConfigLoader::new(&cwd, &home)
|
||||
.load()
|
||||
.expect_err("config should fail");
|
||||
|
||||
// then
|
||||
let rendered = error.to_string();
|
||||
assert!(
|
||||
rendered.contains(&user_settings.display().to_string()),
|
||||
"error should include file path, got: {rendered}"
|
||||
);
|
||||
assert!(
|
||||
rendered.contains("line 3"),
|
||||
"error should include line number, got: {rendered}"
|
||||
);
|
||||
assert!(
|
||||
rendered.contains("allowedTools"),
|
||||
"error should call out the unknown field, got: {rendered}"
|
||||
);
|
||||
// allowedTools is an unknown key; validator should name it in the error
|
||||
assert!(
|
||||
rendered.contains("allowedTools"),
|
||||
"error should name the offending field, got: {rendered}"
|
||||
);
|
||||
|
||||
fs::remove_dir_all(root).expect("cleanup temp dir");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn validates_wrong_type_for_known_field_with_field_path() {
|
||||
// given
|
||||
let root = temp_dir();
|
||||
let cwd = root.join("project");
|
||||
let home = root.join("home").join(".claw");
|
||||
let user_settings = home.join("settings.json");
|
||||
fs::create_dir_all(&home).expect("home config dir");
|
||||
fs::create_dir_all(&cwd).expect("project dir");
|
||||
fs::write(
|
||||
&user_settings,
|
||||
"{\n \"hooks\": {\n \"PreToolUse\": \"not-an-array\"\n }\n}\n",
|
||||
)
|
||||
.expect("write user settings");
|
||||
|
||||
// when
|
||||
let error = ConfigLoader::new(&cwd, &home)
|
||||
.load()
|
||||
.expect_err("config should fail");
|
||||
|
||||
// then
|
||||
let rendered = error.to_string();
|
||||
assert!(
|
||||
rendered.contains(&user_settings.display().to_string()),
|
||||
"error should include file path, got: {rendered}"
|
||||
);
|
||||
assert!(
|
||||
rendered.contains("hooks"),
|
||||
"error should include field path component 'hooks', got: {rendered}"
|
||||
);
|
||||
assert!(
|
||||
rendered.contains("PreToolUse"),
|
||||
"error should describe the type mismatch, got: {rendered}"
|
||||
);
|
||||
assert!(
|
||||
rendered.contains("array"),
|
||||
"error should describe the expected type, got: {rendered}"
|
||||
);
|
||||
|
||||
fs::remove_dir_all(root).expect("cleanup temp dir");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn unknown_top_level_key_suggests_closest_match() {
|
||||
// given
|
||||
let root = temp_dir();
|
||||
let cwd = root.join("project");
|
||||
let home = root.join("home").join(".claw");
|
||||
let user_settings = home.join("settings.json");
|
||||
fs::create_dir_all(&home).expect("home config dir");
|
||||
fs::create_dir_all(&cwd).expect("project dir");
|
||||
fs::write(&user_settings, "{\n \"modle\": \"opus\"\n}\n").expect("write user settings");
|
||||
|
||||
// when
|
||||
let error = ConfigLoader::new(&cwd, &home)
|
||||
.load()
|
||||
.expect_err("config should fail");
|
||||
|
||||
// then
|
||||
let rendered = error.to_string();
|
||||
assert!(
|
||||
rendered.contains("modle"),
|
||||
"error should name the offending field, got: {rendered}"
|
||||
);
|
||||
assert!(
|
||||
rendered.contains("model"),
|
||||
"error should suggest the closest known key, got: {rendered}"
|
||||
);
|
||||
|
||||
fs::remove_dir_all(root).expect("cleanup temp dir");
|
||||
}
|
||||
}
|
||||
|
||||
901
rust/crates/runtime/src/config_validate.rs
Normal file
901
rust/crates/runtime/src/config_validate.rs
Normal file
@@ -0,0 +1,901 @@
|
||||
use std::collections::BTreeMap;
|
||||
use std::path::Path;
|
||||
|
||||
use crate::config::ConfigError;
|
||||
use crate::json::JsonValue;
|
||||
|
||||
/// Diagnostic emitted when a config file contains a suspect field.
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub struct ConfigDiagnostic {
|
||||
pub path: String,
|
||||
pub field: String,
|
||||
pub line: Option<usize>,
|
||||
pub kind: DiagnosticKind,
|
||||
}
|
||||
|
||||
/// Classification of the diagnostic.
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub enum DiagnosticKind {
|
||||
UnknownKey {
|
||||
suggestion: Option<String>,
|
||||
},
|
||||
WrongType {
|
||||
expected: &'static str,
|
||||
got: &'static str,
|
||||
},
|
||||
Deprecated {
|
||||
replacement: &'static str,
|
||||
},
|
||||
}
|
||||
|
||||
impl std::fmt::Display for ConfigDiagnostic {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
let location = self
|
||||
.line
|
||||
.map_or_else(String::new, |line| format!(" (line {line})"));
|
||||
match &self.kind {
|
||||
DiagnosticKind::UnknownKey { suggestion: None } => {
|
||||
write!(f, "{}: unknown key \"{}\"{location}", self.path, self.field)
|
||||
}
|
||||
DiagnosticKind::UnknownKey {
|
||||
suggestion: Some(hint),
|
||||
} => {
|
||||
write!(
|
||||
f,
|
||||
"{}: unknown key \"{}\"{location}. Did you mean \"{}\"?",
|
||||
self.path, self.field, hint
|
||||
)
|
||||
}
|
||||
DiagnosticKind::WrongType { expected, got } => {
|
||||
write!(
|
||||
f,
|
||||
"{}: field \"{}\" must be {expected}, got {got}{location}",
|
||||
self.path, self.field
|
||||
)
|
||||
}
|
||||
DiagnosticKind::Deprecated { replacement } => {
|
||||
write!(
|
||||
f,
|
||||
"{}: field \"{}\" is deprecated{location}. Use \"{replacement}\" instead",
|
||||
self.path, self.field
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Result of validating a single config file.
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub struct ValidationResult {
|
||||
pub errors: Vec<ConfigDiagnostic>,
|
||||
pub warnings: Vec<ConfigDiagnostic>,
|
||||
}
|
||||
|
||||
impl ValidationResult {
|
||||
#[must_use]
|
||||
pub fn is_ok(&self) -> bool {
|
||||
self.errors.is_empty()
|
||||
}
|
||||
|
||||
fn merge(&mut self, other: Self) {
|
||||
self.errors.extend(other.errors);
|
||||
self.warnings.extend(other.warnings);
|
||||
}
|
||||
}
|
||||
|
||||
// ---- known-key schema ----
|
||||
|
||||
/// Expected type for a config field.
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
enum FieldType {
|
||||
String,
|
||||
Bool,
|
||||
Object,
|
||||
StringArray,
|
||||
Number,
|
||||
}
|
||||
|
||||
impl FieldType {
|
||||
fn label(self) -> &'static str {
|
||||
match self {
|
||||
Self::String => "a string",
|
||||
Self::Bool => "a boolean",
|
||||
Self::Object => "an object",
|
||||
Self::StringArray => "an array of strings",
|
||||
Self::Number => "a number",
|
||||
}
|
||||
}
|
||||
|
||||
fn matches(self, value: &JsonValue) -> bool {
|
||||
match self {
|
||||
Self::String => value.as_str().is_some(),
|
||||
Self::Bool => value.as_bool().is_some(),
|
||||
Self::Object => value.as_object().is_some(),
|
||||
Self::StringArray => value
|
||||
.as_array()
|
||||
.is_some_and(|arr| arr.iter().all(|v| v.as_str().is_some())),
|
||||
Self::Number => value.as_i64().is_some(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn json_type_label(value: &JsonValue) -> &'static str {
|
||||
match value {
|
||||
JsonValue::Null => "null",
|
||||
JsonValue::Bool(_) => "a boolean",
|
||||
JsonValue::Number(_) => "a number",
|
||||
JsonValue::String(_) => "a string",
|
||||
JsonValue::Array(_) => "an array",
|
||||
JsonValue::Object(_) => "an object",
|
||||
}
|
||||
}
|
||||
|
||||
struct FieldSpec {
|
||||
name: &'static str,
|
||||
expected: FieldType,
|
||||
}
|
||||
|
||||
struct DeprecatedField {
|
||||
name: &'static str,
|
||||
replacement: &'static str,
|
||||
}
|
||||
|
||||
const TOP_LEVEL_FIELDS: &[FieldSpec] = &[
|
||||
FieldSpec {
|
||||
name: "$schema",
|
||||
expected: FieldType::String,
|
||||
},
|
||||
FieldSpec {
|
||||
name: "model",
|
||||
expected: FieldType::String,
|
||||
},
|
||||
FieldSpec {
|
||||
name: "hooks",
|
||||
expected: FieldType::Object,
|
||||
},
|
||||
FieldSpec {
|
||||
name: "permissions",
|
||||
expected: FieldType::Object,
|
||||
},
|
||||
FieldSpec {
|
||||
name: "permissionMode",
|
||||
expected: FieldType::String,
|
||||
},
|
||||
FieldSpec {
|
||||
name: "mcpServers",
|
||||
expected: FieldType::Object,
|
||||
},
|
||||
FieldSpec {
|
||||
name: "oauth",
|
||||
expected: FieldType::Object,
|
||||
},
|
||||
FieldSpec {
|
||||
name: "enabledPlugins",
|
||||
expected: FieldType::Object,
|
||||
},
|
||||
FieldSpec {
|
||||
name: "plugins",
|
||||
expected: FieldType::Object,
|
||||
},
|
||||
FieldSpec {
|
||||
name: "sandbox",
|
||||
expected: FieldType::Object,
|
||||
},
|
||||
FieldSpec {
|
||||
name: "env",
|
||||
expected: FieldType::Object,
|
||||
},
|
||||
FieldSpec {
|
||||
name: "aliases",
|
||||
expected: FieldType::Object,
|
||||
},
|
||||
FieldSpec {
|
||||
name: "providerFallbacks",
|
||||
expected: FieldType::Object,
|
||||
},
|
||||
FieldSpec {
|
||||
name: "trustedRoots",
|
||||
expected: FieldType::StringArray,
|
||||
},
|
||||
];
|
||||
|
||||
const HOOKS_FIELDS: &[FieldSpec] = &[
|
||||
FieldSpec {
|
||||
name: "PreToolUse",
|
||||
expected: FieldType::StringArray,
|
||||
},
|
||||
FieldSpec {
|
||||
name: "PostToolUse",
|
||||
expected: FieldType::StringArray,
|
||||
},
|
||||
FieldSpec {
|
||||
name: "PostToolUseFailure",
|
||||
expected: FieldType::StringArray,
|
||||
},
|
||||
];
|
||||
|
||||
const PERMISSIONS_FIELDS: &[FieldSpec] = &[
|
||||
FieldSpec {
|
||||
name: "defaultMode",
|
||||
expected: FieldType::String,
|
||||
},
|
||||
FieldSpec {
|
||||
name: "allow",
|
||||
expected: FieldType::StringArray,
|
||||
},
|
||||
FieldSpec {
|
||||
name: "deny",
|
||||
expected: FieldType::StringArray,
|
||||
},
|
||||
FieldSpec {
|
||||
name: "ask",
|
||||
expected: FieldType::StringArray,
|
||||
},
|
||||
];
|
||||
|
||||
const PLUGINS_FIELDS: &[FieldSpec] = &[
|
||||
FieldSpec {
|
||||
name: "enabled",
|
||||
expected: FieldType::Object,
|
||||
},
|
||||
FieldSpec {
|
||||
name: "externalDirectories",
|
||||
expected: FieldType::StringArray,
|
||||
},
|
||||
FieldSpec {
|
||||
name: "installRoot",
|
||||
expected: FieldType::String,
|
||||
},
|
||||
FieldSpec {
|
||||
name: "registryPath",
|
||||
expected: FieldType::String,
|
||||
},
|
||||
FieldSpec {
|
||||
name: "bundledRoot",
|
||||
expected: FieldType::String,
|
||||
},
|
||||
FieldSpec {
|
||||
name: "maxOutputTokens",
|
||||
expected: FieldType::Number,
|
||||
},
|
||||
];
|
||||
|
||||
const SANDBOX_FIELDS: &[FieldSpec] = &[
|
||||
FieldSpec {
|
||||
name: "enabled",
|
||||
expected: FieldType::Bool,
|
||||
},
|
||||
FieldSpec {
|
||||
name: "namespaceRestrictions",
|
||||
expected: FieldType::Bool,
|
||||
},
|
||||
FieldSpec {
|
||||
name: "networkIsolation",
|
||||
expected: FieldType::Bool,
|
||||
},
|
||||
FieldSpec {
|
||||
name: "filesystemMode",
|
||||
expected: FieldType::String,
|
||||
},
|
||||
FieldSpec {
|
||||
name: "allowedMounts",
|
||||
expected: FieldType::StringArray,
|
||||
},
|
||||
];
|
||||
|
||||
const OAUTH_FIELDS: &[FieldSpec] = &[
|
||||
FieldSpec {
|
||||
name: "clientId",
|
||||
expected: FieldType::String,
|
||||
},
|
||||
FieldSpec {
|
||||
name: "authorizeUrl",
|
||||
expected: FieldType::String,
|
||||
},
|
||||
FieldSpec {
|
||||
name: "tokenUrl",
|
||||
expected: FieldType::String,
|
||||
},
|
||||
FieldSpec {
|
||||
name: "callbackPort",
|
||||
expected: FieldType::Number,
|
||||
},
|
||||
FieldSpec {
|
||||
name: "manualRedirectUrl",
|
||||
expected: FieldType::String,
|
||||
},
|
||||
FieldSpec {
|
||||
name: "scopes",
|
||||
expected: FieldType::StringArray,
|
||||
},
|
||||
];
|
||||
|
||||
const DEPRECATED_FIELDS: &[DeprecatedField] = &[
|
||||
DeprecatedField {
|
||||
name: "permissionMode",
|
||||
replacement: "permissions.defaultMode",
|
||||
},
|
||||
DeprecatedField {
|
||||
name: "enabledPlugins",
|
||||
replacement: "plugins.enabled",
|
||||
},
|
||||
];
|
||||
|
||||
// ---- line-number resolution ----
|
||||
|
||||
/// Find the 1-based line number where a JSON key first appears in the raw source.
|
||||
fn find_key_line(source: &str, key: &str) -> Option<usize> {
|
||||
// Search for `"key"` followed by optional whitespace and a colon.
|
||||
let needle = format!("\"{key}\"");
|
||||
let mut search_start = 0;
|
||||
while let Some(offset) = source[search_start..].find(&needle) {
|
||||
let absolute = search_start + offset;
|
||||
let after = absolute + needle.len();
|
||||
// Verify the next non-whitespace char is `:` to confirm this is a key, not a value.
|
||||
if source[after..].chars().find(|ch| !ch.is_ascii_whitespace()) == Some(':') {
|
||||
return Some(source[..absolute].chars().filter(|&ch| ch == '\n').count() + 1);
|
||||
}
|
||||
search_start = after;
|
||||
}
|
||||
None
|
||||
}
|
||||
|
||||
// ---- core validation ----
|
||||
|
||||
fn validate_object_keys(
|
||||
object: &BTreeMap<String, JsonValue>,
|
||||
known_fields: &[FieldSpec],
|
||||
prefix: &str,
|
||||
source: &str,
|
||||
path_display: &str,
|
||||
) -> ValidationResult {
|
||||
let mut result = ValidationResult {
|
||||
errors: Vec::new(),
|
||||
warnings: Vec::new(),
|
||||
};
|
||||
|
||||
let known_names: Vec<&str> = known_fields.iter().map(|f| f.name).collect();
|
||||
|
||||
for (key, value) in object {
|
||||
let field_path = if prefix.is_empty() {
|
||||
key.clone()
|
||||
} else {
|
||||
format!("{prefix}.{key}")
|
||||
};
|
||||
|
||||
if let Some(spec) = known_fields.iter().find(|f| f.name == key) {
|
||||
// Type check.
|
||||
if !spec.expected.matches(value) {
|
||||
result.errors.push(ConfigDiagnostic {
|
||||
path: path_display.to_string(),
|
||||
field: field_path,
|
||||
line: find_key_line(source, key),
|
||||
kind: DiagnosticKind::WrongType {
|
||||
expected: spec.expected.label(),
|
||||
got: json_type_label(value),
|
||||
},
|
||||
});
|
||||
}
|
||||
} else if DEPRECATED_FIELDS.iter().any(|d| d.name == key) {
|
||||
// Deprecated key — handled separately, not an unknown-key error.
|
||||
} else {
|
||||
// Unknown key.
|
||||
let suggestion = suggest_field(key, &known_names);
|
||||
result.errors.push(ConfigDiagnostic {
|
||||
path: path_display.to_string(),
|
||||
field: field_path,
|
||||
line: find_key_line(source, key),
|
||||
kind: DiagnosticKind::UnknownKey { suggestion },
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
result
|
||||
}
|
||||
|
||||
fn suggest_field(input: &str, candidates: &[&str]) -> Option<String> {
|
||||
let input_lower = input.to_ascii_lowercase();
|
||||
candidates
|
||||
.iter()
|
||||
.filter_map(|candidate| {
|
||||
let distance = simple_edit_distance(&input_lower, &candidate.to_ascii_lowercase());
|
||||
(distance <= 3).then_some((distance, *candidate))
|
||||
})
|
||||
.min_by_key(|(distance, _)| *distance)
|
||||
.map(|(_, name)| name.to_string())
|
||||
}
|
||||
|
||||
fn simple_edit_distance(left: &str, right: &str) -> usize {
|
||||
if left.is_empty() {
|
||||
return right.len();
|
||||
}
|
||||
if right.is_empty() {
|
||||
return left.len();
|
||||
}
|
||||
let right_chars: Vec<char> = right.chars().collect();
|
||||
let mut previous: Vec<usize> = (0..=right_chars.len()).collect();
|
||||
let mut current = vec![0; right_chars.len() + 1];
|
||||
|
||||
for (left_index, left_char) in left.chars().enumerate() {
|
||||
current[0] = left_index + 1;
|
||||
for (right_index, right_char) in right_chars.iter().enumerate() {
|
||||
let cost = usize::from(left_char != *right_char);
|
||||
current[right_index + 1] = (previous[right_index + 1] + 1)
|
||||
.min(current[right_index] + 1)
|
||||
.min(previous[right_index] + cost);
|
||||
}
|
||||
previous.clone_from(¤t);
|
||||
}
|
||||
|
||||
previous[right_chars.len()]
|
||||
}
|
||||
|
||||
/// Validate a parsed config file's keys and types against the known schema.
|
||||
///
|
||||
/// Returns diagnostics (errors and deprecation warnings) without blocking the load.
|
||||
pub fn validate_config_file(
|
||||
object: &BTreeMap<String, JsonValue>,
|
||||
source: &str,
|
||||
file_path: &Path,
|
||||
) -> ValidationResult {
|
||||
let path_display = file_path.display().to_string();
|
||||
let mut result = validate_object_keys(object, TOP_LEVEL_FIELDS, "", source, &path_display);
|
||||
|
||||
// Check deprecated fields.
|
||||
for deprecated in DEPRECATED_FIELDS {
|
||||
if object.contains_key(deprecated.name) {
|
||||
result.warnings.push(ConfigDiagnostic {
|
||||
path: path_display.clone(),
|
||||
field: deprecated.name.to_string(),
|
||||
line: find_key_line(source, deprecated.name),
|
||||
kind: DiagnosticKind::Deprecated {
|
||||
replacement: deprecated.replacement,
|
||||
},
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
// Validate known nested objects.
|
||||
if let Some(hooks) = object.get("hooks").and_then(JsonValue::as_object) {
|
||||
result.merge(validate_object_keys(
|
||||
hooks,
|
||||
HOOKS_FIELDS,
|
||||
"hooks",
|
||||
source,
|
||||
&path_display,
|
||||
));
|
||||
}
|
||||
if let Some(permissions) = object.get("permissions").and_then(JsonValue::as_object) {
|
||||
result.merge(validate_object_keys(
|
||||
permissions,
|
||||
PERMISSIONS_FIELDS,
|
||||
"permissions",
|
||||
source,
|
||||
&path_display,
|
||||
));
|
||||
}
|
||||
if let Some(plugins) = object.get("plugins").and_then(JsonValue::as_object) {
|
||||
result.merge(validate_object_keys(
|
||||
plugins,
|
||||
PLUGINS_FIELDS,
|
||||
"plugins",
|
||||
source,
|
||||
&path_display,
|
||||
));
|
||||
}
|
||||
if let Some(sandbox) = object.get("sandbox").and_then(JsonValue::as_object) {
|
||||
result.merge(validate_object_keys(
|
||||
sandbox,
|
||||
SANDBOX_FIELDS,
|
||||
"sandbox",
|
||||
source,
|
||||
&path_display,
|
||||
));
|
||||
}
|
||||
if let Some(oauth) = object.get("oauth").and_then(JsonValue::as_object) {
|
||||
result.merge(validate_object_keys(
|
||||
oauth,
|
||||
OAUTH_FIELDS,
|
||||
"oauth",
|
||||
source,
|
||||
&path_display,
|
||||
));
|
||||
}
|
||||
|
||||
result
|
||||
}
|
||||
|
||||
/// Check whether a file path uses an unsupported config format (e.g. TOML).
|
||||
pub fn check_unsupported_format(file_path: &Path) -> Result<(), ConfigError> {
|
||||
if let Some(ext) = file_path.extension().and_then(|e| e.to_str()) {
|
||||
if ext.eq_ignore_ascii_case("toml") {
|
||||
return Err(ConfigError::Parse(format!(
|
||||
"{}: TOML config files are not supported. Use JSON (settings.json) instead",
|
||||
file_path.display()
|
||||
)));
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Format all diagnostics into a human-readable report.
|
||||
#[must_use]
|
||||
pub fn format_diagnostics(result: &ValidationResult) -> String {
|
||||
let mut lines = Vec::new();
|
||||
for warning in &result.warnings {
|
||||
lines.push(format!("warning: {warning}"));
|
||||
}
|
||||
for error in &result.errors {
|
||||
lines.push(format!("error: {error}"));
|
||||
}
|
||||
lines.join("\n")
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use std::path::PathBuf;
|
||||
|
||||
fn test_path() -> PathBuf {
|
||||
PathBuf::from("/test/settings.json")
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn detects_unknown_top_level_key() {
|
||||
// given
|
||||
let source = r#"{"model": "opus", "unknownField": true}"#;
|
||||
let parsed = JsonValue::parse(source).expect("valid json");
|
||||
let object = parsed.as_object().expect("object");
|
||||
|
||||
// when
|
||||
let result = validate_config_file(object, source, &test_path());
|
||||
|
||||
// then
|
||||
assert_eq!(result.errors.len(), 1);
|
||||
assert_eq!(result.errors[0].field, "unknownField");
|
||||
assert!(matches!(
|
||||
result.errors[0].kind,
|
||||
DiagnosticKind::UnknownKey { .. }
|
||||
));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn detects_wrong_type_for_model() {
|
||||
// given
|
||||
let source = r#"{"model": 123}"#;
|
||||
let parsed = JsonValue::parse(source).expect("valid json");
|
||||
let object = parsed.as_object().expect("object");
|
||||
|
||||
// when
|
||||
let result = validate_config_file(object, source, &test_path());
|
||||
|
||||
// then
|
||||
assert_eq!(result.errors.len(), 1);
|
||||
assert_eq!(result.errors[0].field, "model");
|
||||
assert!(matches!(
|
||||
result.errors[0].kind,
|
||||
DiagnosticKind::WrongType {
|
||||
expected: "a string",
|
||||
got: "a number"
|
||||
}
|
||||
));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn detects_deprecated_permission_mode() {
|
||||
// given
|
||||
let source = r#"{"permissionMode": "plan"}"#;
|
||||
let parsed = JsonValue::parse(source).expect("valid json");
|
||||
let object = parsed.as_object().expect("object");
|
||||
|
||||
// when
|
||||
let result = validate_config_file(object, source, &test_path());
|
||||
|
||||
// then
|
||||
assert_eq!(result.warnings.len(), 1);
|
||||
assert_eq!(result.warnings[0].field, "permissionMode");
|
||||
assert!(matches!(
|
||||
result.warnings[0].kind,
|
||||
DiagnosticKind::Deprecated {
|
||||
replacement: "permissions.defaultMode"
|
||||
}
|
||||
));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn detects_deprecated_enabled_plugins() {
|
||||
// given
|
||||
let source = r#"{"enabledPlugins": {"tool-guard@builtin": true}}"#;
|
||||
let parsed = JsonValue::parse(source).expect("valid json");
|
||||
let object = parsed.as_object().expect("object");
|
||||
|
||||
// when
|
||||
let result = validate_config_file(object, source, &test_path());
|
||||
|
||||
// then
|
||||
assert_eq!(result.warnings.len(), 1);
|
||||
assert_eq!(result.warnings[0].field, "enabledPlugins");
|
||||
assert!(matches!(
|
||||
result.warnings[0].kind,
|
||||
DiagnosticKind::Deprecated {
|
||||
replacement: "plugins.enabled"
|
||||
}
|
||||
));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn reports_line_number_for_unknown_key() {
|
||||
// given
|
||||
let source = "{\n \"model\": \"opus\",\n \"badKey\": true\n}";
|
||||
let parsed = JsonValue::parse(source).expect("valid json");
|
||||
let object = parsed.as_object().expect("object");
|
||||
|
||||
// when
|
||||
let result = validate_config_file(object, source, &test_path());
|
||||
|
||||
// then
|
||||
assert_eq!(result.errors.len(), 1);
|
||||
assert_eq!(result.errors[0].line, Some(3));
|
||||
assert_eq!(result.errors[0].field, "badKey");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn reports_line_number_for_wrong_type() {
|
||||
// given
|
||||
let source = "{\n \"model\": 42\n}";
|
||||
let parsed = JsonValue::parse(source).expect("valid json");
|
||||
let object = parsed.as_object().expect("object");
|
||||
|
||||
// when
|
||||
let result = validate_config_file(object, source, &test_path());
|
||||
|
||||
// then
|
||||
assert_eq!(result.errors.len(), 1);
|
||||
assert_eq!(result.errors[0].line, Some(2));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn validates_nested_hooks_keys() {
|
||||
// given
|
||||
let source = r#"{"hooks": {"PreToolUse": ["cmd"], "BadHook": ["x"]}}"#;
|
||||
let parsed = JsonValue::parse(source).expect("valid json");
|
||||
let object = parsed.as_object().expect("object");
|
||||
|
||||
// when
|
||||
let result = validate_config_file(object, source, &test_path());
|
||||
|
||||
// then
|
||||
assert_eq!(result.errors.len(), 1);
|
||||
assert_eq!(result.errors[0].field, "hooks.BadHook");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn validates_nested_permissions_keys() {
|
||||
// given
|
||||
let source = r#"{"permissions": {"allow": ["Read"], "denyAll": true}}"#;
|
||||
let parsed = JsonValue::parse(source).expect("valid json");
|
||||
let object = parsed.as_object().expect("object");
|
||||
|
||||
// when
|
||||
let result = validate_config_file(object, source, &test_path());
|
||||
|
||||
// then
|
||||
assert_eq!(result.errors.len(), 1);
|
||||
assert_eq!(result.errors[0].field, "permissions.denyAll");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn validates_nested_sandbox_keys() {
|
||||
// given
|
||||
let source = r#"{"sandbox": {"enabled": true, "containerMode": "strict"}}"#;
|
||||
let parsed = JsonValue::parse(source).expect("valid json");
|
||||
let object = parsed.as_object().expect("object");
|
||||
|
||||
// when
|
||||
let result = validate_config_file(object, source, &test_path());
|
||||
|
||||
// then
|
||||
assert_eq!(result.errors.len(), 1);
|
||||
assert_eq!(result.errors[0].field, "sandbox.containerMode");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn validates_nested_plugins_keys() {
|
||||
// given
|
||||
let source = r#"{"plugins": {"installRoot": "/tmp", "autoUpdate": true}}"#;
|
||||
let parsed = JsonValue::parse(source).expect("valid json");
|
||||
let object = parsed.as_object().expect("object");
|
||||
|
||||
// when
|
||||
let result = validate_config_file(object, source, &test_path());
|
||||
|
||||
// then
|
||||
assert_eq!(result.errors.len(), 1);
|
||||
assert_eq!(result.errors[0].field, "plugins.autoUpdate");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn validates_nested_oauth_keys() {
|
||||
// given
|
||||
let source = r#"{"oauth": {"clientId": "abc", "secret": "hidden"}}"#;
|
||||
let parsed = JsonValue::parse(source).expect("valid json");
|
||||
let object = parsed.as_object().expect("object");
|
||||
|
||||
// when
|
||||
let result = validate_config_file(object, source, &test_path());
|
||||
|
||||
// then
|
||||
assert_eq!(result.errors.len(), 1);
|
||||
assert_eq!(result.errors[0].field, "oauth.secret");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn valid_config_produces_no_diagnostics() {
|
||||
// given
|
||||
let source = r#"{
|
||||
"model": "opus",
|
||||
"hooks": {"PreToolUse": ["guard"]},
|
||||
"permissions": {"defaultMode": "plan", "allow": ["Read"]},
|
||||
"mcpServers": {},
|
||||
"sandbox": {"enabled": false}
|
||||
}"#;
|
||||
let parsed = JsonValue::parse(source).expect("valid json");
|
||||
let object = parsed.as_object().expect("object");
|
||||
|
||||
// when
|
||||
let result = validate_config_file(object, source, &test_path());
|
||||
|
||||
// then
|
||||
assert!(result.is_ok());
|
||||
assert!(result.warnings.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn suggests_close_field_name() {
|
||||
// given
|
||||
let source = r#"{"modle": "opus"}"#;
|
||||
let parsed = JsonValue::parse(source).expect("valid json");
|
||||
let object = parsed.as_object().expect("object");
|
||||
|
||||
// when
|
||||
let result = validate_config_file(object, source, &test_path());
|
||||
|
||||
// then
|
||||
assert_eq!(result.errors.len(), 1);
|
||||
match &result.errors[0].kind {
|
||||
DiagnosticKind::UnknownKey {
|
||||
suggestion: Some(s),
|
||||
} => assert_eq!(s, "model"),
|
||||
other => panic!("expected suggestion, got {other:?}"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn format_diagnostics_includes_all_entries() {
|
||||
// given
|
||||
let source = r#"{"permissionMode": "plan", "badKey": 1}"#;
|
||||
let parsed = JsonValue::parse(source).expect("valid json");
|
||||
let object = parsed.as_object().expect("object");
|
||||
let result = validate_config_file(object, source, &test_path());
|
||||
|
||||
// when
|
||||
let output = format_diagnostics(&result);
|
||||
|
||||
// then
|
||||
assert!(output.contains("warning:"));
|
||||
assert!(output.contains("error:"));
|
||||
assert!(output.contains("badKey"));
|
||||
assert!(output.contains("permissionMode"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn check_unsupported_format_rejects_toml() {
|
||||
// given
|
||||
let path = PathBuf::from("/home/.claw/settings.toml");
|
||||
|
||||
// when
|
||||
let result = check_unsupported_format(&path);
|
||||
|
||||
// then
|
||||
assert!(result.is_err());
|
||||
let message = result.unwrap_err().to_string();
|
||||
assert!(message.contains("TOML"));
|
||||
assert!(message.contains("settings.toml"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn check_unsupported_format_allows_json() {
|
||||
// given
|
||||
let path = PathBuf::from("/home/.claw/settings.json");
|
||||
|
||||
// when / then
|
||||
assert!(check_unsupported_format(&path).is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn wrong_type_in_nested_sandbox_field() {
|
||||
// given
|
||||
let source = r#"{"sandbox": {"enabled": "yes"}}"#;
|
||||
let parsed = JsonValue::parse(source).expect("valid json");
|
||||
let object = parsed.as_object().expect("object");
|
||||
|
||||
// when
|
||||
let result = validate_config_file(object, source, &test_path());
|
||||
|
||||
// then
|
||||
assert_eq!(result.errors.len(), 1);
|
||||
assert_eq!(result.errors[0].field, "sandbox.enabled");
|
||||
assert!(matches!(
|
||||
result.errors[0].kind,
|
||||
DiagnosticKind::WrongType {
|
||||
expected: "a boolean",
|
||||
got: "a string"
|
||||
}
|
||||
));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn display_format_unknown_key_with_line() {
|
||||
// given
|
||||
let diag = ConfigDiagnostic {
|
||||
path: "/test/settings.json".to_string(),
|
||||
field: "badKey".to_string(),
|
||||
line: Some(5),
|
||||
kind: DiagnosticKind::UnknownKey { suggestion: None },
|
||||
};
|
||||
|
||||
// when
|
||||
let output = diag.to_string();
|
||||
|
||||
// then
|
||||
assert_eq!(
|
||||
output,
|
||||
r#"/test/settings.json: unknown key "badKey" (line 5)"#
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn display_format_wrong_type_with_line() {
|
||||
// given
|
||||
let diag = ConfigDiagnostic {
|
||||
path: "/test/settings.json".to_string(),
|
||||
field: "model".to_string(),
|
||||
line: Some(2),
|
||||
kind: DiagnosticKind::WrongType {
|
||||
expected: "a string",
|
||||
got: "a number",
|
||||
},
|
||||
};
|
||||
|
||||
// when
|
||||
let output = diag.to_string();
|
||||
|
||||
// then
|
||||
assert_eq!(
|
||||
output,
|
||||
r#"/test/settings.json: field "model" must be a string, got a number (line 2)"#
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn display_format_deprecated_with_line() {
|
||||
// given
|
||||
let diag = ConfigDiagnostic {
|
||||
path: "/test/settings.json".to_string(),
|
||||
field: "permissionMode".to_string(),
|
||||
line: Some(3),
|
||||
kind: DiagnosticKind::Deprecated {
|
||||
replacement: "permissions.defaultMode",
|
||||
},
|
||||
};
|
||||
|
||||
// when
|
||||
let output = diag.to_string();
|
||||
|
||||
// then
|
||||
assert_eq!(
|
||||
output,
|
||||
r#"/test/settings.json: field "permissionMode" is deprecated (line 3). Use "permissions.defaultMode" instead"#
|
||||
);
|
||||
}
|
||||
}
|
||||
@@ -504,6 +504,10 @@ where
|
||||
&self.session
|
||||
}
|
||||
|
||||
pub fn session_mut(&mut self) -> &mut Session {
|
||||
&mut self.session
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
pub fn fork_session(&self, branch_name: Option<String>) -> Session {
|
||||
self.session.fork(branch_name)
|
||||
@@ -890,6 +894,7 @@ mod tests {
|
||||
current_date: "2026-03-31".to_string(),
|
||||
git_status: None,
|
||||
git_diff: None,
|
||||
git_context: None,
|
||||
instruction_files: Vec::new(),
|
||||
})
|
||||
.with_os("linux", "6.8")
|
||||
|
||||
324
rust/crates/runtime/src/git_context.rs
Normal file
324
rust/crates/runtime/src/git_context.rs
Normal file
@@ -0,0 +1,324 @@
|
||||
use std::path::Path;
|
||||
use std::process::Command;
|
||||
|
||||
/// A single git commit entry from the log.
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub struct GitCommitEntry {
|
||||
pub hash: String,
|
||||
pub subject: String,
|
||||
}
|
||||
|
||||
/// Git-aware context gathered at startup for injection into the system prompt.
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub struct GitContext {
|
||||
pub branch: Option<String>,
|
||||
pub recent_commits: Vec<GitCommitEntry>,
|
||||
pub staged_files: Vec<String>,
|
||||
}
|
||||
|
||||
const MAX_RECENT_COMMITS: usize = 5;
|
||||
|
||||
impl GitContext {
|
||||
/// Detect the git context from the given working directory.
|
||||
///
|
||||
/// Returns `None` when the directory is not inside a git repository.
|
||||
#[must_use]
|
||||
pub fn detect(cwd: &Path) -> Option<Self> {
|
||||
// Quick gate: is this a git repo at all?
|
||||
let rev_parse = Command::new("git")
|
||||
.args(["rev-parse", "--is-inside-work-tree"])
|
||||
.current_dir(cwd)
|
||||
.output()
|
||||
.ok()?;
|
||||
if !rev_parse.status.success() {
|
||||
return None;
|
||||
}
|
||||
|
||||
Some(Self {
|
||||
branch: read_branch(cwd),
|
||||
recent_commits: read_recent_commits(cwd),
|
||||
staged_files: read_staged_files(cwd),
|
||||
})
|
||||
}
|
||||
|
||||
/// Render a human-readable summary suitable for system-prompt injection.
|
||||
#[must_use]
|
||||
pub fn render(&self) -> String {
|
||||
let mut lines = Vec::new();
|
||||
|
||||
if let Some(branch) = &self.branch {
|
||||
lines.push(format!("Git branch: {branch}"));
|
||||
}
|
||||
|
||||
if !self.recent_commits.is_empty() {
|
||||
lines.push(String::new());
|
||||
lines.push("Recent commits:".to_string());
|
||||
for entry in &self.recent_commits {
|
||||
lines.push(format!(" {} {}", entry.hash, entry.subject));
|
||||
}
|
||||
}
|
||||
|
||||
if !self.staged_files.is_empty() {
|
||||
lines.push(String::new());
|
||||
lines.push("Staged files:".to_string());
|
||||
for file in &self.staged_files {
|
||||
lines.push(format!(" {file}"));
|
||||
}
|
||||
}
|
||||
|
||||
lines.join("\n")
|
||||
}
|
||||
}
|
||||
|
||||
fn read_branch(cwd: &Path) -> Option<String> {
|
||||
let output = Command::new("git")
|
||||
.args(["rev-parse", "--abbrev-ref", "HEAD"])
|
||||
.current_dir(cwd)
|
||||
.output()
|
||||
.ok()?;
|
||||
if !output.status.success() {
|
||||
return None;
|
||||
}
|
||||
let branch = String::from_utf8(output.stdout).ok()?;
|
||||
let trimmed = branch.trim();
|
||||
if trimmed.is_empty() || trimmed == "HEAD" {
|
||||
None
|
||||
} else {
|
||||
Some(trimmed.to_string())
|
||||
}
|
||||
}
|
||||
|
||||
fn read_recent_commits(cwd: &Path) -> Vec<GitCommitEntry> {
|
||||
let output = Command::new("git")
|
||||
.args([
|
||||
"--no-optional-locks",
|
||||
"log",
|
||||
"--oneline",
|
||||
"-n",
|
||||
&MAX_RECENT_COMMITS.to_string(),
|
||||
"--no-decorate",
|
||||
])
|
||||
.current_dir(cwd)
|
||||
.output()
|
||||
.ok();
|
||||
let Some(output) = output else {
|
||||
return Vec::new();
|
||||
};
|
||||
if !output.status.success() {
|
||||
return Vec::new();
|
||||
}
|
||||
let stdout = String::from_utf8(output.stdout).unwrap_or_default();
|
||||
stdout
|
||||
.lines()
|
||||
.filter_map(|line| {
|
||||
let line = line.trim();
|
||||
if line.is_empty() {
|
||||
return None;
|
||||
}
|
||||
let (hash, subject) = line.split_once(' ')?;
|
||||
Some(GitCommitEntry {
|
||||
hash: hash.to_string(),
|
||||
subject: subject.to_string(),
|
||||
})
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn read_staged_files(cwd: &Path) -> Vec<String> {
|
||||
let output = Command::new("git")
|
||||
.args(["--no-optional-locks", "diff", "--cached", "--name-only"])
|
||||
.current_dir(cwd)
|
||||
.output()
|
||||
.ok();
|
||||
let Some(output) = output else {
|
||||
return Vec::new();
|
||||
};
|
||||
if !output.status.success() {
|
||||
return Vec::new();
|
||||
}
|
||||
let stdout = String::from_utf8(output.stdout).unwrap_or_default();
|
||||
stdout
|
||||
.lines()
|
||||
.filter(|line| !line.trim().is_empty())
|
||||
.map(|line| line.trim().to_string())
|
||||
.collect()
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::{GitCommitEntry, GitContext};
|
||||
use std::fs;
|
||||
use std::process::Command;
|
||||
use std::time::{SystemTime, UNIX_EPOCH};
|
||||
|
||||
fn temp_dir(label: &str) -> std::path::PathBuf {
|
||||
let nanos = SystemTime::now()
|
||||
.duration_since(UNIX_EPOCH)
|
||||
.expect("time should be after epoch")
|
||||
.as_nanos();
|
||||
std::env::temp_dir().join(format!("runtime-git-context-{label}-{nanos}"))
|
||||
}
|
||||
|
||||
fn env_lock() -> std::sync::MutexGuard<'static, ()> {
|
||||
crate::test_env_lock()
|
||||
}
|
||||
|
||||
fn ensure_valid_cwd() {
|
||||
if std::env::current_dir().is_err() {
|
||||
std::env::set_current_dir(env!("CARGO_MANIFEST_DIR"))
|
||||
.expect("test cwd should be recoverable");
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn returns_none_for_non_git_directory() {
|
||||
// given
|
||||
let _guard = env_lock();
|
||||
ensure_valid_cwd();
|
||||
let root = temp_dir("non-git");
|
||||
fs::create_dir_all(&root).expect("create dir");
|
||||
|
||||
// when
|
||||
let context = GitContext::detect(&root);
|
||||
|
||||
// then
|
||||
assert!(context.is_none());
|
||||
fs::remove_dir_all(root).expect("cleanup");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn detects_branch_name_and_commits() {
|
||||
// given
|
||||
let _guard = env_lock();
|
||||
ensure_valid_cwd();
|
||||
let root = temp_dir("branch-commits");
|
||||
fs::create_dir_all(&root).expect("create dir");
|
||||
git(&root, &["init", "--quiet", "--initial-branch=main"]);
|
||||
git(&root, &["config", "user.email", "tests@example.com"]);
|
||||
git(&root, &["config", "user.name", "Git Context Tests"]);
|
||||
fs::write(root.join("a.txt"), "a\n").expect("write a");
|
||||
git(&root, &["add", "a.txt"]);
|
||||
git(&root, &["commit", "-m", "first commit", "--quiet"]);
|
||||
fs::write(root.join("b.txt"), "b\n").expect("write b");
|
||||
git(&root, &["add", "b.txt"]);
|
||||
git(&root, &["commit", "-m", "second commit", "--quiet"]);
|
||||
|
||||
// when
|
||||
let context = GitContext::detect(&root).expect("should detect git repo");
|
||||
|
||||
// then
|
||||
assert_eq!(context.branch.as_deref(), Some("main"));
|
||||
assert_eq!(context.recent_commits.len(), 2);
|
||||
assert_eq!(context.recent_commits[0].subject, "second commit");
|
||||
assert_eq!(context.recent_commits[1].subject, "first commit");
|
||||
assert!(context.staged_files.is_empty());
|
||||
fs::remove_dir_all(root).expect("cleanup");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn detects_staged_files() {
|
||||
// given
|
||||
let _guard = env_lock();
|
||||
ensure_valid_cwd();
|
||||
let root = temp_dir("staged");
|
||||
fs::create_dir_all(&root).expect("create dir");
|
||||
git(&root, &["init", "--quiet", "--initial-branch=main"]);
|
||||
git(&root, &["config", "user.email", "tests@example.com"]);
|
||||
git(&root, &["config", "user.name", "Git Context Tests"]);
|
||||
fs::write(root.join("init.txt"), "init\n").expect("write init");
|
||||
git(&root, &["add", "init.txt"]);
|
||||
git(&root, &["commit", "-m", "initial", "--quiet"]);
|
||||
fs::write(root.join("staged.txt"), "staged\n").expect("write staged");
|
||||
git(&root, &["add", "staged.txt"]);
|
||||
|
||||
// when
|
||||
let context = GitContext::detect(&root).expect("should detect git repo");
|
||||
|
||||
// then
|
||||
assert_eq!(context.staged_files, vec!["staged.txt"]);
|
||||
fs::remove_dir_all(root).expect("cleanup");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn render_formats_all_sections() {
|
||||
// given
|
||||
let context = GitContext {
|
||||
branch: Some("feat/test".to_string()),
|
||||
recent_commits: vec![
|
||||
GitCommitEntry {
|
||||
hash: "abc1234".to_string(),
|
||||
subject: "add feature".to_string(),
|
||||
},
|
||||
GitCommitEntry {
|
||||
hash: "def5678".to_string(),
|
||||
subject: "fix bug".to_string(),
|
||||
},
|
||||
],
|
||||
staged_files: vec!["src/main.rs".to_string()],
|
||||
};
|
||||
|
||||
// when
|
||||
let rendered = context.render();
|
||||
|
||||
// then
|
||||
assert!(rendered.contains("Git branch: feat/test"));
|
||||
assert!(rendered.contains("abc1234 add feature"));
|
||||
assert!(rendered.contains("def5678 fix bug"));
|
||||
assert!(rendered.contains("src/main.rs"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn render_omits_empty_sections() {
|
||||
// given
|
||||
let context = GitContext {
|
||||
branch: Some("main".to_string()),
|
||||
recent_commits: Vec::new(),
|
||||
staged_files: Vec::new(),
|
||||
};
|
||||
|
||||
// when
|
||||
let rendered = context.render();
|
||||
|
||||
// then
|
||||
assert!(rendered.contains("Git branch: main"));
|
||||
assert!(!rendered.contains("Recent commits:"));
|
||||
assert!(!rendered.contains("Staged files:"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn limits_to_five_recent_commits() {
|
||||
// given
|
||||
let _guard = env_lock();
|
||||
ensure_valid_cwd();
|
||||
let root = temp_dir("five-commits");
|
||||
fs::create_dir_all(&root).expect("create dir");
|
||||
git(&root, &["init", "--quiet", "--initial-branch=main"]);
|
||||
git(&root, &["config", "user.email", "tests@example.com"]);
|
||||
git(&root, &["config", "user.name", "Git Context Tests"]);
|
||||
for i in 1..=8 {
|
||||
let name = format!("file{i}.txt");
|
||||
fs::write(root.join(&name), format!("{i}\n")).expect("write file");
|
||||
git(&root, &["add", &name]);
|
||||
git(&root, &["commit", "-m", &format!("commit {i}"), "--quiet"]);
|
||||
}
|
||||
|
||||
// when
|
||||
let context = GitContext::detect(&root).expect("should detect git repo");
|
||||
|
||||
// then
|
||||
assert_eq!(context.recent_commits.len(), 5);
|
||||
assert_eq!(context.recent_commits[0].subject, "commit 8");
|
||||
assert_eq!(context.recent_commits[4].subject, "commit 4");
|
||||
fs::remove_dir_all(root).expect("cleanup");
|
||||
}
|
||||
|
||||
fn git(cwd: &std::path::Path, args: &[&str]) {
|
||||
let status = Command::new("git")
|
||||
.args(args)
|
||||
.current_dir(cwd)
|
||||
.output()
|
||||
.unwrap_or_else(|_| panic!("git {args:?} should run"))
|
||||
.status;
|
||||
assert!(status.success(), "git {args:?} failed");
|
||||
}
|
||||
}
|
||||
@@ -10,8 +10,10 @@ mod bootstrap;
|
||||
pub mod branch_lock;
|
||||
mod compact;
|
||||
mod config;
|
||||
pub mod config_validate;
|
||||
mod conversation;
|
||||
mod file_ops;
|
||||
mod git_context;
|
||||
pub mod green_contract;
|
||||
mod hooks;
|
||||
mod json;
|
||||
@@ -20,6 +22,7 @@ pub mod lsp_client;
|
||||
mod mcp;
|
||||
mod mcp_client;
|
||||
pub mod mcp_lifecycle_hardened;
|
||||
pub mod mcp_server;
|
||||
mod mcp_stdio;
|
||||
pub mod mcp_tool_bridge;
|
||||
mod oauth;
|
||||
@@ -32,9 +35,10 @@ pub mod recovery_recipes;
|
||||
mod remote;
|
||||
pub mod sandbox;
|
||||
mod session;
|
||||
#[cfg(test)]
|
||||
mod session_control;
|
||||
pub mod session_control;
|
||||
pub use session_control::SessionStore;
|
||||
mod sse;
|
||||
pub mod stale_base;
|
||||
pub mod stale_branch;
|
||||
pub mod summary_compression;
|
||||
pub mod task_packet;
|
||||
@@ -56,10 +60,14 @@ pub use config::{
|
||||
ConfigEntry, ConfigError, ConfigLoader, ConfigSource, McpConfigCollection,
|
||||
McpManagedProxyServerConfig, McpOAuthConfig, McpRemoteServerConfig, McpSdkServerConfig,
|
||||
McpServerConfig, McpStdioServerConfig, McpTransport, McpWebSocketServerConfig, OAuthConfig,
|
||||
ResolvedPermissionMode, RuntimeConfig, RuntimeFeatureConfig, RuntimeHookConfig,
|
||||
RuntimePermissionRuleConfig, RuntimePluginConfig, ScopedMcpServerConfig,
|
||||
ProviderFallbackConfig, ResolvedPermissionMode, RuntimeConfig, RuntimeFeatureConfig,
|
||||
RuntimeHookConfig, RuntimePermissionRuleConfig, RuntimePluginConfig, ScopedMcpServerConfig,
|
||||
CLAW_SETTINGS_SCHEMA_NAME,
|
||||
};
|
||||
pub use config_validate::{
|
||||
check_unsupported_format, format_diagnostics, validate_config_file, ConfigDiagnostic,
|
||||
DiagnosticKind, ValidationResult,
|
||||
};
|
||||
pub use conversation::{
|
||||
auto_compaction_threshold_from_env, ApiClient, ApiRequest, AssistantEvent, AutoCompactionEvent,
|
||||
ConversationRuntime, PromptCacheEvent, RuntimeError, StaticToolExecutor, ToolError,
|
||||
@@ -70,6 +78,7 @@ pub use file_ops::{
|
||||
GrepSearchInput, GrepSearchOutput, ReadFileOutput, StructuredPatchHunk, TextFilePayload,
|
||||
WriteFileOutput,
|
||||
};
|
||||
pub use git_context::{GitCommitEntry, GitContext};
|
||||
pub use hooks::{
|
||||
HookAbortSignal, HookEvent, HookProgressEvent, HookProgressReporter, HookRunResult, HookRunner,
|
||||
};
|
||||
@@ -89,6 +98,7 @@ pub use mcp_lifecycle_hardened::{
|
||||
McpDegradedReport, McpErrorSurface, McpFailedServer, McpLifecyclePhase, McpLifecycleState,
|
||||
McpLifecycleValidator, McpPhaseResult,
|
||||
};
|
||||
pub use mcp_server::{McpServer, McpServerSpec, ToolCallHandler, MCP_SERVER_PROTOCOL_VERSION};
|
||||
pub use mcp_stdio::{
|
||||
spawn_mcp_stdio_process, JsonRpcError, JsonRpcId, JsonRpcRequest, JsonRpcResponse,
|
||||
ManagedMcpTool, McpDiscoveryFailure, McpInitializeClientInfo, McpInitializeParams,
|
||||
@@ -138,9 +148,13 @@ pub use sandbox::{
|
||||
};
|
||||
pub use session::{
|
||||
ContentBlock, ConversationMessage, MessageRole, Session, SessionCompaction, SessionError,
|
||||
SessionFork,
|
||||
SessionFork, SessionPromptEntry,
|
||||
};
|
||||
pub use sse::{IncrementalSseParser, SseEvent};
|
||||
pub use stale_base::{
|
||||
check_base_commit, format_stale_base_warning, read_claw_base_file, resolve_expected_base,
|
||||
BaseCommitSource, BaseCommitState,
|
||||
};
|
||||
pub use stale_branch::{
|
||||
apply_policy, check_freshness, BranchFreshness, StaleBranchAction, StaleBranchEvent,
|
||||
StaleBranchPolicy,
|
||||
|
||||
440
rust/crates/runtime/src/mcp_server.rs
Normal file
440
rust/crates/runtime/src/mcp_server.rs
Normal file
@@ -0,0 +1,440 @@
|
||||
//! Minimal Model Context Protocol (MCP) server.
|
||||
//!
|
||||
//! Implements a newline-safe, LSP-framed JSON-RPC server over stdio that
|
||||
//! answers `initialize`, `tools/list`, and `tools/call` requests. The framing
|
||||
//! matches the client transport implemented in [`crate::mcp_stdio`] so this
|
||||
//! server can be driven by either an external MCP client (e.g. Claude
|
||||
//! Desktop) or `claw`'s own [`McpServerManager`](crate::McpServerManager).
|
||||
//!
|
||||
//! The server is intentionally small: it exposes a list of pre-built
|
||||
//! [`McpTool`] descriptors and delegates `tools/call` to a caller-supplied
|
||||
//! handler. Tool execution itself lives in the `tools` crate; this module is
|
||||
//! purely the transport + dispatch loop.
|
||||
//!
|
||||
//! [`McpTool`]: crate::mcp_stdio::McpTool
|
||||
|
||||
use std::io;
|
||||
|
||||
use serde_json::{json, Value as JsonValue};
|
||||
use tokio::io::{
|
||||
stdin, stdout, AsyncBufReadExt, AsyncReadExt, AsyncWriteExt, BufReader, Stdin, Stdout,
|
||||
};
|
||||
|
||||
use crate::mcp_stdio::{
|
||||
JsonRpcError, JsonRpcId, JsonRpcRequest, JsonRpcResponse, McpInitializeResult,
|
||||
McpInitializeServerInfo, McpListToolsResult, McpTool, McpToolCallContent, McpToolCallParams,
|
||||
McpToolCallResult,
|
||||
};
|
||||
|
||||
/// Protocol version the server advertises during `initialize`.
|
||||
///
|
||||
/// Matches the version used by the built-in client in
|
||||
/// [`crate::mcp_stdio`], so the two stay in lockstep.
|
||||
pub const MCP_SERVER_PROTOCOL_VERSION: &str = "2025-03-26";
|
||||
|
||||
/// Synchronous handler invoked for every `tools/call` request.
|
||||
///
|
||||
/// Returning `Ok(text)` yields a single `text` content block and
|
||||
/// `isError: false`. Returning `Err(message)` yields a `text` block with the
|
||||
/// error and `isError: true`, mirroring the error-surfacing convention used
|
||||
/// elsewhere in claw.
|
||||
pub type ToolCallHandler =
|
||||
Box<dyn Fn(&str, &JsonValue) -> Result<String, String> + Send + Sync + 'static>;
|
||||
|
||||
/// Configuration for an [`McpServer`] instance.
|
||||
///
|
||||
/// Named `McpServerSpec` rather than `McpServerConfig` to avoid colliding
|
||||
/// with the existing client-side [`crate::config::McpServerConfig`] that
|
||||
/// describes *remote* MCP servers the runtime connects to.
|
||||
pub struct McpServerSpec {
|
||||
/// Name advertised in the `serverInfo` field of the `initialize` response.
|
||||
pub server_name: String,
|
||||
/// Version advertised in the `serverInfo` field of the `initialize`
|
||||
/// response.
|
||||
pub server_version: String,
|
||||
/// Tool descriptors returned for `tools/list`.
|
||||
pub tools: Vec<McpTool>,
|
||||
/// Handler invoked for `tools/call`.
|
||||
pub tool_handler: ToolCallHandler,
|
||||
}
|
||||
|
||||
/// Minimal MCP stdio server.
|
||||
///
|
||||
/// The server runs a blocking read/dispatch/write loop over the current
|
||||
/// process's stdin/stdout, terminating cleanly when the peer closes the
|
||||
/// stream.
|
||||
pub struct McpServer {
|
||||
spec: McpServerSpec,
|
||||
stdin: BufReader<Stdin>,
|
||||
stdout: Stdout,
|
||||
}
|
||||
|
||||
impl McpServer {
|
||||
#[must_use]
|
||||
pub fn new(spec: McpServerSpec) -> Self {
|
||||
Self {
|
||||
spec,
|
||||
stdin: BufReader::new(stdin()),
|
||||
stdout: stdout(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Runs the server until the client closes stdin.
|
||||
///
|
||||
/// Returns `Ok(())` on clean EOF; any other I/O error is propagated so
|
||||
/// callers can log and exit non-zero.
|
||||
pub async fn run(&mut self) -> io::Result<()> {
|
||||
loop {
|
||||
let Some(payload) = read_frame(&mut self.stdin).await? else {
|
||||
return Ok(());
|
||||
};
|
||||
|
||||
// Requests and notifications share a wire format; the absence of
|
||||
// `id` distinguishes notifications, which must never receive a
|
||||
// response.
|
||||
let message: JsonValue = match serde_json::from_slice(&payload) {
|
||||
Ok(value) => value,
|
||||
Err(error) => {
|
||||
// Parse error with null id per JSON-RPC 2.0 §4.2.
|
||||
let response = JsonRpcResponse::<JsonValue> {
|
||||
jsonrpc: "2.0".to_string(),
|
||||
id: JsonRpcId::Null,
|
||||
result: None,
|
||||
error: Some(JsonRpcError {
|
||||
code: -32700,
|
||||
message: format!("parse error: {error}"),
|
||||
data: None,
|
||||
}),
|
||||
};
|
||||
write_response(&mut self.stdout, &response).await?;
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
if message.get("id").is_none() {
|
||||
// Notification: dispatch for side effects only (e.g. log),
|
||||
// but send no reply.
|
||||
continue;
|
||||
}
|
||||
|
||||
let request: JsonRpcRequest<JsonValue> = match serde_json::from_value(message) {
|
||||
Ok(request) => request,
|
||||
Err(error) => {
|
||||
let response = JsonRpcResponse::<JsonValue> {
|
||||
jsonrpc: "2.0".to_string(),
|
||||
id: JsonRpcId::Null,
|
||||
result: None,
|
||||
error: Some(JsonRpcError {
|
||||
code: -32600,
|
||||
message: format!("invalid request: {error}"),
|
||||
data: None,
|
||||
}),
|
||||
};
|
||||
write_response(&mut self.stdout, &response).await?;
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
let response = self.dispatch(request);
|
||||
write_response(&mut self.stdout, &response).await?;
|
||||
}
|
||||
}
|
||||
|
||||
fn dispatch(&self, request: JsonRpcRequest<JsonValue>) -> JsonRpcResponse<JsonValue> {
|
||||
let id = request.id.clone();
|
||||
match request.method.as_str() {
|
||||
"initialize" => self.handle_initialize(id),
|
||||
"tools/list" => self.handle_tools_list(id),
|
||||
"tools/call" => self.handle_tools_call(id, request.params),
|
||||
other => JsonRpcResponse {
|
||||
jsonrpc: "2.0".to_string(),
|
||||
id,
|
||||
result: None,
|
||||
error: Some(JsonRpcError {
|
||||
code: -32601,
|
||||
message: format!("method not found: {other}"),
|
||||
data: None,
|
||||
}),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
fn handle_initialize(&self, id: JsonRpcId) -> JsonRpcResponse<JsonValue> {
|
||||
let result = McpInitializeResult {
|
||||
protocol_version: MCP_SERVER_PROTOCOL_VERSION.to_string(),
|
||||
capabilities: json!({ "tools": {} }),
|
||||
server_info: McpInitializeServerInfo {
|
||||
name: self.spec.server_name.clone(),
|
||||
version: self.spec.server_version.clone(),
|
||||
},
|
||||
};
|
||||
JsonRpcResponse {
|
||||
jsonrpc: "2.0".to_string(),
|
||||
id,
|
||||
result: serde_json::to_value(result).ok(),
|
||||
error: None,
|
||||
}
|
||||
}
|
||||
|
||||
fn handle_tools_list(&self, id: JsonRpcId) -> JsonRpcResponse<JsonValue> {
|
||||
let result = McpListToolsResult {
|
||||
tools: self.spec.tools.clone(),
|
||||
next_cursor: None,
|
||||
};
|
||||
JsonRpcResponse {
|
||||
jsonrpc: "2.0".to_string(),
|
||||
id,
|
||||
result: serde_json::to_value(result).ok(),
|
||||
error: None,
|
||||
}
|
||||
}
|
||||
|
||||
fn handle_tools_call(
|
||||
&self,
|
||||
id: JsonRpcId,
|
||||
params: Option<JsonValue>,
|
||||
) -> JsonRpcResponse<JsonValue> {
|
||||
let Some(params) = params else {
|
||||
return invalid_params_response(id, "missing params for tools/call");
|
||||
};
|
||||
let call: McpToolCallParams = match serde_json::from_value(params) {
|
||||
Ok(value) => value,
|
||||
Err(error) => {
|
||||
return invalid_params_response(id, &format!("invalid tools/call params: {error}"));
|
||||
}
|
||||
};
|
||||
let arguments = call.arguments.unwrap_or_else(|| json!({}));
|
||||
let tool_result = (self.spec.tool_handler)(&call.name, &arguments);
|
||||
let (text, is_error) = match tool_result {
|
||||
Ok(text) => (text, false),
|
||||
Err(message) => (message, true),
|
||||
};
|
||||
let mut data = std::collections::BTreeMap::new();
|
||||
data.insert("text".to_string(), JsonValue::String(text));
|
||||
let call_result = McpToolCallResult {
|
||||
content: vec![McpToolCallContent {
|
||||
kind: "text".to_string(),
|
||||
data,
|
||||
}],
|
||||
structured_content: None,
|
||||
is_error: Some(is_error),
|
||||
meta: None,
|
||||
};
|
||||
JsonRpcResponse {
|
||||
jsonrpc: "2.0".to_string(),
|
||||
id,
|
||||
result: serde_json::to_value(call_result).ok(),
|
||||
error: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn invalid_params_response(id: JsonRpcId, message: &str) -> JsonRpcResponse<JsonValue> {
|
||||
JsonRpcResponse {
|
||||
jsonrpc: "2.0".to_string(),
|
||||
id,
|
||||
result: None,
|
||||
error: Some(JsonRpcError {
|
||||
code: -32602,
|
||||
message: message.to_string(),
|
||||
data: None,
|
||||
}),
|
||||
}
|
||||
}
|
||||
|
||||
/// Reads a single LSP-framed JSON-RPC payload from `reader`.
|
||||
///
|
||||
/// Returns `Ok(None)` on clean EOF before any header bytes have been read,
|
||||
/// matching how [`crate::mcp_stdio::McpStdioProcess`] treats stream closure.
|
||||
async fn read_frame(reader: &mut BufReader<Stdin>) -> io::Result<Option<Vec<u8>>> {
|
||||
let mut content_length: Option<usize> = None;
|
||||
let mut first_header = true;
|
||||
loop {
|
||||
let mut line = String::new();
|
||||
let bytes_read = reader.read_line(&mut line).await?;
|
||||
if bytes_read == 0 {
|
||||
if first_header {
|
||||
return Ok(None);
|
||||
}
|
||||
return Err(io::Error::new(
|
||||
io::ErrorKind::UnexpectedEof,
|
||||
"MCP stdio stream closed while reading headers",
|
||||
));
|
||||
}
|
||||
first_header = false;
|
||||
if line == "\r\n" || line == "\n" {
|
||||
break;
|
||||
}
|
||||
let header = line.trim_end_matches(['\r', '\n']);
|
||||
if let Some((name, value)) = header.split_once(':') {
|
||||
if name.trim().eq_ignore_ascii_case("Content-Length") {
|
||||
let parsed = value
|
||||
.trim()
|
||||
.parse::<usize>()
|
||||
.map_err(|error| io::Error::new(io::ErrorKind::InvalidData, error))?;
|
||||
content_length = Some(parsed);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let content_length = content_length.ok_or_else(|| {
|
||||
io::Error::new(io::ErrorKind::InvalidData, "missing Content-Length header")
|
||||
})?;
|
||||
let mut payload = vec![0_u8; content_length];
|
||||
reader.read_exact(&mut payload).await?;
|
||||
Ok(Some(payload))
|
||||
}
|
||||
|
||||
async fn write_response(
|
||||
stdout: &mut Stdout,
|
||||
response: &JsonRpcResponse<JsonValue>,
|
||||
) -> io::Result<()> {
|
||||
let body = serde_json::to_vec(response)
|
||||
.map_err(|error| io::Error::new(io::ErrorKind::InvalidData, error))?;
|
||||
let header = format!("Content-Length: {}\r\n\r\n", body.len());
|
||||
stdout.write_all(header.as_bytes()).await?;
|
||||
stdout.write_all(&body).await?;
|
||||
stdout.flush().await
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn dispatch_initialize_returns_server_info() {
|
||||
let server = McpServer {
|
||||
spec: McpServerSpec {
|
||||
server_name: "test".to_string(),
|
||||
server_version: "9.9.9".to_string(),
|
||||
tools: Vec::new(),
|
||||
tool_handler: Box::new(|_, _| Ok(String::new())),
|
||||
},
|
||||
stdin: BufReader::new(stdin()),
|
||||
stdout: stdout(),
|
||||
};
|
||||
let request = JsonRpcRequest::<JsonValue> {
|
||||
jsonrpc: "2.0".to_string(),
|
||||
id: JsonRpcId::Number(1),
|
||||
method: "initialize".to_string(),
|
||||
params: None,
|
||||
};
|
||||
let response = server.dispatch(request);
|
||||
assert_eq!(response.id, JsonRpcId::Number(1));
|
||||
assert!(response.error.is_none());
|
||||
let result = response.result.expect("initialize result");
|
||||
assert_eq!(result["protocolVersion"], MCP_SERVER_PROTOCOL_VERSION);
|
||||
assert_eq!(result["serverInfo"]["name"], "test");
|
||||
assert_eq!(result["serverInfo"]["version"], "9.9.9");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn dispatch_tools_list_returns_registered_tools() {
|
||||
let tool = McpTool {
|
||||
name: "echo".to_string(),
|
||||
description: Some("Echo".to_string()),
|
||||
input_schema: Some(json!({"type": "object"})),
|
||||
annotations: None,
|
||||
meta: None,
|
||||
};
|
||||
let server = McpServer {
|
||||
spec: McpServerSpec {
|
||||
server_name: "test".to_string(),
|
||||
server_version: "0.0.0".to_string(),
|
||||
tools: vec![tool.clone()],
|
||||
tool_handler: Box::new(|_, _| Ok(String::new())),
|
||||
},
|
||||
stdin: BufReader::new(stdin()),
|
||||
stdout: stdout(),
|
||||
};
|
||||
let request = JsonRpcRequest::<JsonValue> {
|
||||
jsonrpc: "2.0".to_string(),
|
||||
id: JsonRpcId::Number(2),
|
||||
method: "tools/list".to_string(),
|
||||
params: None,
|
||||
};
|
||||
let response = server.dispatch(request);
|
||||
assert!(response.error.is_none());
|
||||
let result = response.result.expect("tools/list result");
|
||||
assert_eq!(result["tools"][0]["name"], "echo");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn dispatch_tools_call_wraps_handler_output() {
|
||||
let server = McpServer {
|
||||
spec: McpServerSpec {
|
||||
server_name: "test".to_string(),
|
||||
server_version: "0.0.0".to_string(),
|
||||
tools: Vec::new(),
|
||||
tool_handler: Box::new(|name, args| Ok(format!("called {name} with {args}"))),
|
||||
},
|
||||
stdin: BufReader::new(stdin()),
|
||||
stdout: stdout(),
|
||||
};
|
||||
let request = JsonRpcRequest::<JsonValue> {
|
||||
jsonrpc: "2.0".to_string(),
|
||||
id: JsonRpcId::Number(3),
|
||||
method: "tools/call".to_string(),
|
||||
params: Some(json!({
|
||||
"name": "echo",
|
||||
"arguments": {"text": "hi"}
|
||||
})),
|
||||
};
|
||||
let response = server.dispatch(request);
|
||||
assert!(response.error.is_none());
|
||||
let result = response.result.expect("tools/call result");
|
||||
assert_eq!(result["isError"], false);
|
||||
assert_eq!(result["content"][0]["type"], "text");
|
||||
assert!(result["content"][0]["text"]
|
||||
.as_str()
|
||||
.unwrap()
|
||||
.starts_with("called echo"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn dispatch_tools_call_surfaces_handler_error() {
|
||||
let server = McpServer {
|
||||
spec: McpServerSpec {
|
||||
server_name: "test".to_string(),
|
||||
server_version: "0.0.0".to_string(),
|
||||
tools: Vec::new(),
|
||||
tool_handler: Box::new(|_, _| Err("boom".to_string())),
|
||||
},
|
||||
stdin: BufReader::new(stdin()),
|
||||
stdout: stdout(),
|
||||
};
|
||||
let request = JsonRpcRequest::<JsonValue> {
|
||||
jsonrpc: "2.0".to_string(),
|
||||
id: JsonRpcId::Number(4),
|
||||
method: "tools/call".to_string(),
|
||||
params: Some(json!({"name": "broken"})),
|
||||
};
|
||||
let response = server.dispatch(request);
|
||||
let result = response.result.expect("tools/call result");
|
||||
assert_eq!(result["isError"], true);
|
||||
assert_eq!(result["content"][0]["text"], "boom");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn dispatch_unknown_method_returns_method_not_found() {
|
||||
let server = McpServer {
|
||||
spec: McpServerSpec {
|
||||
server_name: "test".to_string(),
|
||||
server_version: "0.0.0".to_string(),
|
||||
tools: Vec::new(),
|
||||
tool_handler: Box::new(|_, _| Ok(String::new())),
|
||||
},
|
||||
stdin: BufReader::new(stdin()),
|
||||
stdout: stdout(),
|
||||
};
|
||||
let request = JsonRpcRequest::<JsonValue> {
|
||||
jsonrpc: "2.0".to_string(),
|
||||
id: JsonRpcId::Number(5),
|
||||
method: "nonsense".to_string(),
|
||||
params: None,
|
||||
};
|
||||
let response = server.dispatch(request);
|
||||
let error = response.error.expect("error payload");
|
||||
assert_eq!(error.code, -32601);
|
||||
}
|
||||
}
|
||||
@@ -4,6 +4,7 @@ use std::path::{Path, PathBuf};
|
||||
use std::process::Command;
|
||||
|
||||
use crate::config::{ConfigError, ConfigLoader, RuntimeConfig};
|
||||
use crate::git_context::GitContext;
|
||||
|
||||
/// Errors raised while assembling the final system prompt.
|
||||
#[derive(Debug)]
|
||||
@@ -56,6 +57,7 @@ pub struct ProjectContext {
|
||||
pub current_date: String,
|
||||
pub git_status: Option<String>,
|
||||
pub git_diff: Option<String>,
|
||||
pub git_context: Option<GitContext>,
|
||||
pub instruction_files: Vec<ContextFile>,
|
||||
}
|
||||
|
||||
@@ -71,6 +73,7 @@ impl ProjectContext {
|
||||
current_date: current_date.into(),
|
||||
git_status: None,
|
||||
git_diff: None,
|
||||
git_context: None,
|
||||
instruction_files,
|
||||
})
|
||||
}
|
||||
@@ -82,6 +85,7 @@ impl ProjectContext {
|
||||
let mut context = Self::discover(cwd, current_date)?;
|
||||
context.git_status = read_git_status(&context.cwd);
|
||||
context.git_diff = read_git_diff(&context.cwd);
|
||||
context.git_context = GitContext::detect(&context.cwd);
|
||||
Ok(context)
|
||||
}
|
||||
}
|
||||
@@ -299,11 +303,27 @@ fn render_project_context(project_context: &ProjectContext) -> String {
|
||||
lines.push("Git status snapshot:".to_string());
|
||||
lines.push(status.clone());
|
||||
}
|
||||
if let Some(ref gc) = project_context.git_context {
|
||||
if !gc.recent_commits.is_empty() {
|
||||
lines.push(String::new());
|
||||
lines.push("Recent commits (last 5):".to_string());
|
||||
for c in &gc.recent_commits {
|
||||
lines.push(format!(" {} {}", c.hash, c.subject));
|
||||
}
|
||||
}
|
||||
}
|
||||
if let Some(diff) = &project_context.git_diff {
|
||||
lines.push(String::new());
|
||||
lines.push("Git diff snapshot:".to_string());
|
||||
lines.push(diff.clone());
|
||||
}
|
||||
if let Some(git_context) = &project_context.git_context {
|
||||
let rendered = git_context.render();
|
||||
if !rendered.is_empty() {
|
||||
lines.push(String::new());
|
||||
lines.push(rendered);
|
||||
}
|
||||
}
|
||||
lines.join("\n")
|
||||
}
|
||||
|
||||
@@ -639,6 +659,88 @@ mod tests {
|
||||
fs::remove_dir_all(root).expect("cleanup temp dir");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn discover_with_git_includes_recent_commits_and_renders_them() {
|
||||
// given: a git repo with three commits and a current branch
|
||||
let _guard = env_lock();
|
||||
ensure_valid_cwd();
|
||||
let root = temp_dir();
|
||||
fs::create_dir_all(&root).expect("root dir");
|
||||
std::process::Command::new("git")
|
||||
.args(["init", "--quiet", "-b", "main"])
|
||||
.current_dir(&root)
|
||||
.status()
|
||||
.expect("git init should run");
|
||||
std::process::Command::new("git")
|
||||
.args(["config", "user.email", "tests@example.com"])
|
||||
.current_dir(&root)
|
||||
.status()
|
||||
.expect("git config email should run");
|
||||
std::process::Command::new("git")
|
||||
.args(["config", "user.name", "Runtime Prompt Tests"])
|
||||
.current_dir(&root)
|
||||
.status()
|
||||
.expect("git config name should run");
|
||||
for (file, message) in [
|
||||
("a.txt", "first commit"),
|
||||
("b.txt", "second commit"),
|
||||
("c.txt", "third commit"),
|
||||
] {
|
||||
fs::write(root.join(file), "x\n").expect("write commit file");
|
||||
std::process::Command::new("git")
|
||||
.args(["add", file])
|
||||
.current_dir(&root)
|
||||
.status()
|
||||
.expect("git add should run");
|
||||
std::process::Command::new("git")
|
||||
.args(["commit", "-m", message, "--quiet"])
|
||||
.current_dir(&root)
|
||||
.status()
|
||||
.expect("git commit should run");
|
||||
}
|
||||
fs::write(root.join("d.txt"), "staged\n").expect("write staged file");
|
||||
std::process::Command::new("git")
|
||||
.args(["add", "d.txt"])
|
||||
.current_dir(&root)
|
||||
.status()
|
||||
.expect("git add staged should run");
|
||||
|
||||
// when: discovering project context with git auto-include
|
||||
let context =
|
||||
ProjectContext::discover_with_git(&root, "2026-03-31").expect("context should load");
|
||||
let rendered = SystemPromptBuilder::new()
|
||||
.with_os("linux", "6.8")
|
||||
.with_project_context(context.clone())
|
||||
.render();
|
||||
|
||||
// then: branch, recent commits and staged files are present in context
|
||||
let gc = context
|
||||
.git_context
|
||||
.as_ref()
|
||||
.expect("git context should be present");
|
||||
let commits: String = gc
|
||||
.recent_commits
|
||||
.iter()
|
||||
.map(|c| c.subject.clone())
|
||||
.collect::<Vec<_>>()
|
||||
.join("\n");
|
||||
assert!(commits.contains("first commit"));
|
||||
assert!(commits.contains("second commit"));
|
||||
assert!(commits.contains("third commit"));
|
||||
assert_eq!(gc.recent_commits.len(), 3);
|
||||
|
||||
let status = context.git_status.as_deref().expect("status snapshot");
|
||||
assert!(status.contains("## main"));
|
||||
assert!(status.contains("A d.txt"));
|
||||
|
||||
assert!(rendered.contains("Recent commits (last 5):"));
|
||||
assert!(rendered.contains("first commit"));
|
||||
assert!(rendered.contains("Git status snapshot:"));
|
||||
assert!(rendered.contains("## main"));
|
||||
|
||||
fs::remove_dir_all(root).expect("cleanup temp dir");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn discover_with_git_includes_diff_snapshot_for_tracked_changes() {
|
||||
let _guard = env_lock();
|
||||
|
||||
@@ -65,6 +65,13 @@ pub struct SessionFork {
|
||||
pub branch_name: Option<String>,
|
||||
}
|
||||
|
||||
/// A single user prompt recorded with a timestamp for history tracking.
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub struct SessionPromptEntry {
|
||||
pub timestamp_ms: u64,
|
||||
pub text: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
struct SessionPersistence {
|
||||
path: PathBuf,
|
||||
@@ -88,6 +95,7 @@ pub struct Session {
|
||||
pub compaction: Option<SessionCompaction>,
|
||||
pub fork: Option<SessionFork>,
|
||||
pub workspace_root: Option<PathBuf>,
|
||||
pub prompt_history: Vec<SessionPromptEntry>,
|
||||
persistence: Option<SessionPersistence>,
|
||||
}
|
||||
|
||||
@@ -101,6 +109,7 @@ impl PartialEq for Session {
|
||||
&& self.compaction == other.compaction
|
||||
&& self.fork == other.fork
|
||||
&& self.workspace_root == other.workspace_root
|
||||
&& self.prompt_history == other.prompt_history
|
||||
}
|
||||
}
|
||||
|
||||
@@ -151,6 +160,7 @@ impl Session {
|
||||
compaction: None,
|
||||
fork: None,
|
||||
workspace_root: None,
|
||||
prompt_history: Vec::new(),
|
||||
persistence: None,
|
||||
}
|
||||
}
|
||||
@@ -252,6 +262,7 @@ impl Session {
|
||||
branch_name: normalize_optional_string(branch_name),
|
||||
}),
|
||||
workspace_root: self.workspace_root.clone(),
|
||||
prompt_history: self.prompt_history.clone(),
|
||||
persistence: None,
|
||||
}
|
||||
}
|
||||
@@ -295,6 +306,17 @@ impl Session {
|
||||
JsonValue::String(workspace_root_to_string(workspace_root)?),
|
||||
);
|
||||
}
|
||||
if !self.prompt_history.is_empty() {
|
||||
object.insert(
|
||||
"prompt_history".to_string(),
|
||||
JsonValue::Array(
|
||||
self.prompt_history
|
||||
.iter()
|
||||
.map(SessionPromptEntry::to_jsonl_record)
|
||||
.collect(),
|
||||
),
|
||||
);
|
||||
}
|
||||
Ok(JsonValue::Object(object))
|
||||
}
|
||||
|
||||
@@ -339,6 +361,16 @@ impl Session {
|
||||
.get("workspace_root")
|
||||
.and_then(JsonValue::as_str)
|
||||
.map(PathBuf::from);
|
||||
let prompt_history = object
|
||||
.get("prompt_history")
|
||||
.and_then(JsonValue::as_array)
|
||||
.map(|entries| {
|
||||
entries
|
||||
.iter()
|
||||
.filter_map(SessionPromptEntry::from_json_opt)
|
||||
.collect()
|
||||
})
|
||||
.unwrap_or_default();
|
||||
Ok(Self {
|
||||
version,
|
||||
session_id,
|
||||
@@ -348,6 +380,7 @@ impl Session {
|
||||
compaction,
|
||||
fork,
|
||||
workspace_root,
|
||||
prompt_history,
|
||||
persistence: None,
|
||||
})
|
||||
}
|
||||
@@ -361,6 +394,7 @@ impl Session {
|
||||
let mut compaction = None;
|
||||
let mut fork = None;
|
||||
let mut workspace_root = None;
|
||||
let mut prompt_history = Vec::new();
|
||||
|
||||
for (line_number, raw_line) in contents.lines().enumerate() {
|
||||
let line = raw_line.trim();
|
||||
@@ -414,6 +448,13 @@ impl Session {
|
||||
object.clone(),
|
||||
))?);
|
||||
}
|
||||
"prompt_history" => {
|
||||
if let Some(entry) =
|
||||
SessionPromptEntry::from_json_opt(&JsonValue::Object(object.clone()))
|
||||
{
|
||||
prompt_history.push(entry);
|
||||
}
|
||||
}
|
||||
other => {
|
||||
return Err(SessionError::Format(format!(
|
||||
"unsupported JSONL record type at line {}: {other}",
|
||||
@@ -433,15 +474,36 @@ impl Session {
|
||||
compaction,
|
||||
fork,
|
||||
workspace_root,
|
||||
prompt_history,
|
||||
persistence: None,
|
||||
})
|
||||
}
|
||||
|
||||
/// Record a user prompt with the current wall-clock timestamp.
|
||||
///
|
||||
/// The entry is appended to the in-memory history and, when a persistence
|
||||
/// path is configured, incrementally written to the JSONL session file.
|
||||
pub fn push_prompt_entry(&mut self, text: impl Into<String>) -> Result<(), SessionError> {
|
||||
let timestamp_ms = current_time_millis();
|
||||
let entry = SessionPromptEntry {
|
||||
timestamp_ms,
|
||||
text: text.into(),
|
||||
};
|
||||
self.prompt_history.push(entry);
|
||||
let entry_ref = self.prompt_history.last().expect("entry was just pushed");
|
||||
self.append_persisted_prompt_entry(entry_ref)
|
||||
}
|
||||
|
||||
fn render_jsonl_snapshot(&self) -> Result<String, SessionError> {
|
||||
let mut lines = vec![self.meta_record()?.render()];
|
||||
if let Some(compaction) = &self.compaction {
|
||||
lines.push(compaction.to_jsonl_record()?.render());
|
||||
}
|
||||
lines.extend(
|
||||
self.prompt_history
|
||||
.iter()
|
||||
.map(|entry| entry.to_jsonl_record().render()),
|
||||
);
|
||||
lines.extend(
|
||||
self.messages
|
||||
.iter()
|
||||
@@ -468,6 +530,25 @@ impl Session {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn append_persisted_prompt_entry(
|
||||
&self,
|
||||
entry: &SessionPromptEntry,
|
||||
) -> Result<(), SessionError> {
|
||||
let Some(path) = self.persistence_path() else {
|
||||
return Ok(());
|
||||
};
|
||||
|
||||
let needs_bootstrap = !path.exists() || fs::metadata(path)?.len() == 0;
|
||||
if needs_bootstrap {
|
||||
self.save_to_path(path)?;
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let mut file = OpenOptions::new().append(true).open(path)?;
|
||||
writeln!(file, "{}", entry.to_jsonl_record().render())?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn meta_record(&self) -> Result<JsonValue, SessionError> {
|
||||
let mut object = BTreeMap::new();
|
||||
object.insert(
|
||||
@@ -784,6 +865,33 @@ impl SessionFork {
|
||||
}
|
||||
}
|
||||
|
||||
impl SessionPromptEntry {
|
||||
#[must_use]
|
||||
pub fn to_jsonl_record(&self) -> JsonValue {
|
||||
let mut object = BTreeMap::new();
|
||||
object.insert(
|
||||
"type".to_string(),
|
||||
JsonValue::String("prompt_history".to_string()),
|
||||
);
|
||||
object.insert(
|
||||
"timestamp_ms".to_string(),
|
||||
JsonValue::Number(i64::try_from(self.timestamp_ms).unwrap_or(i64::MAX)),
|
||||
);
|
||||
object.insert("text".to_string(), JsonValue::String(self.text.clone()));
|
||||
JsonValue::Object(object)
|
||||
}
|
||||
|
||||
fn from_json_opt(value: &JsonValue) -> Option<Self> {
|
||||
let object = value.as_object()?;
|
||||
let timestamp_ms = object
|
||||
.get("timestamp_ms")
|
||||
.and_then(JsonValue::as_i64)
|
||||
.and_then(|value| u64::try_from(value).ok())?;
|
||||
let text = object.get("text").and_then(JsonValue::as_str)?.to_string();
|
||||
Some(Self { timestamp_ms, text })
|
||||
}
|
||||
}
|
||||
|
||||
fn message_record(message: &ConversationMessage) -> JsonValue {
|
||||
let mut object = BTreeMap::new();
|
||||
object.insert("type".to_string(), JsonValue::String("message".to_string()));
|
||||
@@ -1326,3 +1434,63 @@ mod tests {
|
||||
.collect()
|
||||
}
|
||||
}
|
||||
|
||||
/// Per-worktree session isolation: returns a session directory namespaced
|
||||
/// by the workspace fingerprint of the given working directory.
|
||||
/// This prevents parallel `opencode serve` instances from colliding.
|
||||
/// Called by external consumers (e.g. clawhip) to enumerate sessions for a CWD.
|
||||
#[allow(dead_code)]
|
||||
pub fn workspace_sessions_dir(cwd: &std::path::Path) -> Result<std::path::PathBuf, SessionError> {
|
||||
let store = crate::session_control::SessionStore::from_cwd(cwd).map_err(|e| {
|
||||
SessionError::Io(std::io::Error::new(
|
||||
std::io::ErrorKind::Other,
|
||||
e.to_string(),
|
||||
))
|
||||
})?;
|
||||
Ok(store.sessions_dir().to_path_buf())
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod workspace_sessions_dir_tests {
|
||||
use super::*;
|
||||
use std::fs;
|
||||
|
||||
#[test]
|
||||
fn workspace_sessions_dir_returns_fingerprinted_path_for_valid_cwd() {
|
||||
let tmp = std::env::temp_dir().join("claw-session-dir-test");
|
||||
fs::create_dir_all(&tmp).expect("create temp dir");
|
||||
|
||||
let result = workspace_sessions_dir(&tmp);
|
||||
assert!(
|
||||
result.is_ok(),
|
||||
"workspace_sessions_dir should succeed for a valid CWD, got: {:?}",
|
||||
result
|
||||
);
|
||||
let dir = result.unwrap();
|
||||
// The returned path should be non-empty and end with a hash component
|
||||
assert!(!dir.as_os_str().is_empty());
|
||||
// Two calls with the same CWD should produce identical paths (deterministic)
|
||||
let result2 = workspace_sessions_dir(&tmp).unwrap();
|
||||
assert_eq!(dir, result2, "workspace_sessions_dir must be deterministic");
|
||||
|
||||
fs::remove_dir_all(&tmp).ok();
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn workspace_sessions_dir_differs_for_different_cwds() {
|
||||
let tmp_a = std::env::temp_dir().join("claw-session-dir-a");
|
||||
let tmp_b = std::env::temp_dir().join("claw-session-dir-b");
|
||||
fs::create_dir_all(&tmp_a).expect("create dir a");
|
||||
fs::create_dir_all(&tmp_b).expect("create dir b");
|
||||
|
||||
let dir_a = workspace_sessions_dir(&tmp_a).expect("dir a");
|
||||
let dir_b = workspace_sessions_dir(&tmp_b).expect("dir b");
|
||||
assert_ne!(
|
||||
dir_a, dir_b,
|
||||
"different CWDs must produce different session dirs"
|
||||
);
|
||||
|
||||
fs::remove_dir_all(&tmp_a).ok();
|
||||
fs::remove_dir_all(&tmp_b).ok();
|
||||
}
|
||||
}
|
||||
|
||||
@@ -7,6 +7,252 @@ use std::time::UNIX_EPOCH;
|
||||
|
||||
use crate::session::{Session, SessionError};
|
||||
|
||||
/// Per-worktree session store that namespaces on-disk session files by
|
||||
/// workspace fingerprint so that parallel `opencode serve` instances never
|
||||
/// collide.
|
||||
///
|
||||
/// Create via [`SessionStore::from_cwd`] (derives the store path from the
|
||||
/// server's working directory) or [`SessionStore::from_data_dir`] (honours an
|
||||
/// explicit `--data-dir` flag). Both constructors produce a directory layout
|
||||
/// of `<data_dir>/sessions/<workspace_hash>/` where `<workspace_hash>` is a
|
||||
/// stable hex digest of the canonical workspace root.
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub struct SessionStore {
|
||||
/// Resolved root of the session namespace, e.g.
|
||||
/// `/home/user/project/.claw/sessions/a1b2c3d4e5f60718/`.
|
||||
sessions_root: PathBuf,
|
||||
/// The canonical workspace path that was fingerprinted.
|
||||
workspace_root: PathBuf,
|
||||
}
|
||||
|
||||
impl SessionStore {
|
||||
/// Build a store from the server's current working directory.
|
||||
///
|
||||
/// The on-disk layout becomes `<cwd>/.claw/sessions/<workspace_hash>/`.
|
||||
pub fn from_cwd(cwd: impl AsRef<Path>) -> Result<Self, SessionControlError> {
|
||||
let cwd = cwd.as_ref();
|
||||
let sessions_root = cwd
|
||||
.join(".claw")
|
||||
.join("sessions")
|
||||
.join(workspace_fingerprint(cwd));
|
||||
fs::create_dir_all(&sessions_root)?;
|
||||
Ok(Self {
|
||||
sessions_root,
|
||||
workspace_root: cwd.to_path_buf(),
|
||||
})
|
||||
}
|
||||
|
||||
/// Build a store from an explicit `--data-dir` flag.
|
||||
///
|
||||
/// The on-disk layout becomes `<data_dir>/sessions/<workspace_hash>/`
|
||||
/// where `<workspace_hash>` is derived from `workspace_root`.
|
||||
pub fn from_data_dir(
|
||||
data_dir: impl AsRef<Path>,
|
||||
workspace_root: impl AsRef<Path>,
|
||||
) -> Result<Self, SessionControlError> {
|
||||
let workspace_root = workspace_root.as_ref();
|
||||
let sessions_root = data_dir
|
||||
.as_ref()
|
||||
.join("sessions")
|
||||
.join(workspace_fingerprint(workspace_root));
|
||||
fs::create_dir_all(&sessions_root)?;
|
||||
Ok(Self {
|
||||
sessions_root,
|
||||
workspace_root: workspace_root.to_path_buf(),
|
||||
})
|
||||
}
|
||||
|
||||
/// The fully resolved sessions directory for this namespace.
|
||||
#[must_use]
|
||||
pub fn sessions_dir(&self) -> &Path {
|
||||
&self.sessions_root
|
||||
}
|
||||
|
||||
/// The workspace root this store is bound to.
|
||||
#[must_use]
|
||||
pub fn workspace_root(&self) -> &Path {
|
||||
&self.workspace_root
|
||||
}
|
||||
|
||||
pub fn create_handle(&self, session_id: &str) -> SessionHandle {
|
||||
let id = session_id.to_string();
|
||||
let path = self
|
||||
.sessions_root
|
||||
.join(format!("{id}.{PRIMARY_SESSION_EXTENSION}"));
|
||||
SessionHandle { id, path }
|
||||
}
|
||||
|
||||
pub fn resolve_reference(&self, reference: &str) -> Result<SessionHandle, SessionControlError> {
|
||||
if is_session_reference_alias(reference) {
|
||||
let latest = self.latest_session()?;
|
||||
return Ok(SessionHandle {
|
||||
id: latest.id,
|
||||
path: latest.path,
|
||||
});
|
||||
}
|
||||
|
||||
let direct = PathBuf::from(reference);
|
||||
let candidate = if direct.is_absolute() {
|
||||
direct.clone()
|
||||
} else {
|
||||
self.workspace_root.join(&direct)
|
||||
};
|
||||
let looks_like_path = direct.extension().is_some() || direct.components().count() > 1;
|
||||
let path = if candidate.exists() {
|
||||
candidate
|
||||
} else if looks_like_path {
|
||||
return Err(SessionControlError::Format(
|
||||
format_missing_session_reference(reference),
|
||||
));
|
||||
} else {
|
||||
self.resolve_managed_path(reference)?
|
||||
};
|
||||
|
||||
Ok(SessionHandle {
|
||||
id: session_id_from_path(&path).unwrap_or_else(|| reference.to_string()),
|
||||
path,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn resolve_managed_path(&self, session_id: &str) -> Result<PathBuf, SessionControlError> {
|
||||
for extension in [PRIMARY_SESSION_EXTENSION, LEGACY_SESSION_EXTENSION] {
|
||||
let path = self.sessions_root.join(format!("{session_id}.{extension}"));
|
||||
if path.exists() {
|
||||
return Ok(path);
|
||||
}
|
||||
}
|
||||
Err(SessionControlError::Format(
|
||||
format_missing_session_reference(session_id),
|
||||
))
|
||||
}
|
||||
|
||||
pub fn list_sessions(&self) -> Result<Vec<ManagedSessionSummary>, SessionControlError> {
|
||||
let mut sessions = Vec::new();
|
||||
let read_result = fs::read_dir(&self.sessions_root);
|
||||
let entries = match read_result {
|
||||
Ok(entries) => entries,
|
||||
Err(err) if err.kind() == std::io::ErrorKind::NotFound => return Ok(sessions),
|
||||
Err(err) => return Err(err.into()),
|
||||
};
|
||||
for entry in entries {
|
||||
let entry = entry?;
|
||||
let path = entry.path();
|
||||
if !is_managed_session_file(&path) {
|
||||
continue;
|
||||
}
|
||||
let metadata = entry.metadata()?;
|
||||
let modified_epoch_millis = metadata
|
||||
.modified()
|
||||
.ok()
|
||||
.and_then(|time| time.duration_since(UNIX_EPOCH).ok())
|
||||
.map(|duration| duration.as_millis())
|
||||
.unwrap_or_default();
|
||||
let (id, message_count, parent_session_id, branch_name) =
|
||||
match Session::load_from_path(&path) {
|
||||
Ok(session) => {
|
||||
let parent_session_id = session
|
||||
.fork
|
||||
.as_ref()
|
||||
.map(|fork| fork.parent_session_id.clone());
|
||||
let branch_name = session
|
||||
.fork
|
||||
.as_ref()
|
||||
.and_then(|fork| fork.branch_name.clone());
|
||||
(
|
||||
session.session_id,
|
||||
session.messages.len(),
|
||||
parent_session_id,
|
||||
branch_name,
|
||||
)
|
||||
}
|
||||
Err(_) => (
|
||||
path.file_stem()
|
||||
.and_then(|value| value.to_str())
|
||||
.unwrap_or("unknown")
|
||||
.to_string(),
|
||||
0,
|
||||
None,
|
||||
None,
|
||||
),
|
||||
};
|
||||
sessions.push(ManagedSessionSummary {
|
||||
id,
|
||||
path,
|
||||
modified_epoch_millis,
|
||||
message_count,
|
||||
parent_session_id,
|
||||
branch_name,
|
||||
});
|
||||
}
|
||||
sessions.sort_by(|left, right| {
|
||||
right
|
||||
.modified_epoch_millis
|
||||
.cmp(&left.modified_epoch_millis)
|
||||
.then_with(|| right.id.cmp(&left.id))
|
||||
});
|
||||
Ok(sessions)
|
||||
}
|
||||
|
||||
pub fn latest_session(&self) -> Result<ManagedSessionSummary, SessionControlError> {
|
||||
self.list_sessions()?
|
||||
.into_iter()
|
||||
.next()
|
||||
.ok_or_else(|| SessionControlError::Format(format_no_managed_sessions()))
|
||||
}
|
||||
|
||||
pub fn load_session(
|
||||
&self,
|
||||
reference: &str,
|
||||
) -> Result<LoadedManagedSession, SessionControlError> {
|
||||
let handle = self.resolve_reference(reference)?;
|
||||
let session = Session::load_from_path(&handle.path)?;
|
||||
Ok(LoadedManagedSession {
|
||||
handle: SessionHandle {
|
||||
id: session.session_id.clone(),
|
||||
path: handle.path,
|
||||
},
|
||||
session,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn fork_session(
|
||||
&self,
|
||||
session: &Session,
|
||||
branch_name: Option<String>,
|
||||
) -> Result<ForkedManagedSession, SessionControlError> {
|
||||
let parent_session_id = session.session_id.clone();
|
||||
let forked = session.fork(branch_name);
|
||||
let handle = self.create_handle(&forked.session_id);
|
||||
let branch_name = forked
|
||||
.fork
|
||||
.as_ref()
|
||||
.and_then(|fork| fork.branch_name.clone());
|
||||
let forked = forked.with_persistence_path(handle.path.clone());
|
||||
forked.save_to_path(&handle.path)?;
|
||||
Ok(ForkedManagedSession {
|
||||
parent_session_id,
|
||||
handle,
|
||||
session: forked,
|
||||
branch_name,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
/// Stable hex fingerprint of a workspace path.
|
||||
///
|
||||
/// Uses FNV-1a (64-bit) to produce a 16-char hex string that partitions the
|
||||
/// on-disk session directory per workspace root.
|
||||
#[must_use]
|
||||
pub fn workspace_fingerprint(workspace_root: &Path) -> String {
|
||||
let input = workspace_root.to_string_lossy();
|
||||
let mut hash = 0xcbf2_9ce4_8422_2325_u64;
|
||||
for byte in input.as_bytes() {
|
||||
hash ^= u64::from(*byte);
|
||||
hash = hash.wrapping_mul(0x0100_0000_01b3);
|
||||
}
|
||||
format!("{hash:016x}")
|
||||
}
|
||||
|
||||
pub const PRIMARY_SESSION_EXTENSION: &str = "jsonl";
|
||||
pub const LEGACY_SESSION_EXTENSION: &str = "json";
|
||||
pub const LATEST_SESSION_REFERENCE: &str = "latest";
|
||||
@@ -333,7 +579,7 @@ mod tests {
|
||||
use super::{
|
||||
create_managed_session_handle_for, fork_managed_session_for, is_session_reference_alias,
|
||||
list_managed_sessions_for, load_managed_session_for, resolve_session_reference_for,
|
||||
ManagedSessionSummary, LATEST_SESSION_REFERENCE,
|
||||
workspace_fingerprint, ManagedSessionSummary, SessionStore, LATEST_SESSION_REFERENCE,
|
||||
};
|
||||
use crate::session::Session;
|
||||
use std::fs;
|
||||
@@ -456,4 +702,172 @@ mod tests {
|
||||
);
|
||||
fs::remove_dir_all(root).expect("temp dir should clean up");
|
||||
}
|
||||
|
||||
// ------------------------------------------------------------------
|
||||
// Per-worktree session isolation (SessionStore) tests
|
||||
// ------------------------------------------------------------------
|
||||
|
||||
fn persist_session_via_store(store: &SessionStore, text: &str) -> Session {
|
||||
let mut session = Session::new();
|
||||
session
|
||||
.push_user_text(text)
|
||||
.expect("session message should save");
|
||||
let handle = store.create_handle(&session.session_id);
|
||||
let session = session.with_persistence_path(handle.path.clone());
|
||||
session
|
||||
.save_to_path(&handle.path)
|
||||
.expect("session should persist");
|
||||
session
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn workspace_fingerprint_is_deterministic_and_differs_per_path() {
|
||||
// given
|
||||
let path_a = Path::new("/tmp/worktree-alpha");
|
||||
let path_b = Path::new("/tmp/worktree-beta");
|
||||
|
||||
// when
|
||||
let fp_a1 = workspace_fingerprint(path_a);
|
||||
let fp_a2 = workspace_fingerprint(path_a);
|
||||
let fp_b = workspace_fingerprint(path_b);
|
||||
|
||||
// then
|
||||
assert_eq!(fp_a1, fp_a2, "same path must produce the same fingerprint");
|
||||
assert_ne!(
|
||||
fp_a1, fp_b,
|
||||
"different paths must produce different fingerprints"
|
||||
);
|
||||
assert_eq!(fp_a1.len(), 16, "fingerprint must be a 16-char hex string");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn session_store_from_cwd_isolates_sessions_by_workspace() {
|
||||
// given
|
||||
let base = temp_dir();
|
||||
let workspace_a = base.join("repo-alpha");
|
||||
let workspace_b = base.join("repo-beta");
|
||||
fs::create_dir_all(&workspace_a).expect("workspace a should exist");
|
||||
fs::create_dir_all(&workspace_b).expect("workspace b should exist");
|
||||
|
||||
let store_a = SessionStore::from_cwd(&workspace_a).expect("store a should build");
|
||||
let store_b = SessionStore::from_cwd(&workspace_b).expect("store b should build");
|
||||
|
||||
// when
|
||||
let session_a = persist_session_via_store(&store_a, "alpha work");
|
||||
let _session_b = persist_session_via_store(&store_b, "beta work");
|
||||
|
||||
// then — each store only sees its own sessions
|
||||
let list_a = store_a.list_sessions().expect("list a");
|
||||
let list_b = store_b.list_sessions().expect("list b");
|
||||
assert_eq!(list_a.len(), 1, "store a should see exactly one session");
|
||||
assert_eq!(list_b.len(), 1, "store b should see exactly one session");
|
||||
assert_eq!(list_a[0].id, session_a.session_id);
|
||||
assert_ne!(
|
||||
store_a.sessions_dir(),
|
||||
store_b.sessions_dir(),
|
||||
"session directories must differ across workspaces"
|
||||
);
|
||||
fs::remove_dir_all(base).expect("temp dir should clean up");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn session_store_from_data_dir_namespaces_by_workspace() {
|
||||
// given
|
||||
let base = temp_dir();
|
||||
let data_dir = base.join("global-data");
|
||||
let workspace_a = PathBuf::from("/tmp/project-one");
|
||||
let workspace_b = PathBuf::from("/tmp/project-two");
|
||||
fs::create_dir_all(&data_dir).expect("data dir should exist");
|
||||
|
||||
let store_a =
|
||||
SessionStore::from_data_dir(&data_dir, &workspace_a).expect("store a should build");
|
||||
let store_b =
|
||||
SessionStore::from_data_dir(&data_dir, &workspace_b).expect("store b should build");
|
||||
|
||||
// when
|
||||
persist_session_via_store(&store_a, "work in project-one");
|
||||
persist_session_via_store(&store_b, "work in project-two");
|
||||
|
||||
// then
|
||||
assert_ne!(
|
||||
store_a.sessions_dir(),
|
||||
store_b.sessions_dir(),
|
||||
"data-dir stores must namespace by workspace"
|
||||
);
|
||||
assert_eq!(store_a.list_sessions().expect("list a").len(), 1);
|
||||
assert_eq!(store_b.list_sessions().expect("list b").len(), 1);
|
||||
assert_eq!(store_a.workspace_root(), workspace_a.as_path());
|
||||
assert_eq!(store_b.workspace_root(), workspace_b.as_path());
|
||||
fs::remove_dir_all(base).expect("temp dir should clean up");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn session_store_create_and_load_round_trip() {
|
||||
// given
|
||||
let base = temp_dir();
|
||||
fs::create_dir_all(&base).expect("base dir should exist");
|
||||
let store = SessionStore::from_cwd(&base).expect("store should build");
|
||||
let session = persist_session_via_store(&store, "round-trip message");
|
||||
|
||||
// when
|
||||
let loaded = store
|
||||
.load_session(&session.session_id)
|
||||
.expect("session should load via store");
|
||||
|
||||
// then
|
||||
assert_eq!(loaded.handle.id, session.session_id);
|
||||
assert_eq!(loaded.session.messages.len(), 1);
|
||||
fs::remove_dir_all(base).expect("temp dir should clean up");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn session_store_latest_and_resolve_reference() {
|
||||
// given
|
||||
let base = temp_dir();
|
||||
fs::create_dir_all(&base).expect("base dir should exist");
|
||||
let store = SessionStore::from_cwd(&base).expect("store should build");
|
||||
let _older = persist_session_via_store(&store, "older");
|
||||
wait_for_next_millisecond();
|
||||
let newer = persist_session_via_store(&store, "newer");
|
||||
|
||||
// when
|
||||
let latest = store.latest_session().expect("latest should resolve");
|
||||
let handle = store
|
||||
.resolve_reference("latest")
|
||||
.expect("latest alias should resolve");
|
||||
|
||||
// then
|
||||
assert_eq!(latest.id, newer.session_id);
|
||||
assert_eq!(handle.id, newer.session_id);
|
||||
fs::remove_dir_all(base).expect("temp dir should clean up");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn session_store_fork_stays_in_same_namespace() {
|
||||
// given
|
||||
let base = temp_dir();
|
||||
fs::create_dir_all(&base).expect("base dir should exist");
|
||||
let store = SessionStore::from_cwd(&base).expect("store should build");
|
||||
let source = persist_session_via_store(&store, "parent work");
|
||||
|
||||
// when
|
||||
let forked = store
|
||||
.fork_session(&source, Some("bugfix".to_string()))
|
||||
.expect("fork should succeed");
|
||||
let sessions = store.list_sessions().expect("list sessions");
|
||||
|
||||
// then
|
||||
assert_eq!(
|
||||
sessions.len(),
|
||||
2,
|
||||
"forked session must land in the same namespace"
|
||||
);
|
||||
assert_eq!(forked.parent_session_id, source.session_id);
|
||||
assert_eq!(forked.branch_name.as_deref(), Some("bugfix"));
|
||||
assert!(
|
||||
forked.handle.path.starts_with(store.sessions_dir()),
|
||||
"forked session path must be inside the store namespace"
|
||||
);
|
||||
fs::remove_dir_all(base).expect("temp dir should clean up");
|
||||
}
|
||||
}
|
||||
|
||||
429
rust/crates/runtime/src/stale_base.rs
Normal file
429
rust/crates/runtime/src/stale_base.rs
Normal file
@@ -0,0 +1,429 @@
|
||||
#![allow(clippy::must_use_candidate)]
|
||||
use std::path::Path;
|
||||
use std::process::Command;
|
||||
|
||||
/// Outcome of comparing the worktree HEAD against the expected base commit.
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub enum BaseCommitState {
|
||||
/// HEAD matches the expected base commit.
|
||||
Matches,
|
||||
/// HEAD has diverged from the expected base.
|
||||
Diverged { expected: String, actual: String },
|
||||
/// No expected base was supplied (neither flag nor file).
|
||||
NoExpectedBase,
|
||||
/// The working directory is not inside a git repository.
|
||||
NotAGitRepo,
|
||||
}
|
||||
|
||||
/// Where the expected base commit originated from.
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub enum BaseCommitSource {
|
||||
Flag(String),
|
||||
File(String),
|
||||
}
|
||||
|
||||
/// Read the `.claw-base` file from the given directory and return the trimmed
|
||||
/// commit hash, or `None` when the file is absent or empty.
|
||||
pub fn read_claw_base_file(cwd: &Path) -> Option<String> {
|
||||
let path = cwd.join(".claw-base");
|
||||
let content = std::fs::read_to_string(path).ok()?;
|
||||
let trimmed = content.trim();
|
||||
if trimmed.is_empty() {
|
||||
None
|
||||
} else {
|
||||
Some(trimmed.to_string())
|
||||
}
|
||||
}
|
||||
|
||||
/// Resolve the expected base commit: prefer the `--base-commit` flag value,
|
||||
/// fall back to reading `.claw-base` from `cwd`.
|
||||
pub fn resolve_expected_base(flag_value: Option<&str>, cwd: &Path) -> Option<BaseCommitSource> {
|
||||
if let Some(value) = flag_value {
|
||||
let trimmed = value.trim();
|
||||
if !trimmed.is_empty() {
|
||||
return Some(BaseCommitSource::Flag(trimmed.to_string()));
|
||||
}
|
||||
}
|
||||
read_claw_base_file(cwd).map(BaseCommitSource::File)
|
||||
}
|
||||
|
||||
/// Verify that the worktree HEAD matches `expected_base`.
|
||||
///
|
||||
/// Returns [`BaseCommitState::NoExpectedBase`] when no expected commit is
|
||||
/// provided (the check is effectively a no-op in that case).
|
||||
pub fn check_base_commit(cwd: &Path, expected_base: Option<&BaseCommitSource>) -> BaseCommitState {
|
||||
let Some(source) = expected_base else {
|
||||
return BaseCommitState::NoExpectedBase;
|
||||
};
|
||||
let expected_raw = match source {
|
||||
BaseCommitSource::Flag(value) | BaseCommitSource::File(value) => value.as_str(),
|
||||
};
|
||||
|
||||
let Some(head_sha) = resolve_head_sha(cwd) else {
|
||||
return BaseCommitState::NotAGitRepo;
|
||||
};
|
||||
|
||||
let Some(expected_sha) = resolve_rev(cwd, expected_raw) else {
|
||||
// If the expected ref cannot be resolved, compare raw strings as a
|
||||
// best-effort fallback (e.g. partial SHA provided by the caller).
|
||||
return if head_sha.starts_with(expected_raw) || expected_raw.starts_with(&head_sha) {
|
||||
BaseCommitState::Matches
|
||||
} else {
|
||||
BaseCommitState::Diverged {
|
||||
expected: expected_raw.to_string(),
|
||||
actual: head_sha,
|
||||
}
|
||||
};
|
||||
};
|
||||
|
||||
if head_sha == expected_sha {
|
||||
BaseCommitState::Matches
|
||||
} else {
|
||||
BaseCommitState::Diverged {
|
||||
expected: expected_sha,
|
||||
actual: head_sha,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Format a human-readable warning when the base commit has diverged.
|
||||
///
|
||||
/// Returns `None` for non-warning states (`Matches`, `NoExpectedBase`).
|
||||
pub fn format_stale_base_warning(state: &BaseCommitState) -> Option<String> {
|
||||
match state {
|
||||
BaseCommitState::Diverged { expected, actual } => Some(format!(
|
||||
"warning: worktree HEAD ({actual}) does not match expected base commit ({expected}). \
|
||||
Session may run against a stale codebase."
|
||||
)),
|
||||
BaseCommitState::NotAGitRepo => {
|
||||
Some("warning: stale-base check skipped — not inside a git repository.".to_string())
|
||||
}
|
||||
BaseCommitState::Matches | BaseCommitState::NoExpectedBase => None,
|
||||
}
|
||||
}
|
||||
|
||||
fn resolve_head_sha(cwd: &Path) -> Option<String> {
|
||||
resolve_rev(cwd, "HEAD")
|
||||
}
|
||||
|
||||
fn resolve_rev(cwd: &Path, rev: &str) -> Option<String> {
|
||||
let output = Command::new("git")
|
||||
.args(["rev-parse", rev])
|
||||
.current_dir(cwd)
|
||||
.output()
|
||||
.ok()?;
|
||||
if !output.status.success() {
|
||||
return None;
|
||||
}
|
||||
let sha = String::from_utf8(output.stdout).ok()?;
|
||||
let trimmed = sha.trim();
|
||||
if trimmed.is_empty() {
|
||||
None
|
||||
} else {
|
||||
Some(trimmed.to_string())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use std::fs;
|
||||
use std::process::Command;
|
||||
use std::time::{SystemTime, UNIX_EPOCH};
|
||||
|
||||
fn temp_dir() -> std::path::PathBuf {
|
||||
let nanos = SystemTime::now()
|
||||
.duration_since(UNIX_EPOCH)
|
||||
.expect("time should be after epoch")
|
||||
.as_nanos();
|
||||
std::env::temp_dir().join(format!("runtime-stale-base-{nanos}"))
|
||||
}
|
||||
|
||||
fn init_repo(path: &std::path::Path) {
|
||||
fs::create_dir_all(path).expect("create repo dir");
|
||||
run(path, &["init", "--quiet", "-b", "main"]);
|
||||
run(path, &["config", "user.email", "tests@example.com"]);
|
||||
run(path, &["config", "user.name", "Stale Base Tests"]);
|
||||
fs::write(path.join("init.txt"), "initial\n").expect("write init file");
|
||||
run(path, &["add", "."]);
|
||||
run(path, &["commit", "-m", "initial commit", "--quiet"]);
|
||||
}
|
||||
|
||||
fn run(cwd: &std::path::Path, args: &[&str]) {
|
||||
let status = Command::new("git")
|
||||
.args(args)
|
||||
.current_dir(cwd)
|
||||
.status()
|
||||
.unwrap_or_else(|e| panic!("git {} failed to execute: {e}", args.join(" ")));
|
||||
assert!(
|
||||
status.success(),
|
||||
"git {} exited with {status}",
|
||||
args.join(" ")
|
||||
);
|
||||
}
|
||||
|
||||
fn commit_file(repo: &std::path::Path, name: &str, msg: &str) {
|
||||
fs::write(repo.join(name), format!("{msg}\n")).expect("write file");
|
||||
run(repo, &["add", name]);
|
||||
run(repo, &["commit", "-m", msg, "--quiet"]);
|
||||
}
|
||||
|
||||
fn head_sha(repo: &std::path::Path) -> String {
|
||||
let output = Command::new("git")
|
||||
.args(["rev-parse", "HEAD"])
|
||||
.current_dir(repo)
|
||||
.output()
|
||||
.expect("git rev-parse HEAD");
|
||||
String::from_utf8(output.stdout)
|
||||
.expect("valid utf8")
|
||||
.trim()
|
||||
.to_string()
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn matches_when_head_equals_expected_base() {
|
||||
// given
|
||||
let root = temp_dir();
|
||||
init_repo(&root);
|
||||
let sha = head_sha(&root);
|
||||
let source = BaseCommitSource::Flag(sha);
|
||||
|
||||
// when
|
||||
let state = check_base_commit(&root, Some(&source));
|
||||
|
||||
// then
|
||||
assert_eq!(state, BaseCommitState::Matches);
|
||||
fs::remove_dir_all(&root).expect("cleanup");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn diverged_when_head_moved_past_expected_base() {
|
||||
// given
|
||||
let root = temp_dir();
|
||||
init_repo(&root);
|
||||
let old_sha = head_sha(&root);
|
||||
commit_file(&root, "extra.txt", "move head forward");
|
||||
let new_sha = head_sha(&root);
|
||||
let source = BaseCommitSource::Flag(old_sha.clone());
|
||||
|
||||
// when
|
||||
let state = check_base_commit(&root, Some(&source));
|
||||
|
||||
// then
|
||||
assert_eq!(
|
||||
state,
|
||||
BaseCommitState::Diverged {
|
||||
expected: old_sha,
|
||||
actual: new_sha,
|
||||
}
|
||||
);
|
||||
fs::remove_dir_all(&root).expect("cleanup");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn no_expected_base_when_source_is_none() {
|
||||
// given
|
||||
let root = temp_dir();
|
||||
init_repo(&root);
|
||||
|
||||
// when
|
||||
let state = check_base_commit(&root, None);
|
||||
|
||||
// then
|
||||
assert_eq!(state, BaseCommitState::NoExpectedBase);
|
||||
fs::remove_dir_all(&root).expect("cleanup");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn not_a_git_repo_when_outside_repo() {
|
||||
// given
|
||||
let root = temp_dir();
|
||||
fs::create_dir_all(&root).expect("create dir");
|
||||
let source = BaseCommitSource::Flag("abc1234".to_string());
|
||||
|
||||
// when
|
||||
let state = check_base_commit(&root, Some(&source));
|
||||
|
||||
// then
|
||||
assert_eq!(state, BaseCommitState::NotAGitRepo);
|
||||
fs::remove_dir_all(&root).expect("cleanup");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn reads_claw_base_file() {
|
||||
// given
|
||||
let root = temp_dir();
|
||||
fs::create_dir_all(&root).expect("create dir");
|
||||
fs::write(root.join(".claw-base"), "abc1234def5678\n").expect("write .claw-base");
|
||||
|
||||
// when
|
||||
let value = read_claw_base_file(&root);
|
||||
|
||||
// then
|
||||
assert_eq!(value, Some("abc1234def5678".to_string()));
|
||||
fs::remove_dir_all(&root).expect("cleanup");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn returns_none_for_missing_claw_base_file() {
|
||||
// given
|
||||
let root = temp_dir();
|
||||
fs::create_dir_all(&root).expect("create dir");
|
||||
|
||||
// when
|
||||
let value = read_claw_base_file(&root);
|
||||
|
||||
// then
|
||||
assert!(value.is_none());
|
||||
fs::remove_dir_all(&root).expect("cleanup");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn returns_none_for_empty_claw_base_file() {
|
||||
// given
|
||||
let root = temp_dir();
|
||||
fs::create_dir_all(&root).expect("create dir");
|
||||
fs::write(root.join(".claw-base"), " \n").expect("write empty .claw-base");
|
||||
|
||||
// when
|
||||
let value = read_claw_base_file(&root);
|
||||
|
||||
// then
|
||||
assert!(value.is_none());
|
||||
fs::remove_dir_all(&root).expect("cleanup");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn resolve_expected_base_prefers_flag_over_file() {
|
||||
// given
|
||||
let root = temp_dir();
|
||||
fs::create_dir_all(&root).expect("create dir");
|
||||
fs::write(root.join(".claw-base"), "from_file\n").expect("write .claw-base");
|
||||
|
||||
// when
|
||||
let source = resolve_expected_base(Some("from_flag"), &root);
|
||||
|
||||
// then
|
||||
assert_eq!(
|
||||
source,
|
||||
Some(BaseCommitSource::Flag("from_flag".to_string()))
|
||||
);
|
||||
fs::remove_dir_all(&root).expect("cleanup");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn resolve_expected_base_falls_back_to_file() {
|
||||
// given
|
||||
let root = temp_dir();
|
||||
fs::create_dir_all(&root).expect("create dir");
|
||||
fs::write(root.join(".claw-base"), "from_file\n").expect("write .claw-base");
|
||||
|
||||
// when
|
||||
let source = resolve_expected_base(None, &root);
|
||||
|
||||
// then
|
||||
assert_eq!(
|
||||
source,
|
||||
Some(BaseCommitSource::File("from_file".to_string()))
|
||||
);
|
||||
fs::remove_dir_all(&root).expect("cleanup");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn resolve_expected_base_returns_none_when_nothing_available() {
|
||||
// given
|
||||
let root = temp_dir();
|
||||
fs::create_dir_all(&root).expect("create dir");
|
||||
|
||||
// when
|
||||
let source = resolve_expected_base(None, &root);
|
||||
|
||||
// then
|
||||
assert!(source.is_none());
|
||||
fs::remove_dir_all(&root).expect("cleanup");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn format_warning_returns_message_for_diverged() {
|
||||
// given
|
||||
let state = BaseCommitState::Diverged {
|
||||
expected: "abc1234".to_string(),
|
||||
actual: "def5678".to_string(),
|
||||
};
|
||||
|
||||
// when
|
||||
let warning = format_stale_base_warning(&state);
|
||||
|
||||
// then
|
||||
let message = warning.expect("should produce warning");
|
||||
assert!(message.contains("abc1234"));
|
||||
assert!(message.contains("def5678"));
|
||||
assert!(message.contains("stale codebase"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn format_warning_returns_none_for_matches() {
|
||||
// given
|
||||
let state = BaseCommitState::Matches;
|
||||
|
||||
// when
|
||||
let warning = format_stale_base_warning(&state);
|
||||
|
||||
// then
|
||||
assert!(warning.is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn format_warning_returns_none_for_no_expected_base() {
|
||||
// given
|
||||
let state = BaseCommitState::NoExpectedBase;
|
||||
|
||||
// when
|
||||
let warning = format_stale_base_warning(&state);
|
||||
|
||||
// then
|
||||
assert!(warning.is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn matches_with_claw_base_file_in_real_repo() {
|
||||
// given
|
||||
let root = temp_dir();
|
||||
init_repo(&root);
|
||||
let sha = head_sha(&root);
|
||||
fs::write(root.join(".claw-base"), format!("{sha}\n")).expect("write .claw-base");
|
||||
let source = resolve_expected_base(None, &root);
|
||||
|
||||
// when
|
||||
let state = check_base_commit(&root, source.as_ref());
|
||||
|
||||
// then
|
||||
assert_eq!(state, BaseCommitState::Matches);
|
||||
fs::remove_dir_all(&root).expect("cleanup");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn diverged_with_claw_base_file_after_new_commit() {
|
||||
// given
|
||||
let root = temp_dir();
|
||||
init_repo(&root);
|
||||
let old_sha = head_sha(&root);
|
||||
fs::write(root.join(".claw-base"), format!("{old_sha}\n")).expect("write .claw-base");
|
||||
commit_file(&root, "new.txt", "advance head");
|
||||
let new_sha = head_sha(&root);
|
||||
let source = resolve_expected_base(None, &root);
|
||||
|
||||
// when
|
||||
let state = check_base_commit(&root, source.as_ref());
|
||||
|
||||
// then
|
||||
assert_eq!(
|
||||
state,
|
||||
BaseCommitState::Diverged {
|
||||
expected: old_sha,
|
||||
actual: new_sha,
|
||||
}
|
||||
);
|
||||
fs::remove_dir_all(&root).expect("cleanup");
|
||||
}
|
||||
}
|
||||
@@ -560,6 +560,7 @@ fn push_event(
|
||||
let timestamp = now_secs();
|
||||
let seq = worker.events.len() as u64 + 1;
|
||||
worker.updated_at = timestamp;
|
||||
worker.status = status;
|
||||
worker.events.push(WorkerEvent {
|
||||
seq,
|
||||
kind,
|
||||
@@ -568,6 +569,50 @@ fn push_event(
|
||||
payload,
|
||||
timestamp,
|
||||
});
|
||||
emit_state_file(worker);
|
||||
}
|
||||
|
||||
/// Write current worker state to `.claw/worker-state.json` under the worker's cwd.
|
||||
/// This is the file-based observability surface: external observers (clawhip, orchestrators)
|
||||
/// poll this file instead of requiring an HTTP route on the opencode binary.
|
||||
fn emit_state_file(worker: &Worker) {
|
||||
let state_dir = std::path::Path::new(&worker.cwd).join(".claw");
|
||||
if let Err(_) = std::fs::create_dir_all(&state_dir) {
|
||||
return;
|
||||
}
|
||||
let state_path = state_dir.join("worker-state.json");
|
||||
let tmp_path = state_dir.join("worker-state.json.tmp");
|
||||
|
||||
#[derive(serde::Serialize)]
|
||||
struct StateSnapshot<'a> {
|
||||
worker_id: &'a str,
|
||||
status: WorkerStatus,
|
||||
is_ready: bool,
|
||||
trust_gate_cleared: bool,
|
||||
prompt_in_flight: bool,
|
||||
last_event: Option<&'a WorkerEvent>,
|
||||
updated_at: u64,
|
||||
/// Seconds since last state transition. Clawhip uses this to detect
|
||||
/// stalled workers without computing epoch deltas.
|
||||
seconds_since_update: u64,
|
||||
}
|
||||
|
||||
let now = now_secs();
|
||||
let snapshot = StateSnapshot {
|
||||
worker_id: &worker.worker_id,
|
||||
status: worker.status,
|
||||
is_ready: worker.status == WorkerStatus::ReadyForPrompt,
|
||||
trust_gate_cleared: worker.trust_gate_cleared,
|
||||
prompt_in_flight: worker.prompt_in_flight,
|
||||
last_event: worker.events.last(),
|
||||
updated_at: worker.updated_at,
|
||||
seconds_since_update: now.saturating_sub(worker.updated_at),
|
||||
};
|
||||
|
||||
if let Ok(json) = serde_json::to_string_pretty(&snapshot) {
|
||||
let _ = std::fs::write(&tmp_path, json);
|
||||
let _ = std::fs::rename(&tmp_path, &state_path);
|
||||
}
|
||||
}
|
||||
|
||||
fn path_matches_allowlist(cwd: &str, trusted_root: &str) -> bool {
|
||||
@@ -1058,6 +1103,58 @@ mod tests {
|
||||
.any(|event| event.kind == WorkerEventKind::Failed));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn emit_state_file_writes_worker_status_on_transition() {
|
||||
let cwd_path = std::env::temp_dir().join(format!(
|
||||
"claw-state-test-{}",
|
||||
std::time::SystemTime::now()
|
||||
.duration_since(std::time::UNIX_EPOCH)
|
||||
.unwrap_or_default()
|
||||
.as_nanos()
|
||||
));
|
||||
std::fs::create_dir_all(&cwd_path).expect("test dir should create");
|
||||
let cwd = cwd_path.to_str().expect("test path should be utf8");
|
||||
let registry = WorkerRegistry::new();
|
||||
let worker = registry.create(cwd, &[], true);
|
||||
|
||||
// After create the worker is Spawning — state file should exist
|
||||
let state_path = cwd_path.join(".claw").join("worker-state.json");
|
||||
assert!(
|
||||
state_path.exists(),
|
||||
"state file should exist after worker creation"
|
||||
);
|
||||
|
||||
let raw = std::fs::read_to_string(&state_path).expect("state file should be readable");
|
||||
let value: serde_json::Value =
|
||||
serde_json::from_str(&raw).expect("state file should be valid JSON");
|
||||
assert_eq!(
|
||||
value["status"].as_str(),
|
||||
Some("spawning"),
|
||||
"initial status should be spawning"
|
||||
);
|
||||
assert_eq!(value["is_ready"].as_bool(), Some(false));
|
||||
|
||||
// Transition to ReadyForPrompt by observing trust-cleared text
|
||||
registry
|
||||
.observe(&worker.worker_id, "Ready for input\n>")
|
||||
.expect("observe ready should succeed");
|
||||
|
||||
let raw = std::fs::read_to_string(&state_path)
|
||||
.expect("state file should be readable after observe");
|
||||
let value: serde_json::Value =
|
||||
serde_json::from_str(&raw).expect("state file should be valid JSON after observe");
|
||||
assert_eq!(
|
||||
value["status"].as_str(),
|
||||
Some("ready_for_prompt"),
|
||||
"status should be ready_for_prompt after observe"
|
||||
);
|
||||
assert_eq!(
|
||||
value["is_ready"].as_bool(),
|
||||
Some(true),
|
||||
"is_ready should be true when ReadyForPrompt"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn observe_completion_accepts_normal_finish_with_tokens() {
|
||||
let registry = WorkerRegistry::new();
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -249,13 +249,14 @@ impl TerminalRenderer {
|
||||
|
||||
#[must_use]
|
||||
pub fn render_markdown(&self, markdown: &str) -> String {
|
||||
let normalized = normalize_nested_fences(markdown);
|
||||
let mut output = String::new();
|
||||
let mut state = RenderState::default();
|
||||
let mut code_language = String::new();
|
||||
let mut code_buffer = String::new();
|
||||
let mut in_code_block = false;
|
||||
|
||||
for event in Parser::new_ext(markdown, Options::all()) {
|
||||
for event in Parser::new_ext(&normalized, Options::all()) {
|
||||
self.render_event(
|
||||
event,
|
||||
&mut state,
|
||||
@@ -634,8 +635,180 @@ fn apply_code_block_background(line: &str) -> String {
|
||||
format!("\u{1b}[48;5;236m{with_background}\u{1b}[0m{trailing_newline}")
|
||||
}
|
||||
|
||||
/// Pre-process raw markdown so that fenced code blocks whose body contains
|
||||
/// fence markers of equal or greater length are wrapped with a longer fence.
|
||||
///
|
||||
/// LLMs frequently emit triple-backtick code blocks that contain triple-backtick
|
||||
/// examples. CommonMark (and pulldown-cmark) treats the inner marker as the
|
||||
/// closing fence, breaking the render. This function detects the situation and
|
||||
/// upgrades the outer fence to use enough backticks (or tildes) that the inner
|
||||
/// markers become ordinary content.
|
||||
fn normalize_nested_fences(markdown: &str) -> String {
|
||||
// A fence line is either "labeled" (has an info string ⇒ always an opener)
|
||||
// or "bare" (no info string ⇒ could be opener or closer).
|
||||
#[derive(Debug, Clone)]
|
||||
struct FenceLine {
|
||||
char: char,
|
||||
len: usize,
|
||||
has_info: bool,
|
||||
indent: usize,
|
||||
}
|
||||
|
||||
fn parse_fence_line(line: &str) -> Option<FenceLine> {
|
||||
let trimmed = line.trim_end_matches('\n').trim_end_matches('\r');
|
||||
let indent = trimmed.chars().take_while(|c| *c == ' ').count();
|
||||
if indent > 3 {
|
||||
return None;
|
||||
}
|
||||
let rest = &trimmed[indent..];
|
||||
let ch = rest.chars().next()?;
|
||||
if ch != '`' && ch != '~' {
|
||||
return None;
|
||||
}
|
||||
let len = rest.chars().take_while(|c| *c == ch).count();
|
||||
if len < 3 {
|
||||
return None;
|
||||
}
|
||||
let after = &rest[len..];
|
||||
if ch == '`' && after.contains('`') {
|
||||
return None;
|
||||
}
|
||||
let has_info = !after.trim().is_empty();
|
||||
Some(FenceLine {
|
||||
char: ch,
|
||||
len,
|
||||
has_info,
|
||||
indent,
|
||||
})
|
||||
}
|
||||
|
||||
let lines: Vec<&str> = markdown.split_inclusive('\n').collect();
|
||||
// Handle final line that may lack trailing newline.
|
||||
// split_inclusive already keeps the original chunks, including a
|
||||
// final chunk without '\n' if the input doesn't end with one.
|
||||
|
||||
// First pass: classify every line.
|
||||
let fence_info: Vec<Option<FenceLine>> = lines.iter().map(|l| parse_fence_line(l)).collect();
|
||||
|
||||
// Second pass: pair openers with closers using a stack, recording
|
||||
// (opener_idx, closer_idx) pairs plus the max fence length found between
|
||||
// them.
|
||||
struct StackEntry {
|
||||
line_idx: usize,
|
||||
fence: FenceLine,
|
||||
}
|
||||
|
||||
let mut stack: Vec<StackEntry> = Vec::new();
|
||||
// Paired blocks: (opener_line, closer_line, max_inner_fence_len)
|
||||
let mut pairs: Vec<(usize, usize, usize)> = Vec::new();
|
||||
|
||||
for (i, fi) in fence_info.iter().enumerate() {
|
||||
let Some(fl) = fi else { continue };
|
||||
|
||||
if fl.has_info {
|
||||
// Labeled fence ⇒ always an opener.
|
||||
stack.push(StackEntry {
|
||||
line_idx: i,
|
||||
fence: fl.clone(),
|
||||
});
|
||||
} else {
|
||||
// Bare fence ⇒ try to close the top of the stack if compatible.
|
||||
let closes_top = stack
|
||||
.last()
|
||||
.is_some_and(|top| top.fence.char == fl.char && fl.len >= top.fence.len);
|
||||
if closes_top {
|
||||
let opener = stack.pop().unwrap();
|
||||
// Find max fence length of any fence line strictly between
|
||||
// opener and closer (these are the nested fences).
|
||||
let inner_max = fence_info[opener.line_idx + 1..i]
|
||||
.iter()
|
||||
.filter_map(|fi| fi.as_ref().map(|f| f.len))
|
||||
.max()
|
||||
.unwrap_or(0);
|
||||
pairs.push((opener.line_idx, i, inner_max));
|
||||
} else {
|
||||
// Treat as opener.
|
||||
stack.push(StackEntry {
|
||||
line_idx: i,
|
||||
fence: fl.clone(),
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Determine which lines need rewriting. A pair needs rewriting when
|
||||
// its opener length <= max inner fence length.
|
||||
struct Rewrite {
|
||||
char: char,
|
||||
new_len: usize,
|
||||
indent: usize,
|
||||
}
|
||||
let mut rewrites: std::collections::HashMap<usize, Rewrite> = std::collections::HashMap::new();
|
||||
|
||||
for (opener_idx, closer_idx, inner_max) in &pairs {
|
||||
let opener_fl = fence_info[*opener_idx].as_ref().unwrap();
|
||||
if opener_fl.len <= *inner_max {
|
||||
let new_len = inner_max + 1;
|
||||
let info_part = {
|
||||
let trimmed = lines[*opener_idx]
|
||||
.trim_end_matches('\n')
|
||||
.trim_end_matches('\r');
|
||||
let rest = &trimmed[opener_fl.indent..];
|
||||
rest[opener_fl.len..].to_string()
|
||||
};
|
||||
rewrites.insert(
|
||||
*opener_idx,
|
||||
Rewrite {
|
||||
char: opener_fl.char,
|
||||
new_len,
|
||||
indent: opener_fl.indent,
|
||||
},
|
||||
);
|
||||
let closer_fl = fence_info[*closer_idx].as_ref().unwrap();
|
||||
rewrites.insert(
|
||||
*closer_idx,
|
||||
Rewrite {
|
||||
char: closer_fl.char,
|
||||
new_len,
|
||||
indent: closer_fl.indent,
|
||||
},
|
||||
);
|
||||
// Store info string only in the opener; closer keeps the trailing
|
||||
// portion which is already handled through the original line.
|
||||
// Actually, we rebuild both lines from scratch below, including
|
||||
// the info string for the opener.
|
||||
let _ = info_part; // consumed in rebuild
|
||||
}
|
||||
}
|
||||
|
||||
if rewrites.is_empty() {
|
||||
return markdown.to_string();
|
||||
}
|
||||
|
||||
// Rebuild.
|
||||
let mut out = String::with_capacity(markdown.len() + rewrites.len() * 4);
|
||||
for (i, line) in lines.iter().enumerate() {
|
||||
if let Some(rw) = rewrites.get(&i) {
|
||||
let fence_str: String = std::iter::repeat(rw.char).take(rw.new_len).collect();
|
||||
let indent_str: String = std::iter::repeat(' ').take(rw.indent).collect();
|
||||
// Recover the original info string (if any) and trailing newline.
|
||||
let trimmed = line.trim_end_matches('\n').trim_end_matches('\r');
|
||||
let fi = fence_info[i].as_ref().unwrap();
|
||||
let info = &trimmed[fi.indent + fi.len..];
|
||||
let trailing = &line[trimmed.len()..];
|
||||
out.push_str(&indent_str);
|
||||
out.push_str(&fence_str);
|
||||
out.push_str(info);
|
||||
out.push_str(trailing);
|
||||
} else {
|
||||
out.push_str(line);
|
||||
}
|
||||
}
|
||||
out
|
||||
}
|
||||
|
||||
fn find_stream_safe_boundary(markdown: &str) -> Option<usize> {
|
||||
let mut in_fence = false;
|
||||
let mut open_fence: Option<FenceMarker> = None;
|
||||
let mut last_boundary = None;
|
||||
|
||||
for (offset, line) in markdown.split_inclusive('\n').scan(0usize, |cursor, line| {
|
||||
@@ -643,20 +816,21 @@ fn find_stream_safe_boundary(markdown: &str) -> Option<usize> {
|
||||
*cursor += line.len();
|
||||
Some((start, line))
|
||||
}) {
|
||||
let trimmed = line.trim_start();
|
||||
if trimmed.starts_with("```") || trimmed.starts_with("~~~") {
|
||||
in_fence = !in_fence;
|
||||
if !in_fence {
|
||||
let line_without_newline = line.trim_end_matches('\n');
|
||||
if let Some(opener) = open_fence {
|
||||
if line_closes_fence(line_without_newline, opener) {
|
||||
open_fence = None;
|
||||
last_boundary = Some(offset + line.len());
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
if in_fence {
|
||||
if let Some(opener) = parse_fence_opener(line_without_newline) {
|
||||
open_fence = Some(opener);
|
||||
continue;
|
||||
}
|
||||
|
||||
if trimmed.is_empty() {
|
||||
if line_without_newline.trim().is_empty() {
|
||||
last_boundary = Some(offset + line.len());
|
||||
}
|
||||
}
|
||||
@@ -664,6 +838,46 @@ fn find_stream_safe_boundary(markdown: &str) -> Option<usize> {
|
||||
last_boundary
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
struct FenceMarker {
|
||||
character: char,
|
||||
length: usize,
|
||||
}
|
||||
|
||||
fn parse_fence_opener(line: &str) -> Option<FenceMarker> {
|
||||
let indent = line.chars().take_while(|c| *c == ' ').count();
|
||||
if indent > 3 {
|
||||
return None;
|
||||
}
|
||||
let rest = &line[indent..];
|
||||
let character = rest.chars().next()?;
|
||||
if character != '`' && character != '~' {
|
||||
return None;
|
||||
}
|
||||
let length = rest.chars().take_while(|c| *c == character).count();
|
||||
if length < 3 {
|
||||
return None;
|
||||
}
|
||||
let info_string = &rest[length..];
|
||||
if character == '`' && info_string.contains('`') {
|
||||
return None;
|
||||
}
|
||||
Some(FenceMarker { character, length })
|
||||
}
|
||||
|
||||
fn line_closes_fence(line: &str, opener: FenceMarker) -> bool {
|
||||
let indent = line.chars().take_while(|c| *c == ' ').count();
|
||||
if indent > 3 {
|
||||
return false;
|
||||
}
|
||||
let rest = &line[indent..];
|
||||
let length = rest.chars().take_while(|c| *c == opener.character).count();
|
||||
if length < opener.length {
|
||||
return false;
|
||||
}
|
||||
rest[length..].chars().all(|c| c == ' ' || c == '\t')
|
||||
}
|
||||
|
||||
fn visible_width(input: &str) -> usize {
|
||||
strip_ansi(input).chars().count()
|
||||
}
|
||||
@@ -778,6 +992,60 @@ mod tests {
|
||||
assert!(strip_ansi(&code).contains("fn main()"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn streaming_state_holds_outer_fence_with_nested_inner_fence() {
|
||||
let renderer = TerminalRenderer::new();
|
||||
let mut state = MarkdownStreamState::default();
|
||||
|
||||
assert_eq!(
|
||||
state.push(&renderer, "````markdown\n```rust\nfn inner() {}\n"),
|
||||
None,
|
||||
"inner triple backticks must not close the outer four-backtick fence"
|
||||
);
|
||||
assert_eq!(
|
||||
state.push(&renderer, "```\n"),
|
||||
None,
|
||||
"closing the inner fence must not flush the outer fence"
|
||||
);
|
||||
let flushed = state
|
||||
.push(&renderer, "````\n")
|
||||
.expect("closing the outer four-backtick fence flushes the buffered block");
|
||||
let plain_text = strip_ansi(&flushed);
|
||||
assert!(plain_text.contains("fn inner()"));
|
||||
assert!(plain_text.contains("```rust"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn streaming_state_distinguishes_backtick_and_tilde_fences() {
|
||||
let renderer = TerminalRenderer::new();
|
||||
let mut state = MarkdownStreamState::default();
|
||||
|
||||
assert_eq!(state.push(&renderer, "~~~text\n"), None);
|
||||
assert_eq!(
|
||||
state.push(&renderer, "```\nstill inside tilde fence\n"),
|
||||
None,
|
||||
"a backtick fence cannot close a tilde-opened fence"
|
||||
);
|
||||
assert_eq!(state.push(&renderer, "```\n"), None);
|
||||
let flushed = state
|
||||
.push(&renderer, "~~~\n")
|
||||
.expect("matching tilde marker closes the fence");
|
||||
let plain_text = strip_ansi(&flushed);
|
||||
assert!(plain_text.contains("still inside tilde fence"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn renders_nested_fenced_code_block_preserves_inner_markers() {
|
||||
let terminal_renderer = TerminalRenderer::new();
|
||||
let markdown_output =
|
||||
terminal_renderer.markdown_to_ansi("````markdown\n```rust\nfn nested() {}\n```\n````");
|
||||
let plain_text = strip_ansi(&markdown_output);
|
||||
|
||||
assert!(plain_text.contains("╭─ markdown"));
|
||||
assert!(plain_text.contains("```rust"));
|
||||
assert!(plain_text.contains("fn nested()"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn spinner_advances_frames() {
|
||||
let terminal_renderer = TerminalRenderer::new();
|
||||
|
||||
159
rust/crates/rusty-claude-cli/tests/compact_output.rs
Normal file
159
rust/crates/rusty-claude-cli/tests/compact_output.rs
Normal file
@@ -0,0 +1,159 @@
|
||||
use std::fs;
|
||||
use std::path::PathBuf;
|
||||
use std::process::{Command, Output};
|
||||
use std::sync::atomic::{AtomicU64, Ordering};
|
||||
use std::time::{SystemTime, UNIX_EPOCH};
|
||||
|
||||
use mock_anthropic_service::{MockAnthropicService, SCENARIO_PREFIX};
|
||||
|
||||
static TEMP_COUNTER: AtomicU64 = AtomicU64::new(0);
|
||||
|
||||
#[test]
|
||||
fn compact_flag_prints_only_final_assistant_text_without_tool_call_details() {
|
||||
// given a workspace pointed at the mock Anthropic service and a fixture file
|
||||
// that the read_file_roundtrip scenario will fetch through a tool call
|
||||
let runtime = tokio::runtime::Runtime::new().expect("tokio runtime should build");
|
||||
let server = runtime
|
||||
.block_on(MockAnthropicService::spawn())
|
||||
.expect("mock service should start");
|
||||
let base_url = server.base_url();
|
||||
|
||||
let workspace = unique_temp_dir("compact-read-file");
|
||||
let config_home = workspace.join("config-home");
|
||||
let home = workspace.join("home");
|
||||
fs::create_dir_all(&workspace).expect("workspace should exist");
|
||||
fs::create_dir_all(&config_home).expect("config home should exist");
|
||||
fs::create_dir_all(&home).expect("home should exist");
|
||||
fs::write(workspace.join("fixture.txt"), "alpha parity line\n").expect("fixture should write");
|
||||
|
||||
// when we run claw in compact text mode against a tool-using scenario
|
||||
let prompt = format!("{SCENARIO_PREFIX}read_file_roundtrip");
|
||||
let output = run_claw(
|
||||
&workspace,
|
||||
&config_home,
|
||||
&home,
|
||||
&base_url,
|
||||
&[
|
||||
"--model",
|
||||
"sonnet",
|
||||
"--permission-mode",
|
||||
"read-only",
|
||||
"--allowedTools",
|
||||
"read_file",
|
||||
"--compact",
|
||||
&prompt,
|
||||
],
|
||||
);
|
||||
|
||||
// then the command exits successfully and stdout contains exactly the final
|
||||
// assistant text with no tool call IDs, JSON envelopes, or spinner output
|
||||
assert!(
|
||||
output.status.success(),
|
||||
"compact run should succeed\nstdout:\n{}\n\nstderr:\n{}",
|
||||
String::from_utf8_lossy(&output.stdout),
|
||||
String::from_utf8_lossy(&output.stderr),
|
||||
);
|
||||
let stdout = String::from_utf8(output.stdout).expect("stdout should be utf8");
|
||||
let trimmed = stdout.trim_end_matches('\n');
|
||||
assert_eq!(
|
||||
trimmed, "read_file roundtrip complete: alpha parity line",
|
||||
"compact stdout should contain only the final assistant text"
|
||||
);
|
||||
assert!(
|
||||
!stdout.contains("toolu_"),
|
||||
"compact stdout must not leak tool_use_id ({stdout:?})"
|
||||
);
|
||||
assert!(
|
||||
!stdout.contains("\"tool_uses\""),
|
||||
"compact stdout must not leak json envelopes ({stdout:?})"
|
||||
);
|
||||
assert!(
|
||||
!stdout.contains("Thinking"),
|
||||
"compact stdout must not include the spinner banner ({stdout:?})"
|
||||
);
|
||||
|
||||
fs::remove_dir_all(&workspace).expect("workspace cleanup should succeed");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn compact_flag_streaming_text_only_emits_final_message_text() {
|
||||
// given a workspace pointed at the mock Anthropic service running the
|
||||
// streaming_text scenario which only emits a single assistant text block
|
||||
let runtime = tokio::runtime::Runtime::new().expect("tokio runtime should build");
|
||||
let server = runtime
|
||||
.block_on(MockAnthropicService::spawn())
|
||||
.expect("mock service should start");
|
||||
let base_url = server.base_url();
|
||||
|
||||
let workspace = unique_temp_dir("compact-streaming-text");
|
||||
let config_home = workspace.join("config-home");
|
||||
let home = workspace.join("home");
|
||||
fs::create_dir_all(&workspace).expect("workspace should exist");
|
||||
fs::create_dir_all(&config_home).expect("config home should exist");
|
||||
fs::create_dir_all(&home).expect("home should exist");
|
||||
|
||||
// when we invoke claw with --compact for the streaming text scenario
|
||||
let prompt = format!("{SCENARIO_PREFIX}streaming_text");
|
||||
let output = run_claw(
|
||||
&workspace,
|
||||
&config_home,
|
||||
&home,
|
||||
&base_url,
|
||||
&[
|
||||
"--model",
|
||||
"sonnet",
|
||||
"--permission-mode",
|
||||
"read-only",
|
||||
"--compact",
|
||||
&prompt,
|
||||
],
|
||||
);
|
||||
|
||||
// then stdout should be exactly the assistant text followed by a newline
|
||||
assert!(
|
||||
output.status.success(),
|
||||
"compact streaming run should succeed\nstdout:\n{}\n\nstderr:\n{}",
|
||||
String::from_utf8_lossy(&output.stdout),
|
||||
String::from_utf8_lossy(&output.stderr),
|
||||
);
|
||||
let stdout = String::from_utf8(output.stdout).expect("stdout should be utf8");
|
||||
assert_eq!(
|
||||
stdout, "Mock streaming says hello from the parity harness.\n",
|
||||
"compact streaming stdout should contain only the final assistant text"
|
||||
);
|
||||
|
||||
fs::remove_dir_all(&workspace).expect("workspace cleanup should succeed");
|
||||
}
|
||||
|
||||
fn run_claw(
|
||||
cwd: &std::path::Path,
|
||||
config_home: &std::path::Path,
|
||||
home: &std::path::Path,
|
||||
base_url: &str,
|
||||
args: &[&str],
|
||||
) -> Output {
|
||||
let mut command = Command::new(env!("CARGO_BIN_EXE_claw"));
|
||||
command
|
||||
.current_dir(cwd)
|
||||
.env_clear()
|
||||
.env("ANTHROPIC_API_KEY", "test-compact-key")
|
||||
.env("ANTHROPIC_BASE_URL", base_url)
|
||||
.env("CLAW_CONFIG_HOME", config_home)
|
||||
.env("HOME", home)
|
||||
.env("NO_COLOR", "1")
|
||||
.env("PATH", "/usr/bin:/bin")
|
||||
.args(args);
|
||||
command.output().expect("claw should launch")
|
||||
}
|
||||
|
||||
fn unique_temp_dir(label: &str) -> PathBuf {
|
||||
let millis = SystemTime::now()
|
||||
.duration_since(UNIX_EPOCH)
|
||||
.expect("clock should be after epoch")
|
||||
.as_millis();
|
||||
let counter = TEMP_COUNTER.fetch_add(1, Ordering::Relaxed);
|
||||
std::env::temp_dir().join(format!(
|
||||
"claw-compact-{label}-{}-{millis}-{counter}",
|
||||
std::process::id()
|
||||
))
|
||||
}
|
||||
@@ -183,17 +183,24 @@ fn clean_env_cli_reaches_mock_anthropic_service_across_scripted_parity_scenarios
|
||||
}
|
||||
|
||||
let captured = runtime.block_on(server.captured_requests());
|
||||
assert_eq!(
|
||||
captured.len(),
|
||||
21,
|
||||
"twelve scenarios should produce twenty-one requests"
|
||||
);
|
||||
assert!(captured
|
||||
// After `be561bf` added count_tokens preflight, each turn sends an
|
||||
// extra POST to `/v1/messages/count_tokens` before the messages POST.
|
||||
// The original count (21) assumed messages-only requests. We now
|
||||
// filter to `/v1/messages` and verify that subset matches the original
|
||||
// scenario expectation.
|
||||
let messages_only: Vec<_> = captured
|
||||
.iter()
|
||||
.all(|request| request.path == "/v1/messages"));
|
||||
assert!(captured.iter().all(|request| request.stream));
|
||||
.filter(|r| r.path == "/v1/messages")
|
||||
.collect();
|
||||
assert_eq!(
|
||||
messages_only.len(),
|
||||
21,
|
||||
"twelve scenarios should produce twenty-one /v1/messages requests (total captured: {}, includes count_tokens)",
|
||||
captured.len()
|
||||
);
|
||||
assert!(messages_only.iter().all(|request| request.stream));
|
||||
|
||||
let scenarios = captured
|
||||
let scenarios = messages_only
|
||||
.iter()
|
||||
.map(|request| request.scenario.as_str())
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
@@ -8,6 +8,7 @@ publish.workspace = true
|
||||
[dependencies]
|
||||
api = { path = "../api" }
|
||||
commands = { path = "../commands" }
|
||||
flate2 = "1"
|
||||
plugins = { path = "../plugins" }
|
||||
runtime = { path = "../runtime" }
|
||||
reqwest = { version = "0.12", default-features = false, features = ["blocking", "rustls-tls"] }
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
548
rust/crates/tools/src/pdf_extract.rs
Normal file
548
rust/crates/tools/src/pdf_extract.rs
Normal file
@@ -0,0 +1,548 @@
|
||||
//! Minimal PDF text extraction.
|
||||
//!
|
||||
//! Reads a PDF file, locates `/Contents` stream objects, decompresses with
|
||||
//! flate2 when the stream uses `/FlateDecode`, and extracts text operators
|
||||
//! found between `BT` / `ET` markers.
|
||||
|
||||
use std::io::Read as _;
|
||||
use std::path::Path;
|
||||
|
||||
/// Extract all readable text from a PDF file.
|
||||
///
|
||||
/// Returns the concatenated text found inside BT/ET operators across all
|
||||
/// content streams. Non-text pages or encrypted PDFs yield an empty string
|
||||
/// rather than an error.
|
||||
pub fn extract_text(path: &Path) -> Result<String, String> {
|
||||
let data = std::fs::read(path).map_err(|e| format!("failed to read PDF: {e}"))?;
|
||||
Ok(extract_text_from_bytes(&data))
|
||||
}
|
||||
|
||||
/// Core extraction from raw PDF bytes — useful for testing without touching the
|
||||
/// filesystem.
|
||||
pub(crate) fn extract_text_from_bytes(data: &[u8]) -> String {
|
||||
let mut all_text = String::new();
|
||||
let mut offset = 0;
|
||||
|
||||
while offset < data.len() {
|
||||
let Some(stream_start) = find_subsequence(&data[offset..], b"stream") else {
|
||||
break;
|
||||
};
|
||||
let abs_start = offset + stream_start;
|
||||
|
||||
// Determine the byte offset right after "stream\r\n" or "stream\n".
|
||||
let content_start = skip_stream_eol(data, abs_start + b"stream".len());
|
||||
|
||||
let Some(end_rel) = find_subsequence(&data[content_start..], b"endstream") else {
|
||||
break;
|
||||
};
|
||||
let content_end = content_start + end_rel;
|
||||
|
||||
// Look backwards from "stream" for a FlateDecode hint in the object
|
||||
// dictionary. We scan at most 512 bytes before the stream keyword.
|
||||
let dict_window_start = abs_start.saturating_sub(512);
|
||||
let dict_window = &data[dict_window_start..abs_start];
|
||||
let is_flate = find_subsequence(dict_window, b"FlateDecode").is_some();
|
||||
|
||||
// Only process streams whose parent dictionary references /Contents or
|
||||
// looks like a page content stream (contains /Length). We intentionally
|
||||
// keep this loose to cover both inline and referenced content streams.
|
||||
let raw = &data[content_start..content_end];
|
||||
let decompressed;
|
||||
let stream_bytes: &[u8] = if is_flate {
|
||||
if let Ok(buf) = inflate(raw) {
|
||||
decompressed = buf;
|
||||
&decompressed
|
||||
} else {
|
||||
offset = content_end;
|
||||
continue;
|
||||
}
|
||||
} else {
|
||||
raw
|
||||
};
|
||||
|
||||
let text = extract_bt_et_text(stream_bytes);
|
||||
if !text.is_empty() {
|
||||
if !all_text.is_empty() {
|
||||
all_text.push('\n');
|
||||
}
|
||||
all_text.push_str(&text);
|
||||
}
|
||||
|
||||
offset = content_end;
|
||||
}
|
||||
|
||||
all_text
|
||||
}
|
||||
|
||||
/// Inflate (zlib / deflate) compressed data via `flate2`.
|
||||
fn inflate(data: &[u8]) -> Result<Vec<u8>, String> {
|
||||
let mut decoder = flate2::read::ZlibDecoder::new(data);
|
||||
let mut buf = Vec::new();
|
||||
decoder
|
||||
.read_to_end(&mut buf)
|
||||
.map_err(|e| format!("flate2 inflate error: {e}"))?;
|
||||
Ok(buf)
|
||||
}
|
||||
|
||||
/// Extract text from PDF content-stream operators between BT and ET markers.
|
||||
///
|
||||
/// Handles the common text-showing operators:
|
||||
/// - `Tj` — show a string
|
||||
/// - `TJ` — show an array of strings/numbers
|
||||
/// - `'` — move to next line and show string
|
||||
/// - `"` — set spacing, move to next line and show string
|
||||
fn extract_bt_et_text(stream: &[u8]) -> String {
|
||||
let text = String::from_utf8_lossy(stream);
|
||||
let mut result = String::new();
|
||||
let mut in_bt = false;
|
||||
|
||||
for line in text.lines() {
|
||||
let trimmed = line.trim();
|
||||
if trimmed == "BT" {
|
||||
in_bt = true;
|
||||
continue;
|
||||
}
|
||||
if trimmed == "ET" {
|
||||
in_bt = false;
|
||||
continue;
|
||||
}
|
||||
if !in_bt {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Tj operator: (text) Tj
|
||||
if trimmed.ends_with("Tj") {
|
||||
if let Some(s) = extract_parenthesized_string(trimmed) {
|
||||
if !result.is_empty() && !result.ends_with('\n') {
|
||||
result.push(' ');
|
||||
}
|
||||
result.push_str(&s);
|
||||
}
|
||||
}
|
||||
// TJ operator: [ (text) 123 (text) ] TJ
|
||||
else if trimmed.ends_with("TJ") {
|
||||
let extracted = extract_tj_array(trimmed);
|
||||
if !extracted.is_empty() {
|
||||
if !result.is_empty() && !result.ends_with('\n') {
|
||||
result.push(' ');
|
||||
}
|
||||
result.push_str(&extracted);
|
||||
}
|
||||
}
|
||||
// ' operator: (text) ' and " operator: aw ac (text) "
|
||||
else if is_newline_show_operator(trimmed) {
|
||||
if let Some(s) = extract_parenthesized_string(trimmed) {
|
||||
if !result.is_empty() {
|
||||
result.push('\n');
|
||||
}
|
||||
result.push_str(&s);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
result
|
||||
}
|
||||
|
||||
/// Returns `true` when `trimmed` looks like a `'` or `"` text-show operator.
|
||||
fn is_newline_show_operator(trimmed: &str) -> bool {
|
||||
(trimmed.ends_with('\'') && trimmed.len() > 1)
|
||||
|| (trimmed.ends_with('"') && trimmed.contains('('))
|
||||
}
|
||||
|
||||
/// Pull the text from the first `(…)` group, handling escaped parens and
|
||||
/// common PDF escape sequences.
|
||||
fn extract_parenthesized_string(input: &str) -> Option<String> {
|
||||
let open = input.find('(')?;
|
||||
let bytes = input.as_bytes();
|
||||
let mut depth = 0;
|
||||
let mut result = String::new();
|
||||
let mut i = open;
|
||||
|
||||
while i < bytes.len() {
|
||||
match bytes[i] {
|
||||
b'(' => {
|
||||
if depth > 0 {
|
||||
result.push('(');
|
||||
}
|
||||
depth += 1;
|
||||
}
|
||||
b')' => {
|
||||
depth -= 1;
|
||||
if depth == 0 {
|
||||
return Some(result);
|
||||
}
|
||||
result.push(')');
|
||||
}
|
||||
b'\\' if i + 1 < bytes.len() => {
|
||||
i += 1;
|
||||
match bytes[i] {
|
||||
b'n' => result.push('\n'),
|
||||
b'r' => result.push('\r'),
|
||||
b't' => result.push('\t'),
|
||||
b'\\' => result.push('\\'),
|
||||
b'(' => result.push('('),
|
||||
b')' => result.push(')'),
|
||||
// Octal sequences — up to 3 digits.
|
||||
d @ b'0'..=b'7' => {
|
||||
let mut octal = u32::from(d - b'0');
|
||||
for _ in 0..2 {
|
||||
if i + 1 < bytes.len()
|
||||
&& bytes[i + 1].is_ascii_digit()
|
||||
&& bytes[i + 1] <= b'7'
|
||||
{
|
||||
i += 1;
|
||||
octal = octal * 8 + u32::from(bytes[i] - b'0');
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
if let Some(ch) = char::from_u32(octal) {
|
||||
result.push(ch);
|
||||
}
|
||||
}
|
||||
other => result.push(char::from(other)),
|
||||
}
|
||||
}
|
||||
ch => result.push(char::from(ch)),
|
||||
}
|
||||
i += 1;
|
||||
}
|
||||
|
||||
None // unbalanced
|
||||
}
|
||||
|
||||
/// Extract concatenated strings from a TJ array like `[ (Hello) -120 (World) ] TJ`.
|
||||
fn extract_tj_array(input: &str) -> String {
|
||||
let mut result = String::new();
|
||||
let Some(bracket_start) = input.find('[') else {
|
||||
return result;
|
||||
};
|
||||
let Some(bracket_end) = input.rfind(']') else {
|
||||
return result;
|
||||
};
|
||||
let inner = &input[bracket_start + 1..bracket_end];
|
||||
|
||||
let mut i = 0;
|
||||
let bytes = inner.as_bytes();
|
||||
while i < bytes.len() {
|
||||
if bytes[i] == b'(' {
|
||||
// Reconstruct the parenthesized string and extract it.
|
||||
if let Some(s) = extract_parenthesized_string(&inner[i..]) {
|
||||
result.push_str(&s);
|
||||
// Skip past the closing paren.
|
||||
let mut depth = 0u32;
|
||||
for &b in &bytes[i..] {
|
||||
i += 1;
|
||||
if b == b'(' {
|
||||
depth += 1;
|
||||
} else if b == b')' {
|
||||
depth -= 1;
|
||||
if depth == 0 {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
continue;
|
||||
}
|
||||
}
|
||||
i += 1;
|
||||
}
|
||||
|
||||
result
|
||||
}
|
||||
|
||||
/// Skip past the end-of-line marker that immediately follows the `stream`
|
||||
/// keyword. Per the PDF spec this is either `\r\n` or `\n`.
|
||||
fn skip_stream_eol(data: &[u8], pos: usize) -> usize {
|
||||
if pos < data.len() && data[pos] == b'\r' {
|
||||
if pos + 1 < data.len() && data[pos + 1] == b'\n' {
|
||||
return pos + 2;
|
||||
}
|
||||
return pos + 1;
|
||||
}
|
||||
if pos < data.len() && data[pos] == b'\n' {
|
||||
return pos + 1;
|
||||
}
|
||||
pos
|
||||
}
|
||||
|
||||
/// Simple byte-subsequence search.
|
||||
fn find_subsequence(haystack: &[u8], needle: &[u8]) -> Option<usize> {
|
||||
haystack
|
||||
.windows(needle.len())
|
||||
.position(|window| window == needle)
|
||||
}
|
||||
|
||||
/// Check if a user-supplied path looks like a PDF file reference.
|
||||
#[must_use]
|
||||
pub fn looks_like_pdf_path(text: &str) -> Option<&str> {
|
||||
for token in text.split_whitespace() {
|
||||
let cleaned = token.trim_matches(|c: char| c == '\'' || c == '"' || c == '`');
|
||||
if let Some(dot_pos) = cleaned.rfind('.') {
|
||||
if cleaned[dot_pos + 1..].eq_ignore_ascii_case("pdf") && dot_pos > 0 {
|
||||
return Some(cleaned);
|
||||
}
|
||||
}
|
||||
}
|
||||
None
|
||||
}
|
||||
|
||||
/// Auto-extract text from a PDF path mentioned in a user prompt.
|
||||
///
|
||||
/// Returns `Some((path, extracted_text))` when a `.pdf` path is detected and
|
||||
/// the file exists, otherwise `None`.
|
||||
#[must_use]
|
||||
pub fn maybe_extract_pdf_from_prompt(prompt: &str) -> Option<(String, String)> {
|
||||
let pdf_path = looks_like_pdf_path(prompt)?;
|
||||
let path = Path::new(pdf_path);
|
||||
if !path.exists() {
|
||||
return None;
|
||||
}
|
||||
let text = extract_text(path).ok()?;
|
||||
if text.is_empty() {
|
||||
return None;
|
||||
}
|
||||
Some((pdf_path.to_string(), text))
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
/// Build a minimal valid PDF with a single page containing uncompressed
|
||||
/// text. This is the smallest PDF structure that exercises the BT/ET
|
||||
/// extraction path.
|
||||
fn build_simple_pdf(text: &str) -> Vec<u8> {
|
||||
let content_stream = format!("BT\n/F1 12 Tf\n({text}) Tj\nET");
|
||||
let stream_bytes = content_stream.as_bytes();
|
||||
let mut pdf = Vec::new();
|
||||
|
||||
// Header
|
||||
pdf.extend_from_slice(b"%PDF-1.4\n");
|
||||
|
||||
// Object 1 — Catalog
|
||||
let obj1_offset = pdf.len();
|
||||
pdf.extend_from_slice(b"1 0 obj\n<< /Type /Catalog /Pages 2 0 R >>\nendobj\n");
|
||||
|
||||
// Object 2 — Pages
|
||||
let obj2_offset = pdf.len();
|
||||
pdf.extend_from_slice(b"2 0 obj\n<< /Type /Pages /Kids [3 0 R] /Count 1 >>\nendobj\n");
|
||||
|
||||
// Object 3 — Page
|
||||
let obj3_offset = pdf.len();
|
||||
pdf.extend_from_slice(
|
||||
b"3 0 obj\n<< /Type /Page /Parent 2 0 R /Contents 4 0 R >>\nendobj\n",
|
||||
);
|
||||
|
||||
// Object 4 — Content stream (uncompressed)
|
||||
let obj4_offset = pdf.len();
|
||||
let length = stream_bytes.len();
|
||||
let header = format!("4 0 obj\n<< /Length {length} >>\nstream\n");
|
||||
pdf.extend_from_slice(header.as_bytes());
|
||||
pdf.extend_from_slice(stream_bytes);
|
||||
pdf.extend_from_slice(b"\nendstream\nendobj\n");
|
||||
|
||||
// Cross-reference table
|
||||
let xref_offset = pdf.len();
|
||||
pdf.extend_from_slice(b"xref\n0 5\n");
|
||||
pdf.extend_from_slice(b"0000000000 65535 f \n");
|
||||
pdf.extend_from_slice(format!("{obj1_offset:010} 00000 n \n").as_bytes());
|
||||
pdf.extend_from_slice(format!("{obj2_offset:010} 00000 n \n").as_bytes());
|
||||
pdf.extend_from_slice(format!("{obj3_offset:010} 00000 n \n").as_bytes());
|
||||
pdf.extend_from_slice(format!("{obj4_offset:010} 00000 n \n").as_bytes());
|
||||
|
||||
// Trailer
|
||||
pdf.extend_from_slice(b"trailer\n<< /Size 5 /Root 1 0 R >>\n");
|
||||
pdf.extend_from_slice(format!("startxref\n{xref_offset}\n%%EOF\n").as_bytes());
|
||||
|
||||
pdf
|
||||
}
|
||||
|
||||
/// Build a minimal PDF with flate-compressed content stream.
|
||||
fn build_flate_pdf(text: &str) -> Vec<u8> {
|
||||
use flate2::write::ZlibEncoder;
|
||||
use flate2::Compression;
|
||||
use std::io::Write as _;
|
||||
|
||||
let content_stream = format!("BT\n/F1 12 Tf\n({text}) Tj\nET");
|
||||
let mut encoder = ZlibEncoder::new(Vec::new(), Compression::default());
|
||||
encoder
|
||||
.write_all(content_stream.as_bytes())
|
||||
.expect("compress");
|
||||
let compressed = encoder.finish().expect("finish");
|
||||
|
||||
let mut pdf = Vec::new();
|
||||
pdf.extend_from_slice(b"%PDF-1.4\n");
|
||||
|
||||
let obj1_offset = pdf.len();
|
||||
pdf.extend_from_slice(b"1 0 obj\n<< /Type /Catalog /Pages 2 0 R >>\nendobj\n");
|
||||
|
||||
let obj2_offset = pdf.len();
|
||||
pdf.extend_from_slice(b"2 0 obj\n<< /Type /Pages /Kids [3 0 R] /Count 1 >>\nendobj\n");
|
||||
|
||||
let obj3_offset = pdf.len();
|
||||
pdf.extend_from_slice(
|
||||
b"3 0 obj\n<< /Type /Page /Parent 2 0 R /Contents 4 0 R >>\nendobj\n",
|
||||
);
|
||||
|
||||
let obj4_offset = pdf.len();
|
||||
let length = compressed.len();
|
||||
let header = format!("4 0 obj\n<< /Length {length} /Filter /FlateDecode >>\nstream\n");
|
||||
pdf.extend_from_slice(header.as_bytes());
|
||||
pdf.extend_from_slice(&compressed);
|
||||
pdf.extend_from_slice(b"\nendstream\nendobj\n");
|
||||
|
||||
let xref_offset = pdf.len();
|
||||
pdf.extend_from_slice(b"xref\n0 5\n");
|
||||
pdf.extend_from_slice(b"0000000000 65535 f \n");
|
||||
pdf.extend_from_slice(format!("{obj1_offset:010} 00000 n \n").as_bytes());
|
||||
pdf.extend_from_slice(format!("{obj2_offset:010} 00000 n \n").as_bytes());
|
||||
pdf.extend_from_slice(format!("{obj3_offset:010} 00000 n \n").as_bytes());
|
||||
pdf.extend_from_slice(format!("{obj4_offset:010} 00000 n \n").as_bytes());
|
||||
|
||||
pdf.extend_from_slice(b"trailer\n<< /Size 5 /Root 1 0 R >>\n");
|
||||
pdf.extend_from_slice(format!("startxref\n{xref_offset}\n%%EOF\n").as_bytes());
|
||||
|
||||
pdf
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn extracts_uncompressed_text_from_minimal_pdf() {
|
||||
// given
|
||||
let pdf_bytes = build_simple_pdf("Hello World");
|
||||
|
||||
// when
|
||||
let text = extract_text_from_bytes(&pdf_bytes);
|
||||
|
||||
// then
|
||||
assert_eq!(text, "Hello World");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn extracts_text_from_flate_compressed_stream() {
|
||||
// given
|
||||
let pdf_bytes = build_flate_pdf("Compressed PDF Text");
|
||||
|
||||
// when
|
||||
let text = extract_text_from_bytes(&pdf_bytes);
|
||||
|
||||
// then
|
||||
assert_eq!(text, "Compressed PDF Text");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn handles_tj_array_operator() {
|
||||
// given
|
||||
let stream = b"BT\n/F1 12 Tf\n[ (Hello) -120 ( World) ] TJ\nET";
|
||||
// Build a raw PDF with TJ array operator instead of simple Tj.
|
||||
let content_stream = std::str::from_utf8(stream).unwrap();
|
||||
let raw = format!(
|
||||
"%PDF-1.4\n1 0 obj\n<< /Type /Catalog >>\nendobj\n\
|
||||
2 0 obj\n<< /Length {} >>\nstream\n{}\nendstream\nendobj\n%%EOF\n",
|
||||
content_stream.len(),
|
||||
content_stream
|
||||
);
|
||||
let pdf_bytes = raw.into_bytes();
|
||||
|
||||
// when
|
||||
let text = extract_text_from_bytes(&pdf_bytes);
|
||||
|
||||
// then
|
||||
assert_eq!(text, "Hello World");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn handles_escaped_parentheses() {
|
||||
// given
|
||||
let content = b"BT\n(Hello \\(World\\)) Tj\nET";
|
||||
let raw = format!(
|
||||
"%PDF-1.4\n1 0 obj\n<< /Length {} >>\nstream\n",
|
||||
content.len()
|
||||
);
|
||||
let mut pdf_bytes = raw.into_bytes();
|
||||
pdf_bytes.extend_from_slice(content);
|
||||
pdf_bytes.extend_from_slice(b"\nendstream\nendobj\n%%EOF\n");
|
||||
|
||||
// when
|
||||
let text = extract_text_from_bytes(&pdf_bytes);
|
||||
|
||||
// then
|
||||
assert_eq!(text, "Hello (World)");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn returns_empty_for_non_pdf_data() {
|
||||
// given
|
||||
let data = b"This is not a PDF file at all";
|
||||
|
||||
// when
|
||||
let text = extract_text_from_bytes(data);
|
||||
|
||||
// then
|
||||
assert!(text.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn extracts_text_from_file_on_disk() {
|
||||
// given
|
||||
let pdf_bytes = build_simple_pdf("Disk Test");
|
||||
let dir = std::env::temp_dir().join("clawd-pdf-extract-test");
|
||||
std::fs::create_dir_all(&dir).unwrap();
|
||||
let pdf_path = dir.join("test.pdf");
|
||||
std::fs::write(&pdf_path, &pdf_bytes).unwrap();
|
||||
|
||||
// when
|
||||
let text = extract_text(&pdf_path).unwrap();
|
||||
|
||||
// then
|
||||
assert_eq!(text, "Disk Test");
|
||||
|
||||
// cleanup
|
||||
let _ = std::fs::remove_dir_all(&dir);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn looks_like_pdf_path_detects_pdf_references() {
|
||||
// given / when / then
|
||||
assert_eq!(
|
||||
looks_like_pdf_path("Please read /tmp/report.pdf"),
|
||||
Some("/tmp/report.pdf")
|
||||
);
|
||||
assert_eq!(looks_like_pdf_path("Check file.PDF now"), Some("file.PDF"));
|
||||
assert_eq!(looks_like_pdf_path("no pdf here"), None);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn maybe_extract_pdf_from_prompt_returns_none_for_missing_file() {
|
||||
// given
|
||||
let prompt = "Read /tmp/nonexistent-abc123.pdf please";
|
||||
|
||||
// when
|
||||
let result = maybe_extract_pdf_from_prompt(prompt);
|
||||
|
||||
// then
|
||||
assert!(result.is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn maybe_extract_pdf_from_prompt_extracts_existing_file() {
|
||||
// given
|
||||
let pdf_bytes = build_simple_pdf("Auto Extracted");
|
||||
let dir = std::env::temp_dir().join("clawd-pdf-auto-extract-test");
|
||||
std::fs::create_dir_all(&dir).unwrap();
|
||||
let pdf_path = dir.join("auto.pdf");
|
||||
std::fs::write(&pdf_path, &pdf_bytes).unwrap();
|
||||
let prompt = format!("Summarize {}", pdf_path.display());
|
||||
|
||||
// when
|
||||
let result = maybe_extract_pdf_from_prompt(&prompt);
|
||||
|
||||
// then
|
||||
let (path, text) = result.expect("should extract");
|
||||
assert_eq!(path, pdf_path.display().to_string());
|
||||
assert_eq!(text, "Auto Extracted");
|
||||
|
||||
// cleanup
|
||||
let _ = std::fs::remove_dir_all(&dir);
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user