Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ prost-types = "0.13"
# HTTP server
axum = { version = "0.8", features = ["ws"] }
tower = "0.5"
tower-http = { version = "0.6", features = ["cors", "trace"] }
tower-http = { version = "0.6", features = ["cors", "trace", "request-id"] }
hyper = { version = "1.6", features = ["full"] }
hyper-util = { version = "0.1", features = ["tokio", "server-auto"] }
http = "1.2"
Expand Down
307 changes: 263 additions & 44 deletions crates/openshell-server/src/multiplex.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
//! to either the gRPC service or HTTP endpoints based on the request headers.

use bytes::Bytes;
use http::{Request, Response};
use http::{HeaderValue, Request, Response};
use http_body::Body;
use http_body_util::BodyExt;
use hyper::body::Incoming;
Expand All @@ -25,12 +25,83 @@ use std::sync::Arc;
use std::task::{Context, Poll};
use std::time::{Duration, Instant};
use tokio::io::{AsyncRead, AsyncWrite};
use tower::{ServiceBuilder, ServiceExt};
use tower_http::trace::TraceLayer;
use tower::ServiceExt;
use tower_http::request_id::{MakeRequestId, RequestId};
use tracing::Span;

use crate::{OpenShellService, ServerState, http_router, inference::InferenceService};

/// Request-ID generator that produces a UUID v4 for each inbound request.
#[derive(Clone)]
struct UuidRequestId;

impl MakeRequestId for UuidRequestId {
fn make_request_id<B>(&mut self, _req: &Request<B>) -> Option<RequestId> {
let id = uuid::Uuid::new_v4().to_string();
Some(RequestId::new(HeaderValue::from_str(&id).unwrap()))
}
}

/// Build a tracing span for an inbound request, recording the `request_id`
/// header (set by [`UuidRequestId`] or supplied by the client).
fn make_request_span<B>(req: &Request<B>) -> Span {
let path = req.uri().path();
let request_id = req
.headers()
.get("x-request-id")
.and_then(|v| v.to_str().ok())
.unwrap_or("-");

if matches!(path, "/health" | "/healthz" | "/readyz") {
tracing::debug_span!(
"request",
method = %req.method(),
path,
request_id,
)
} else {
tracing::info_span!(
"request",
method = %req.method(),
path,
request_id,
)
}
}

/// Log response status and latency within the request span.
fn log_response<B>(res: &Response<B>, latency: Duration, _span: &Span) {
tracing::info!(
status = res.status().as_u16(),
latency_ms = latency.as_millis(),
"response"
);
}

/// Wrap a service with the standard request-ID middleware stack.
///
/// Layer order: `SetRequestId` → `TraceLayer` → `PropagateRequestId`.
macro_rules! request_id_middleware {
($service:expr) => {{
let x_request_id = ::http::HeaderName::from_static("x-request-id");
::tower::ServiceBuilder::new()
.layer(::tower_http::request_id::SetRequestIdLayer::new(
x_request_id.clone(),
UuidRequestId,
))
.layer(
::tower_http::trace::TraceLayer::new_for_http()
.make_span_with(make_request_span)
.on_request(())
.on_response(log_response),
)
.layer(::tower_http::request_id::PropagateRequestIdLayer::new(
x_request_id,
))
.service($service)
}};
}

/// Maximum inbound gRPC message size (1 MB).
///
/// Replaces tonic's implicit 4 MB default with a conservative limit to
Expand Down Expand Up @@ -64,22 +135,8 @@ impl MultiplexService {
let grpc_service = GrpcRouter::new(openshell, inference);
let http_service = http_router(self.state.clone());

let grpc_service = ServiceBuilder::new()
.layer(
TraceLayer::new_for_http()
.make_span_with(make_request_span)
.on_request(())
.on_response(log_response),
)
.service(grpc_service);
let http_service = ServiceBuilder::new()
.layer(
TraceLayer::new_for_http()
.make_span_with(make_request_span)
.on_request(())
.on_response(log_response),
)
.service(http_service);
let grpc_service = request_id_middleware!(grpc_service);
let http_service = request_id_middleware!(http_service);

let service = MultiplexedService::new(grpc_service, http_service);

Expand Down Expand Up @@ -248,31 +305,6 @@ where
}
}

fn make_request_span<B>(req: &Request<B>) -> Span {
let path = req.uri().path();
if matches!(path, "/health" | "/healthz" | "/readyz") {
tracing::debug_span!(
"request",
method = %req.method(),
path,
)
} else {
tracing::info_span!(
"request",
method = %req.method(),
path,
)
}
}

fn log_response<B>(res: &Response<B>, latency: Duration, _span: &Span) {
tracing::info!(
status = res.status().as_u16(),
latency_ms = latency.as_millis(),
"response"
);
}

fn grpc_method_from_path(path: &str) -> String {
path.rsplit('/').next().unwrap_or(path).to_string()
}
Expand Down Expand Up @@ -321,6 +353,193 @@ impl Body for BoxBody {
#[cfg(test)]
mod tests {
use super::*;
use bytes::Bytes;
use http_body_util::Empty;
use std::sync::Mutex;

#[test]
fn uuid_request_id_generates_valid_uuid() {
let mut maker = UuidRequestId;
let req = Request::builder().body(()).unwrap();
let id = maker.make_request_id(&req).expect("should produce an ID");
let value = id.header_value().to_str().unwrap();
uuid::Uuid::parse_str(value).expect("should be a valid UUID");
}

#[test]
fn uuid_request_id_generates_unique_ids() {
let mut maker = UuidRequestId;
let req = Request::builder().body(()).unwrap();
let id1 = maker.make_request_id(&req).unwrap();
let id2 = maker.make_request_id(&req).unwrap();
assert_ne!(id1.header_value(), id2.header_value());
}

async fn start_http_server_with_middleware() -> std::net::SocketAddr {
let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();

let http_service = crate::http::health_router();
let http_service = request_id_middleware!(http_service);

let service = MultiplexedService::new(http_service.clone(), http_service);

tokio::spawn(async move {
loop {
let Ok((stream, _)) = listener.accept().await else {
continue;
};
let svc = service.clone();
tokio::spawn(async move {
let _ = Builder::new(TokioExecutor::new())
.serve_connection(TokioIo::new(stream), svc)
.await;
});
}
});

addr
}

async fn http1_get(
addr: std::net::SocketAddr,
path: &str,
headers: &[(&str, &str)],
) -> Response<Incoming> {
let stream = tokio::net::TcpStream::connect(addr).await.unwrap();
let (mut sender, conn) = hyper::client::conn::http1::Builder::new()
.handshake(TokioIo::new(stream))
.await
.unwrap();
tokio::spawn(async move {
let _ = conn.await;
});

let mut builder = Request::builder()
.method("GET")
.uri(format!("http://{addr}{path}"));
for (k, v) in headers {
builder = builder.header(*k, *v);
}
let req = builder.body(Empty::<Bytes>::new()).unwrap();
sender.send_request(req).await.unwrap()
}

#[tokio::test]
async fn http_response_includes_request_id() {
let addr = start_http_server_with_middleware().await;
let resp = http1_get(addr, "/healthz", &[]).await;
assert_eq!(resp.status(), 200);

let request_id = resp
.headers()
.get("x-request-id")
.expect("response should include x-request-id header");
let id_str = request_id.to_str().unwrap();
uuid::Uuid::parse_str(id_str).expect("should be a valid UUID");
}

#[tokio::test]
async fn http_preserves_client_request_id() {
let addr = start_http_server_with_middleware().await;
let client_id = "my-custom-correlation-id";
let resp = http1_get(addr, "/healthz", &[("x-request-id", client_id)]).await;
assert_eq!(resp.status(), 200);

let request_id = resp
.headers()
.get("x-request-id")
.expect("response should include x-request-id header");
assert_eq!(request_id.to_str().unwrap(), client_id);
}

#[tokio::test]
async fn each_request_gets_unique_id() {
let addr = start_http_server_with_middleware().await;

let mut ids = Vec::new();
for _ in 0..3 {
let resp = http1_get(addr, "/healthz", &[]).await;
let id = resp
.headers()
.get("x-request-id")
.unwrap()
.to_str()
.unwrap()
.to_string();
ids.push(id);
}

assert_ne!(ids[0], ids[1]);
assert_ne!(ids[1], ids[2]);
assert_ne!(ids[0], ids[2]);
}

#[tokio::test]
async fn grpc_path_includes_request_id() {
let addr = start_http_server_with_middleware().await;
let resp = http1_get(
addr,
"/openshell.v1.OpenShell/Health",
&[
("content-type", "application/grpc"),
("x-request-id", "grpc-corr-id"),
],
)
.await;

let request_id = resp
.headers()
.get("x-request-id")
.expect("gRPC-routed response should include x-request-id header");
assert_eq!(request_id.to_str().unwrap(), "grpc-corr-id");
}

#[derive(Clone)]
struct TraceBuf(Arc<Mutex<Vec<u8>>>);

impl std::io::Write for TraceBuf {
fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
self.0.lock().unwrap().extend_from_slice(buf);
Ok(buf.len())
}

fn flush(&mut self) -> std::io::Result<()> {
Ok(())
}
}

#[test]
fn request_id_appears_in_trace_span() {
use tracing_subscriber::fmt::format::FmtSpan;
use tracing_subscriber::layer::SubscriberExt;

let log_buf: Arc<Mutex<Vec<u8>>> = Arc::new(Mutex::new(Vec::new()));
let writer = TraceBuf(log_buf.clone());

let fmt_layer = tracing_subscriber::fmt::layer()
.with_writer(move || writer.clone())
.with_ansi(false)
.with_span_events(FmtSpan::CLOSE);

let subscriber = tracing_subscriber::registry().with(fmt_layer);
let _guard = tracing::subscriber::set_default(subscriber);

let req = Request::builder()
.uri("/test-path")
.header("x-request-id", "trace-test-id-12345")
.body(Empty::<Bytes>::new())
.unwrap();
let span = make_request_span(&req);
drop(span.enter());
drop(span);

let output = String::from_utf8(log_buf.lock().unwrap().clone()).unwrap();
assert!(
output.contains("trace-test-id-12345"),
"trace output should contain the request_id recorded in the span, got: {output}"
);
}

#[test]
fn grpc_method_extracts_last_segment() {
Expand Down
Loading
Loading