From d7f49b23cf3ebea9271c70e9250b0df7cc75ce0e Mon Sep 17 00:00:00 2001 From: sauagarwa Date: Thu, 23 Apr 2026 19:56:44 -0400 Subject: [PATCH] feat(server): add request-ID middleware for request correlation Add a UUID-based request-ID middleware using tower-http's request-id feature. Each inbound request receives a unique x-request-id header (or preserves a client-supplied one), which is recorded in the tracing span and propagated to the response. This enables operators to correlate log lines across the middleware stack for a single request under concurrent load, and lets clients reference specific requests in bug reports. Signed-off-by: sauagarwa --- Cargo.lock | 1 + Cargo.toml | 2 +- crates/openshell-server/src/multiplex.rs | 307 +++++++++++++++--- .../tests/multiplex_integration.rs | 73 +++++ 4 files changed, 338 insertions(+), 45 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 0e59eb64f..32e98f02a 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -5790,6 +5790,7 @@ dependencies = [ "tower-layer", "tower-service", "tracing", + "uuid", ] [[package]] diff --git a/Cargo.toml b/Cargo.toml index cffad2cc1..d1bcb6886 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -25,7 +25,7 @@ prost-types = "0.13" # HTTP server axum = { version = "0.8", features = ["ws"] } tower = "0.5" -tower-http = { version = "0.6", features = ["cors", "trace"] } +tower-http = { version = "0.6", features = ["cors", "trace", "request-id"] } hyper = { version = "1.6", features = ["full"] } hyper-util = { version = "0.1", features = ["tokio", "server-auto"] } http = "1.2" diff --git a/crates/openshell-server/src/multiplex.rs b/crates/openshell-server/src/multiplex.rs index e0c159958..1e71ba062 100644 --- a/crates/openshell-server/src/multiplex.rs +++ b/crates/openshell-server/src/multiplex.rs @@ -7,7 +7,7 @@ //! to either the gRPC service or HTTP endpoints based on the request headers. use bytes::Bytes; -use http::{Request, Response}; +use http::{HeaderValue, Request, Response}; use http_body::Body; use http_body_util::BodyExt; use hyper::body::Incoming; @@ -25,12 +25,83 @@ use std::sync::Arc; use std::task::{Context, Poll}; use std::time::{Duration, Instant}; use tokio::io::{AsyncRead, AsyncWrite}; -use tower::{ServiceBuilder, ServiceExt}; -use tower_http::trace::TraceLayer; +use tower::ServiceExt; +use tower_http::request_id::{MakeRequestId, RequestId}; use tracing::Span; use crate::{OpenShellService, ServerState, http_router, inference::InferenceService}; +/// Request-ID generator that produces a UUID v4 for each inbound request. +#[derive(Clone)] +struct UuidRequestId; + +impl MakeRequestId for UuidRequestId { + fn make_request_id(&mut self, _req: &Request) -> Option { + let id = uuid::Uuid::new_v4().to_string(); + Some(RequestId::new(HeaderValue::from_str(&id).unwrap())) + } +} + +/// Build a tracing span for an inbound request, recording the `request_id` +/// header (set by [`UuidRequestId`] or supplied by the client). +fn make_request_span(req: &Request) -> Span { + let path = req.uri().path(); + let request_id = req + .headers() + .get("x-request-id") + .and_then(|v| v.to_str().ok()) + .unwrap_or("-"); + + if matches!(path, "/health" | "/healthz" | "/readyz") { + tracing::debug_span!( + "request", + method = %req.method(), + path, + request_id, + ) + } else { + tracing::info_span!( + "request", + method = %req.method(), + path, + request_id, + ) + } +} + +/// Log response status and latency within the request span. +fn log_response(res: &Response, latency: Duration, _span: &Span) { + tracing::info!( + status = res.status().as_u16(), + latency_ms = latency.as_millis(), + "response" + ); +} + +/// Wrap a service with the standard request-ID middleware stack. +/// +/// Layer order: `SetRequestId` → `TraceLayer` → `PropagateRequestId`. +macro_rules! request_id_middleware { + ($service:expr) => {{ + let x_request_id = ::http::HeaderName::from_static("x-request-id"); + ::tower::ServiceBuilder::new() + .layer(::tower_http::request_id::SetRequestIdLayer::new( + x_request_id.clone(), + UuidRequestId, + )) + .layer( + ::tower_http::trace::TraceLayer::new_for_http() + .make_span_with(make_request_span) + .on_request(()) + .on_response(log_response), + ) + .layer(::tower_http::request_id::PropagateRequestIdLayer::new( + x_request_id, + )) + .service($service) + }}; +} + /// Maximum inbound gRPC message size (1 MB). /// /// Replaces tonic's implicit 4 MB default with a conservative limit to @@ -64,22 +135,8 @@ impl MultiplexService { let grpc_service = GrpcRouter::new(openshell, inference); let http_service = http_router(self.state.clone()); - let grpc_service = ServiceBuilder::new() - .layer( - TraceLayer::new_for_http() - .make_span_with(make_request_span) - .on_request(()) - .on_response(log_response), - ) - .service(grpc_service); - let http_service = ServiceBuilder::new() - .layer( - TraceLayer::new_for_http() - .make_span_with(make_request_span) - .on_request(()) - .on_response(log_response), - ) - .service(http_service); + let grpc_service = request_id_middleware!(grpc_service); + let http_service = request_id_middleware!(http_service); let service = MultiplexedService::new(grpc_service, http_service); @@ -248,31 +305,6 @@ where } } -fn make_request_span(req: &Request) -> Span { - let path = req.uri().path(); - if matches!(path, "/health" | "/healthz" | "/readyz") { - tracing::debug_span!( - "request", - method = %req.method(), - path, - ) - } else { - tracing::info_span!( - "request", - method = %req.method(), - path, - ) - } -} - -fn log_response(res: &Response, latency: Duration, _span: &Span) { - tracing::info!( - status = res.status().as_u16(), - latency_ms = latency.as_millis(), - "response" - ); -} - fn grpc_method_from_path(path: &str) -> String { path.rsplit('/').next().unwrap_or(path).to_string() } @@ -321,6 +353,193 @@ impl Body for BoxBody { #[cfg(test)] mod tests { use super::*; + use bytes::Bytes; + use http_body_util::Empty; + use std::sync::Mutex; + + #[test] + fn uuid_request_id_generates_valid_uuid() { + let mut maker = UuidRequestId; + let req = Request::builder().body(()).unwrap(); + let id = maker.make_request_id(&req).expect("should produce an ID"); + let value = id.header_value().to_str().unwrap(); + uuid::Uuid::parse_str(value).expect("should be a valid UUID"); + } + + #[test] + fn uuid_request_id_generates_unique_ids() { + let mut maker = UuidRequestId; + let req = Request::builder().body(()).unwrap(); + let id1 = maker.make_request_id(&req).unwrap(); + let id2 = maker.make_request_id(&req).unwrap(); + assert_ne!(id1.header_value(), id2.header_value()); + } + + async fn start_http_server_with_middleware() -> std::net::SocketAddr { + let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap(); + let addr = listener.local_addr().unwrap(); + + let http_service = crate::http::health_router(); + let http_service = request_id_middleware!(http_service); + + let service = MultiplexedService::new(http_service.clone(), http_service); + + tokio::spawn(async move { + loop { + let Ok((stream, _)) = listener.accept().await else { + continue; + }; + let svc = service.clone(); + tokio::spawn(async move { + let _ = Builder::new(TokioExecutor::new()) + .serve_connection(TokioIo::new(stream), svc) + .await; + }); + } + }); + + addr + } + + async fn http1_get( + addr: std::net::SocketAddr, + path: &str, + headers: &[(&str, &str)], + ) -> Response { + let stream = tokio::net::TcpStream::connect(addr).await.unwrap(); + let (mut sender, conn) = hyper::client::conn::http1::Builder::new() + .handshake(TokioIo::new(stream)) + .await + .unwrap(); + tokio::spawn(async move { + let _ = conn.await; + }); + + let mut builder = Request::builder() + .method("GET") + .uri(format!("http://{addr}{path}")); + for (k, v) in headers { + builder = builder.header(*k, *v); + } + let req = builder.body(Empty::::new()).unwrap(); + sender.send_request(req).await.unwrap() + } + + #[tokio::test] + async fn http_response_includes_request_id() { + let addr = start_http_server_with_middleware().await; + let resp = http1_get(addr, "/healthz", &[]).await; + assert_eq!(resp.status(), 200); + + let request_id = resp + .headers() + .get("x-request-id") + .expect("response should include x-request-id header"); + let id_str = request_id.to_str().unwrap(); + uuid::Uuid::parse_str(id_str).expect("should be a valid UUID"); + } + + #[tokio::test] + async fn http_preserves_client_request_id() { + let addr = start_http_server_with_middleware().await; + let client_id = "my-custom-correlation-id"; + let resp = http1_get(addr, "/healthz", &[("x-request-id", client_id)]).await; + assert_eq!(resp.status(), 200); + + let request_id = resp + .headers() + .get("x-request-id") + .expect("response should include x-request-id header"); + assert_eq!(request_id.to_str().unwrap(), client_id); + } + + #[tokio::test] + async fn each_request_gets_unique_id() { + let addr = start_http_server_with_middleware().await; + + let mut ids = Vec::new(); + for _ in 0..3 { + let resp = http1_get(addr, "/healthz", &[]).await; + let id = resp + .headers() + .get("x-request-id") + .unwrap() + .to_str() + .unwrap() + .to_string(); + ids.push(id); + } + + assert_ne!(ids[0], ids[1]); + assert_ne!(ids[1], ids[2]); + assert_ne!(ids[0], ids[2]); + } + + #[tokio::test] + async fn grpc_path_includes_request_id() { + let addr = start_http_server_with_middleware().await; + let resp = http1_get( + addr, + "/openshell.v1.OpenShell/Health", + &[ + ("content-type", "application/grpc"), + ("x-request-id", "grpc-corr-id"), + ], + ) + .await; + + let request_id = resp + .headers() + .get("x-request-id") + .expect("gRPC-routed response should include x-request-id header"); + assert_eq!(request_id.to_str().unwrap(), "grpc-corr-id"); + } + + #[derive(Clone)] + struct TraceBuf(Arc>>); + + impl std::io::Write for TraceBuf { + fn write(&mut self, buf: &[u8]) -> std::io::Result { + self.0.lock().unwrap().extend_from_slice(buf); + Ok(buf.len()) + } + + fn flush(&mut self) -> std::io::Result<()> { + Ok(()) + } + } + + #[test] + fn request_id_appears_in_trace_span() { + use tracing_subscriber::fmt::format::FmtSpan; + use tracing_subscriber::layer::SubscriberExt; + + let log_buf: Arc>> = Arc::new(Mutex::new(Vec::new())); + let writer = TraceBuf(log_buf.clone()); + + let fmt_layer = tracing_subscriber::fmt::layer() + .with_writer(move || writer.clone()) + .with_ansi(false) + .with_span_events(FmtSpan::CLOSE); + + let subscriber = tracing_subscriber::registry().with(fmt_layer); + let _guard = tracing::subscriber::set_default(subscriber); + + let req = Request::builder() + .uri("/test-path") + .header("x-request-id", "trace-test-id-12345") + .body(Empty::::new()) + .unwrap(); + let span = make_request_span(&req); + drop(span.enter()); + drop(span); + + let output = String::from_utf8(log_buf.lock().unwrap().clone()).unwrap(); + assert!( + output.contains("trace-test-id-12345"), + "trace output should contain the request_id recorded in the span, got: {output}" + ); + } #[test] fn grpc_method_extracts_last_segment() { diff --git a/crates/openshell-server/tests/multiplex_integration.rs b/crates/openshell-server/tests/multiplex_integration.rs index dd14c63ec..9c466d4f8 100644 --- a/crates/openshell-server/tests/multiplex_integration.rs +++ b/crates/openshell-server/tests/multiplex_integration.rs @@ -343,3 +343,76 @@ async fn serves_grpc_and_http_on_same_port() { server.abort(); } + +/// Verify tonic metadata ↔ HTTP header roundtrip for `x-request-id`. +/// +/// This intentionally constructs its own request-ID layers from +/// `tower-http`'s public API rather than reusing the production macro +/// (which is crate-private). Production middleware composition and +/// layer ordering are covered by the unit tests in `multiplex::tests`. +#[tokio::test] +async fn grpc_response_propagates_request_id() { + use tower::ServiceBuilder; + use tower_http::request_id::{ + MakeRequestId, PropagateRequestIdLayer, RequestId, SetRequestIdLayer, + }; + + #[derive(Clone)] + struct TestUuidRequestId; + + impl MakeRequestId for TestUuidRequestId { + fn make_request_id(&mut self, _req: &Request) -> Option { + let id = uuid::Uuid::new_v4().to_string(); + Some(RequestId::new(http::HeaderValue::from_str(&id).unwrap())) + } + } + + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let addr = listener.local_addr().unwrap(); + + let x_request_id = http::HeaderName::from_static("x-request-id"); + let grpc_service = ServiceBuilder::new() + .layer(SetRequestIdLayer::new( + x_request_id.clone(), + TestUuidRequestId, + )) + .layer(PropagateRequestIdLayer::new(x_request_id)) + .service(OpenShellServer::new(TestOpenShell)); + let http_service = health_router(); + let service = MultiplexedService::new(grpc_service, http_service); + + tokio::spawn(async move { + loop { + let Ok((stream, _)) = listener.accept().await else { + continue; + }; + let svc = service.clone(); + tokio::spawn(async move { + let _ = Builder::new(TokioExecutor::new()) + .serve_connection(TokioIo::new(stream), svc) + .await; + }); + } + }); + + let mut client = OpenShellClient::connect(format!("http://{addr}")) + .await + .unwrap(); + + // Server generates a UUID when client omits x-request-id. + let response = client.health(HealthRequest {}).await.unwrap(); + let generated = response + .metadata() + .get("x-request-id") + .expect("gRPC response should include server-generated x-request-id"); + uuid::Uuid::parse_str(generated.to_str().unwrap()).expect("should be a valid UUID"); + + // Server preserves a client-supplied x-request-id. + let mut request = tonic::Request::new(HealthRequest {}); + request + .metadata_mut() + .insert("x-request-id", "grpc-corr-id".parse().unwrap()); + let response = client.health(request).await.unwrap(); + let echoed = response.metadata().get("x-request-id").unwrap(); + assert_eq!(echoed.to_str().unwrap(), "grpc-corr-id"); +}