feat: HTTP/SSE server crate with axum (session management, event streaming)

This commit is contained in:
Sisyphus
2026-04-01 21:26:06 +09:00
parent 12e935b30f
commit 48e36d422a
3 changed files with 623 additions and 2 deletions

163
rust/Cargo.lock generated
View File

@@ -28,12 +28,86 @@ dependencies = [
"tokio", "tokio",
] ]
[[package]]
name = "async-stream"
version = "0.3.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0b5a71a6f37880a80d1d7f19efd781e4b5de42c88f0722cc13bcb6cc2cfe8476"
dependencies = [
"async-stream-impl",
"futures-core",
"pin-project-lite",
]
[[package]]
name = "async-stream-impl"
version = "0.3.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c7c24de15d275a1ecfd47a380fb4d5ec9bfe0933f309ed5e705b775596a3574d"
dependencies = [
"proc-macro2",
"quote",
"syn",
]
[[package]] [[package]]
name = "atomic-waker" name = "atomic-waker"
version = "1.1.2" version = "1.1.2"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1505bd5d3d116872e7271a6d4e16d81d0c8570876c8de68093a09ac269d8aac0" checksum = "1505bd5d3d116872e7271a6d4e16d81d0c8570876c8de68093a09ac269d8aac0"
[[package]]
name = "axum"
version = "0.8.8"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8b52af3cb4058c895d37317bb27508dccc8e5f2d39454016b297bf4a400597b8"
dependencies = [
"axum-core",
"bytes",
"form_urlencoded",
"futures-util",
"http",
"http-body",
"http-body-util",
"hyper",
"hyper-util",
"itoa",
"matchit",
"memchr",
"mime",
"percent-encoding",
"pin-project-lite",
"serde_core",
"serde_json",
"serde_path_to_error",
"serde_urlencoded",
"sync_wrapper",
"tokio",
"tower",
"tower-layer",
"tower-service",
"tracing",
]
[[package]]
name = "axum-core"
version = "0.5.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "08c78f31d7b1291f7ee735c1c6780ccde7785daae9a9206026862dab7d8792d1"
dependencies = [
"bytes",
"futures-core",
"http",
"http-body",
"http-body-util",
"mime",
"pin-project-lite",
"sync_wrapper",
"tower-layer",
"tower-service",
"tracing",
]
[[package]] [[package]]
name = "base64" name = "base64"
version = "0.22.1" version = "0.22.1"
@@ -262,7 +336,7 @@ checksum = "0ce92ff622d6dadf7349484f42c93271a0d49b7cc4d466a936405bacbe10aa78"
dependencies = [ dependencies = [
"cfg-if", "cfg-if",
"rustix 1.1.4", "rustix 1.1.4",
"windows-sys 0.52.0", "windows-sys 0.59.0",
] ]
[[package]] [[package]]
@@ -318,6 +392,17 @@ version = "0.3.32"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "cecba35d7ad927e23624b22ad55235f2239cfa44fd10428eecbeba6d6a717718" checksum = "cecba35d7ad927e23624b22ad55235f2239cfa44fd10428eecbeba6d6a717718"
[[package]]
name = "futures-macro"
version = "0.3.32"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e835b70203e41293343137df5c0664546da5745f82ec9b84d40be8336958447b"
dependencies = [
"proc-macro2",
"quote",
"syn",
]
[[package]] [[package]]
name = "futures-sink" name = "futures-sink"
version = "0.3.32" version = "0.3.32"
@@ -338,6 +423,7 @@ checksum = "389ca41296e6190b48053de0321d02a77f32f8a5d2461dd38762c0593805c6d6"
dependencies = [ dependencies = [
"futures-core", "futures-core",
"futures-io", "futures-io",
"futures-macro",
"futures-sink", "futures-sink",
"futures-task", "futures-task",
"memchr", "memchr",
@@ -451,6 +537,12 @@ version = "1.10.1"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6dbf3de79e51f3d586ab4cb9d5c3e2c14aa28ed23d180cf89b4df0454a69cc87" checksum = "6dbf3de79e51f3d586ab4cb9d5c3e2c14aa28ed23d180cf89b4df0454a69cc87"
[[package]]
name = "httpdate"
version = "1.0.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "df3b46402a9d5adb4c86a0cf463f42e19994e3ee891101b1841f30a545cb49a9"
[[package]] [[package]]
name = "hyper" name = "hyper"
version = "1.9.0" version = "1.9.0"
@@ -464,6 +556,7 @@ dependencies = [
"http", "http",
"http-body", "http-body",
"httparse", "httparse",
"httpdate",
"itoa", "itoa",
"pin-project-lite", "pin-project-lite",
"smallvec", "smallvec",
@@ -708,12 +801,24 @@ version = "0.1.2"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "112b39cec0b298b6c1999fee3e31427f74f676e4cb9879ed1a121b43661a4154" checksum = "112b39cec0b298b6c1999fee3e31427f74f676e4cb9879ed1a121b43661a4154"
[[package]]
name = "matchit"
version = "0.8.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "47e1ffaa40ddd1f3ed91f717a33c8c0ee23fff369e3aa8772b9605cc1d22f4c3"
[[package]] [[package]]
name = "memchr" name = "memchr"
version = "2.8.0" version = "2.8.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f8ca58f447f06ed17d5fc4043ce1b10dd205e060fb3ce5b979b8ed8e59ff3f79" checksum = "f8ca58f447f06ed17d5fc4043ce1b10dd205e060fb3ce5b979b8ed8e59ff3f79"
[[package]]
name = "mime"
version = "0.3.17"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6877bb514081ee2a7ff5ef9de3281f14a4dd4bceac4c09388074a6b5df8a139a"
[[package]] [[package]]
name = "miniz_oxide" name = "miniz_oxide"
version = "0.8.9" version = "0.8.9"
@@ -1091,12 +1196,14 @@ dependencies = [
"sync_wrapper", "sync_wrapper",
"tokio", "tokio",
"tokio-rustls", "tokio-rustls",
"tokio-util",
"tower", "tower",
"tower-http", "tower-http",
"tower-service", "tower-service",
"url", "url",
"wasm-bindgen", "wasm-bindgen",
"wasm-bindgen-futures", "wasm-bindgen-futures",
"wasm-streams",
"web-sys", "web-sys",
"webpki-roots", "webpki-roots",
] ]
@@ -1145,7 +1252,7 @@ dependencies = [
"errno", "errno",
"libc", "libc",
"linux-raw-sys 0.4.15", "linux-raw-sys 0.4.15",
"windows-sys 0.52.0", "windows-sys 0.59.0",
] ]
[[package]] [[package]]
@@ -1288,6 +1395,17 @@ dependencies = [
"zmij", "zmij",
] ]
[[package]]
name = "serde_path_to_error"
version = "0.1.20"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "10a9ff822e371bb5403e391ecd83e182e0e77ba7f6fe0160b795797109d1b457"
dependencies = [
"itoa",
"serde",
"serde_core",
]
[[package]] [[package]]
name = "serde_urlencoded" name = "serde_urlencoded"
version = "0.7.1" version = "0.7.1"
@@ -1300,6 +1418,19 @@ dependencies = [
"serde", "serde",
] ]
[[package]]
name = "server"
version = "0.1.0"
dependencies = [
"async-stream",
"axum",
"reqwest",
"runtime",
"serde",
"serde_json",
"tokio",
]
[[package]] [[package]]
name = "sha2" name = "sha2"
version = "0.10.9" version = "0.10.9"
@@ -1553,6 +1684,19 @@ dependencies = [
"tokio", "tokio",
] ]
[[package]]
name = "tokio-util"
version = "0.7.18"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9ae9cec805b01e8fc3fd2fe289f89149a9b66dd16786abd8b19cfa7b48cb0098"
dependencies = [
"bytes",
"futures-core",
"futures-sink",
"pin-project-lite",
"tokio",
]
[[package]] [[package]]
name = "tools" name = "tools"
version = "0.1.0" version = "0.1.0"
@@ -1579,6 +1723,7 @@ dependencies = [
"tokio", "tokio",
"tower-layer", "tower-layer",
"tower-service", "tower-service",
"tracing",
] ]
[[package]] [[package]]
@@ -1617,6 +1762,7 @@ version = "0.1.44"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "63e71662fa4b2a2c3a26f570f037eb95bb1f85397f3cd8076caed2f026a6d100" checksum = "63e71662fa4b2a2c3a26f570f037eb95bb1f85397f3cd8076caed2f026a6d100"
dependencies = [ dependencies = [
"log",
"pin-project-lite", "pin-project-lite",
"tracing-core", "tracing-core",
] ]
@@ -1791,6 +1937,19 @@ dependencies = [
"unicode-ident", "unicode-ident",
] ]
[[package]]
name = "wasm-streams"
version = "0.4.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "15053d8d85c7eccdbefef60f06769760a563c7f0a9d6902a13d35c7800b0ad65"
dependencies = [
"futures-util",
"js-sys",
"wasm-bindgen",
"wasm-bindgen-futures",
"web-sys",
]
[[package]] [[package]]
name = "web-sys" name = "web-sys"
version = "0.3.93" version = "0.3.93"

View File

@@ -0,0 +1,20 @@
[package]
name = "server"
version.workspace = true
edition.workspace = true
license.workspace = true
publish.workspace = true
[dependencies]
async-stream = "0.3"
axum = "0.8"
runtime = { path = "../runtime" }
serde = { version = "1", features = ["derive"] }
serde_json.workspace = true
tokio = { version = "1", features = ["macros", "rt-multi-thread", "sync", "net", "time"] }
[dev-dependencies]
reqwest = { version = "0.12", default-features = false, features = ["json", "rustls-tls", "stream"] }
[lints]
workspace = true

View File

@@ -0,0 +1,442 @@
use std::collections::HashMap;
use std::convert::Infallible;
use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::Arc;
use std::time::{Duration, SystemTime, UNIX_EPOCH};
use async_stream::stream;
use axum::extract::{Path, State};
use axum::http::StatusCode;
use axum::response::sse::{Event, KeepAlive, Sse};
use axum::response::IntoResponse;
use axum::routing::{get, post};
use axum::{Json, Router};
use runtime::{ConversationMessage, Session as RuntimeSession};
use serde::{Deserialize, Serialize};
use tokio::sync::{broadcast, RwLock};
pub type SessionId = String;
pub type SessionStore = Arc<RwLock<HashMap<SessionId, Session>>>;
const BROADCAST_CAPACITY: usize = 64;
#[derive(Clone)]
pub struct AppState {
sessions: SessionStore,
next_session_id: Arc<AtomicU64>,
}
impl AppState {
#[must_use]
pub fn new() -> Self {
Self {
sessions: Arc::new(RwLock::new(HashMap::new())),
next_session_id: Arc::new(AtomicU64::new(1)),
}
}
fn allocate_session_id(&self) -> SessionId {
let id = self.next_session_id.fetch_add(1, Ordering::Relaxed);
format!("session-{id}")
}
}
impl Default for AppState {
fn default() -> Self {
Self::new()
}
}
#[derive(Clone)]
pub struct Session {
pub id: SessionId,
pub created_at: u64,
pub conversation: RuntimeSession,
events: broadcast::Sender<SessionEvent>,
}
impl Session {
fn new(id: SessionId) -> Self {
let (events, _) = broadcast::channel(BROADCAST_CAPACITY);
Self {
id,
created_at: unix_timestamp_millis(),
conversation: RuntimeSession::new(),
events,
}
}
fn subscribe(&self) -> broadcast::Receiver<SessionEvent> {
self.events.subscribe()
}
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
#[serde(tag = "type", rename_all = "snake_case")]
enum SessionEvent {
Snapshot {
session_id: SessionId,
session: RuntimeSession,
},
Message {
session_id: SessionId,
message: ConversationMessage,
},
}
impl SessionEvent {
fn event_name(&self) -> &'static str {
match self {
Self::Snapshot { .. } => "snapshot",
Self::Message { .. } => "message",
}
}
fn to_sse_event(&self) -> Result<Event, serde_json::Error> {
Ok(Event::default()
.event(self.event_name())
.data(serde_json::to_string(self)?))
}
}
#[derive(Debug, Serialize)]
struct ErrorResponse {
error: String,
}
type ApiError = (StatusCode, Json<ErrorResponse>);
type ApiResult<T> = Result<T, ApiError>;
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
pub struct CreateSessionResponse {
pub session_id: SessionId,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
pub struct SessionSummary {
pub id: SessionId,
pub created_at: u64,
pub message_count: usize,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
pub struct ListSessionsResponse {
pub sessions: Vec<SessionSummary>,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
pub struct SessionDetailsResponse {
pub id: SessionId,
pub created_at: u64,
pub session: RuntimeSession,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
pub struct SendMessageRequest {
pub message: String,
}
#[must_use]
pub fn app(state: AppState) -> Router {
Router::new()
.route("/sessions", post(create_session).get(list_sessions))
.route("/sessions/{id}", get(get_session))
.route("/sessions/{id}/events", get(stream_session_events))
.route("/sessions/{id}/message", post(send_message))
.with_state(state)
}
async fn create_session(
State(state): State<AppState>,
) -> (StatusCode, Json<CreateSessionResponse>) {
let session_id = state.allocate_session_id();
let session = Session::new(session_id.clone());
state
.sessions
.write()
.await
.insert(session_id.clone(), session);
(
StatusCode::CREATED,
Json(CreateSessionResponse { session_id }),
)
}
async fn list_sessions(State(state): State<AppState>) -> Json<ListSessionsResponse> {
let sessions = state.sessions.read().await;
let mut summaries = sessions
.values()
.map(|session| SessionSummary {
id: session.id.clone(),
created_at: session.created_at,
message_count: session.conversation.messages.len(),
})
.collect::<Vec<_>>();
summaries.sort_by(|left, right| left.id.cmp(&right.id));
Json(ListSessionsResponse {
sessions: summaries,
})
}
async fn get_session(
State(state): State<AppState>,
Path(id): Path<SessionId>,
) -> ApiResult<Json<SessionDetailsResponse>> {
let sessions = state.sessions.read().await;
let session = sessions
.get(&id)
.ok_or_else(|| not_found(format!("session `{id}` not found")))?;
Ok(Json(SessionDetailsResponse {
id: session.id.clone(),
created_at: session.created_at,
session: session.conversation.clone(),
}))
}
async fn send_message(
State(state): State<AppState>,
Path(id): Path<SessionId>,
Json(payload): Json<SendMessageRequest>,
) -> ApiResult<StatusCode> {
let message = ConversationMessage::user_text(payload.message);
let broadcaster = {
let mut sessions = state.sessions.write().await;
let session = sessions
.get_mut(&id)
.ok_or_else(|| not_found(format!("session `{id}` not found")))?;
session.conversation.messages.push(message.clone());
session.events.clone()
};
let _ = broadcaster.send(SessionEvent::Message {
session_id: id,
message,
});
Ok(StatusCode::NO_CONTENT)
}
async fn stream_session_events(
State(state): State<AppState>,
Path(id): Path<SessionId>,
) -> ApiResult<impl IntoResponse> {
let (snapshot, mut receiver) = {
let sessions = state.sessions.read().await;
let session = sessions
.get(&id)
.ok_or_else(|| not_found(format!("session `{id}` not found")))?;
(
SessionEvent::Snapshot {
session_id: session.id.clone(),
session: session.conversation.clone(),
},
session.subscribe(),
)
};
let stream = stream! {
if let Ok(event) = snapshot.to_sse_event() {
yield Ok::<Event, Infallible>(event);
}
loop {
match receiver.recv().await {
Ok(event) => {
if let Ok(sse_event) = event.to_sse_event() {
yield Ok::<Event, Infallible>(sse_event);
}
}
Err(broadcast::error::RecvError::Lagged(_)) => continue,
Err(broadcast::error::RecvError::Closed) => break,
}
}
};
Ok(Sse::new(stream).keep_alive(KeepAlive::new().interval(Duration::from_secs(15))))
}
fn unix_timestamp_millis() -> u64 {
SystemTime::now()
.duration_since(UNIX_EPOCH)
.expect("system time should be after epoch")
.as_millis() as u64
}
fn not_found(message: String) -> ApiError {
(
StatusCode::NOT_FOUND,
Json(ErrorResponse { error: message }),
)
}
#[cfg(test)]
mod tests {
use super::{
app, AppState, CreateSessionResponse, ListSessionsResponse, SessionDetailsResponse,
};
use reqwest::Client;
use std::net::SocketAddr;
use std::time::Duration;
use tokio::net::TcpListener;
use tokio::task::JoinHandle;
use tokio::time::timeout;
struct TestServer {
address: SocketAddr,
handle: JoinHandle<()>,
}
impl TestServer {
async fn spawn() -> Self {
let listener = TcpListener::bind("127.0.0.1:0")
.await
.expect("test listener should bind");
let address = listener
.local_addr()
.expect("listener should report local address");
let handle = tokio::spawn(async move {
axum::serve(listener, app(AppState::default()))
.await
.expect("server should run");
});
Self { address, handle }
}
fn url(&self, path: &str) -> String {
format!("http://{}{}", self.address, path)
}
}
impl Drop for TestServer {
fn drop(&mut self) {
self.handle.abort();
}
}
async fn create_session(client: &Client, server: &TestServer) -> CreateSessionResponse {
client
.post(server.url("/sessions"))
.send()
.await
.expect("create request should succeed")
.error_for_status()
.expect("create request should return success")
.json::<CreateSessionResponse>()
.await
.expect("create response should parse")
}
async fn next_sse_frame(response: &mut reqwest::Response, buffer: &mut String) -> String {
loop {
if let Some(index) = buffer.find("\n\n") {
let frame = buffer[..index].to_string();
let remainder = buffer[index + 2..].to_string();
*buffer = remainder;
return frame;
}
let next_chunk = timeout(Duration::from_secs(5), response.chunk())
.await
.expect("SSE stream should yield within timeout")
.expect("SSE stream should remain readable")
.expect("SSE stream should stay open");
buffer.push_str(&String::from_utf8_lossy(&next_chunk));
}
}
#[tokio::test]
async fn creates_and_lists_sessions() {
let server = TestServer::spawn().await;
let client = Client::new();
// given
let created = create_session(&client, &server).await;
// when
let sessions = client
.get(server.url("/sessions"))
.send()
.await
.expect("list request should succeed")
.error_for_status()
.expect("list request should return success")
.json::<ListSessionsResponse>()
.await
.expect("list response should parse");
let details = client
.get(server.url(&format!("/sessions/{}", created.session_id)))
.send()
.await
.expect("details request should succeed")
.error_for_status()
.expect("details request should return success")
.json::<SessionDetailsResponse>()
.await
.expect("details response should parse");
// then
assert_eq!(created.session_id, "session-1");
assert_eq!(sessions.sessions.len(), 1);
assert_eq!(sessions.sessions[0].id, created.session_id);
assert_eq!(sessions.sessions[0].message_count, 0);
assert_eq!(details.id, "session-1");
assert!(details.session.messages.is_empty());
}
#[tokio::test]
async fn streams_message_events_and_persists_message_flow() {
let server = TestServer::spawn().await;
let client = Client::new();
// given
let created = create_session(&client, &server).await;
let mut response = client
.get(server.url(&format!("/sessions/{}/events", created.session_id)))
.send()
.await
.expect("events request should succeed")
.error_for_status()
.expect("events request should return success");
let mut buffer = String::new();
let snapshot_frame = next_sse_frame(&mut response, &mut buffer).await;
// when
let send_status = client
.post(server.url(&format!("/sessions/{}/message", created.session_id)))
.json(&super::SendMessageRequest {
message: "hello from test".to_string(),
})
.send()
.await
.expect("message request should succeed")
.status();
let message_frame = next_sse_frame(&mut response, &mut buffer).await;
let details = client
.get(server.url(&format!("/sessions/{}", created.session_id)))
.send()
.await
.expect("details request should succeed")
.error_for_status()
.expect("details request should return success")
.json::<SessionDetailsResponse>()
.await
.expect("details response should parse");
// then
assert_eq!(send_status, reqwest::StatusCode::NO_CONTENT);
assert!(snapshot_frame.contains("event: snapshot"));
assert!(snapshot_frame.contains("\"session_id\":\"session-1\""));
assert!(message_frame.contains("event: message"));
assert!(message_frame.contains("hello from test"));
assert_eq!(details.session.messages.len(), 1);
assert_eq!(
details.session.messages[0],
runtime::ConversationMessage::user_text("hello from test")
);
}
}