Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
98 changes: 98 additions & 0 deletions crates/rmcp/src/transport/streamable_http_server/tower.rs
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,15 @@ pub struct StreamableHttpServerConfig {
/// or with ports:
/// allowed_hosts = ["example.com", "example.com:8080"]
pub allowed_hosts: Vec<String>,
/// Allowed browser origins for inbound `Origin` validation.
///
/// Defaults to an empty list, which disables Origin validation. When
/// non-empty, requests carrying an `Origin` header must match per RFC 6454
/// `(scheme, host, port)`; missing-`Origin` requests still pass. Entries
/// must include a scheme; `"null"` matches the browser's `Origin: null`.
/// examples:
/// allowed_origins = ["https://app.example.com", "http://localhost:8080"]
pub allowed_origins: Vec<String>,
/// Optional external session store for cross-instance recovery.
///
/// When set, [`SessionState`] (the client's `initialize` parameters) is
Expand Down Expand Up @@ -103,6 +112,7 @@ impl Default for StreamableHttpServerConfig {
json_response: false,
cancellation_token: CancellationToken::new(),
allowed_hosts: vec!["localhost".into(), "127.0.0.1".into(), "::1".into()],
allowed_origins: vec![],
session_store: None,
}
}
Expand All @@ -121,6 +131,18 @@ impl StreamableHttpServerConfig {
self.allowed_hosts.clear();
self
}
pub fn with_allowed_origins(
mut self,
allowed_origins: impl IntoIterator<Item = impl Into<String>>,
) -> Self {
self.allowed_origins = allowed_origins.into_iter().map(Into::into).collect();
self
}
/// Disable Origin validation, reverting to the default ignore-Origin behavior.
pub fn disable_allowed_origins(mut self) -> Self {
self.allowed_origins.clear();
self
}
pub fn with_sse_keep_alive(mut self, duration: Option<Duration>) -> Self {
self.sse_keep_alive = duration;
self
Expand Down Expand Up @@ -243,6 +265,59 @@ fn host_is_allowed(host: &NormalizedAuthority, allowed_hosts: &[String]) -> bool
})
}

#[derive(Debug, Clone, PartialEq, Eq)]
enum NormalizedOrigin {
Null,
Tuple {
scheme: String,
host: String,
port: Option<u16>,
},
}

fn parse_origin_value(value: &str) -> Option<NormalizedOrigin> {
let value = value.trim();
if value.is_empty() {
return None;
}
if value.eq_ignore_ascii_case("null") {
return Some(NormalizedOrigin::Null);
}
let uri = http::Uri::try_from(value).ok()?;
let scheme = uri.scheme_str()?.to_ascii_lowercase();
let authority = uri.authority()?;
Some(NormalizedOrigin::Tuple {
scheme,
host: normalize_host(authority.host()),
port: authority.port_u16(),
})
}

fn origin_is_allowed(origin: &NormalizedOrigin, allowed_origins: &[String]) -> bool {
if allowed_origins.is_empty() {
return true;
}
allowed_origins
.iter()
.filter_map(|raw| parse_origin_value(raw))
.any(|allowed| match (&allowed, origin) {
(NormalizedOrigin::Null, NormalizedOrigin::Null) => true,
(
NormalizedOrigin::Tuple {
scheme: a_scheme,
host: a_host,
port: a_port,
},
NormalizedOrigin::Tuple {
scheme: o_scheme,
host: o_host,
port: o_port,
},
) => a_scheme == o_scheme && a_host == o_host && (a_port.is_none() || a_port == o_port),
_ => false,
})
}

fn bad_request_response(message: &str) -> BoxResponse {
let body = Full::from(message.to_string()).boxed();

Expand Down Expand Up @@ -274,7 +349,30 @@ fn validate_dns_rebinding_headers(
if !host_is_allowed(&host, &config.allowed_hosts) {
return Err(forbidden_response("Forbidden: Host header is not allowed"));
}
validate_origin_header(headers, &config.allowed_origins)?;
Ok(())
}

fn validate_origin_header(
headers: &HeaderMap,
allowed_origins: &[String],
) -> Result<(), BoxResponse> {
if allowed_origins.is_empty() {
return Ok(());
}
let Some(origin_header) = headers.get(http::header::ORIGIN) else {
return Ok(());
};
let origin_str = origin_header
.to_str()
.map_err(|_| bad_request_response("Bad Request: Invalid Origin header encoding"))?;
let origin = parse_origin_value(origin_str)
.ok_or_else(|| bad_request_response("Bad Request: Invalid Origin header"))?;
if !origin_is_allowed(&origin, allowed_origins) {
return Err(forbidden_response(
"Forbidden: Origin header is not allowed",
));
}
Ok(())
}

Expand Down
108 changes: 108 additions & 0 deletions crates/rmcp/tests/test_custom_headers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1030,3 +1030,111 @@ async fn test_server_validates_host_header_port_for_dns_rebinding_protection() {
let response = service.handle(wrong_port_request).await;
assert_eq!(response.status(), http::StatusCode::FORBIDDEN);
}

#[cfg(all(feature = "transport-streamable-http-server", feature = "server"))]
mod origin_validation {
use std::sync::Arc;

use bytes::Bytes;
use http::{Method, Request, header::CONTENT_TYPE};
use http_body_util::Full;
use rmcp::{
handler::server::ServerHandler,
model::{ServerCapabilities, ServerInfo},
transport::streamable_http_server::{
StreamableHttpServerConfig, StreamableHttpService, session::local::LocalSessionManager,
},
};
use serde_json::json;

#[derive(Clone)]
struct TestHandler;

impl ServerHandler for TestHandler {
fn get_info(&self) -> ServerInfo {
ServerInfo::new(ServerCapabilities::builder().build())
}
}

fn service_with_allowed_origins(
origins: &[&str],
) -> StreamableHttpService<TestHandler, LocalSessionManager> {
StreamableHttpService::new(
|| Ok(TestHandler),
Arc::new(LocalSessionManager::default()),
StreamableHttpServerConfig::default().with_allowed_origins(origins.iter().copied()),
)
}

fn init_request(origin: Option<&str>) -> Request<Full<Bytes>> {
let init_body = json!({
"jsonrpc": "2.0",
"id": 1,
"method": "initialize",
"params": {
"protocolVersion": "2025-03-26",
"capabilities": {},
"clientInfo": {"name": "test-client", "version": "1.0.0"}
}
});
let mut builder = Request::builder()
.method(Method::POST)
.header("Accept", "application/json, text/event-stream")
.header(CONTENT_TYPE, "application/json")
.header("Host", "localhost:8080");
if let Some(origin) = origin {
builder = builder.header("Origin", origin);
}
builder
.body(Full::new(Bytes::from(init_body.to_string())))
.unwrap()
}

#[tokio::test]
async fn allowlisted_origin_is_allowed() {
let service = service_with_allowed_origins(&["http://localhost:8080"]);
let response = service
.handle(init_request(Some("http://localhost:8080")))
.await;
assert_eq!(response.status(), http::StatusCode::OK);
}

#[tokio::test]
async fn non_allowlisted_origin_is_forbidden() {
let service = service_with_allowed_origins(&["http://localhost:8080"]);
let response = service
.handle(init_request(Some("http://attacker.example")))
.await;
assert_eq!(response.status(), http::StatusCode::FORBIDDEN);
}

#[tokio::test]
async fn missing_origin_passes_through() {
let service = service_with_allowed_origins(&["http://localhost:8080"]);
let response = service.handle(init_request(None)).await;
assert_eq!(response.status(), http::StatusCode::OK);
}

#[tokio::test]
async fn scheme_mismatch_is_forbidden() {
let service = service_with_allowed_origins(&["http://localhost:8080"]);
let response = service
.handle(init_request(Some("https://localhost:8080")))
.await;
assert_eq!(response.status(), http::StatusCode::FORBIDDEN);
}

#[tokio::test]
async fn null_origin_is_allowed_when_allowlisted() {
let service = service_with_allowed_origins(&["null"]);
let response = service.handle(init_request(Some("null"))).await;
assert_eq!(response.status(), http::StatusCode::OK);
}

#[tokio::test]
async fn null_origin_is_forbidden_when_not_allowlisted() {
let service = service_with_allowed_origins(&["http://localhost:8080"]);
let response = service.handle(init_request(Some("null"))).await;
assert_eq!(response.status(), http::StatusCode::FORBIDDEN);
}
}
Loading