Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/proto/h2/upgrade.rs
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ where

if me.h2_tx.capacity() == 0 {
// poll_capacity oddly needs a loop
'capacity: loop {
loop {
match me.h2_tx.poll_capacity(cx) {
Poll::Ready(Some(Ok(0))) => {}
Poll::Ready(Some(Ok(_))) => break,
Expand All @@ -95,7 +95,7 @@ where
"send stream capacity unexpectedly closed",
)));
}
Poll::Pending => break 'capacity,
Poll::Pending => return Poll::Pending,
}
}
}
Expand Down
277 changes: 277 additions & 0 deletions tests/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2300,6 +2300,283 @@ async fn h2_connect_empty_frames() {
.unwrap();
}

#[tokio::test]
async fn h2_connect_backpressure_respected() {
let (listener, addr) = setup_tcp_listener();
let conn = connect_async(addr).await;

let mut builder = h2::client::Builder::new();
builder.initial_window_size(1024);
builder.initial_connection_window_size(1024);
let (h2, connection) = builder.handshake::<_, Bytes>(conn).await.unwrap();
tokio::spawn(async move {
connection.await.unwrap();
});
let mut h2 = h2.ready().await.unwrap();

const CHUNK: &[u8] = b"backpressure test data chunk!\n";
const TOTAL_LEN: usize = CHUNK.len() * 2000;

let client_handle = tokio::spawn(async move {
let request = Request::connect("localhost").body(()).unwrap();
let (response, _send_stream) = h2.send_request(request, false).unwrap();
let response = response.await.unwrap();
assert_eq!(response.status(), StatusCode::OK);

let mut body = response.into_body();
let mut received = 0usize;

while let Some(chunk) = body.data().await {
let chunk = chunk.unwrap();
if chunk.is_empty() {
break;
}
let len = chunk.len();
received += len;
let _ = body.flow_control().release_capacity(len);
}

assert_eq!(received, TOTAL_LEN);
});

let svc = service_fn(move |req: Request<IncomingBody>| {
let on_upgrade = hyper::upgrade::on(req);

tokio::spawn(async move {
let mut upgraded = TokioIo::new(on_upgrade.await.expect("on_upgrade"));

for _ in 0..2000 {
upgraded.write_all(CHUNK).await.unwrap();
}

upgraded.shutdown().await.unwrap();
});

future::ok::<_, hyper::Error>(
Response::builder()
.status(200)
.body(Empty::<Bytes>::new())
.unwrap(),
)
});

let (socket, _) = listener.accept().await.unwrap();
let socket = TokioIo::new(socket);
http2::Builder::new(TokioExecutor)
.serve_connection(socket, svc)
.await
.unwrap();

client_handle.await.unwrap();
}

#[tokio::test]
async fn h2_connect_zero_window_then_release() {
let (listener, addr) = setup_tcp_listener();
let conn = connect_async(addr).await;

let mut builder = h2::client::Builder::new();
builder.initial_window_size(65535);
let (h2, connection) = builder.handshake::<_, Bytes>(conn).await.unwrap();
tokio::spawn(async move {
connection.await.unwrap();
});
let mut h2 = h2.ready().await.unwrap();

const DATA: &[u8] = b"Hello from upgraded stream";

let client_handle = tokio::spawn(async move {
let request = Request::connect("localhost").body(()).unwrap();
let (response, _send_stream) = h2.send_request(request, false).unwrap();
let response = response.await.unwrap();
assert_eq!(response.status(), StatusCode::OK);

let mut body = response.into_body();
let mut received = Vec::new();

while let Some(chunk) = body.data().await {
let chunk = chunk.unwrap();
if chunk.is_empty() {
break;
}
let len = chunk.len();
received.extend_from_slice(&chunk);
let _ = body.flow_control().release_capacity(len);
}

assert_eq!(&received[..], DATA);
});

let svc = service_fn(move |req: Request<IncomingBody>| {
let on_upgrade = hyper::upgrade::on(req);

tokio::spawn(async move {
let mut upgraded = TokioIo::new(on_upgrade.await.expect("on_upgrade"));
upgraded.write_all(DATA).await.unwrap();
upgraded.shutdown().await.unwrap();
});

future::ok::<_, hyper::Error>(
Response::builder()
.status(200)
.body(Empty::<Bytes>::new())
.unwrap(),
)
});

let (socket, _) = listener.accept().await.unwrap();
let socket = TokioIo::new(socket);
http2::Builder::new(TokioExecutor)
.serve_connection(socket, svc)
.await
.unwrap();

client_handle.await.unwrap();
}

#[tokio::test]
async fn h2_connect_reset_during_backpressure() {
let (listener, addr) = setup_tcp_listener();
let conn = connect_async(addr).await;

let mut builder = h2::client::Builder::new();
builder.initial_window_size(1024);
builder.initial_connection_window_size(1024);
let (h2, connection) = builder.handshake::<_, Bytes>(conn).await.unwrap();
tokio::spawn(async move {
let _ = connection.await;
});
let mut h2 = h2.ready().await.unwrap();

let (write_err_tx, write_err_rx) = oneshot::channel::<bool>();
let write_err_tx = Arc::new(Mutex::new(Some(write_err_tx)));

tokio::spawn(async move {
let request = Request::connect("localhost").body(()).unwrap();
let (response, mut send_stream) = h2.send_request(request, false).unwrap();
let response = response.await.unwrap();
assert_eq!(response.status(), StatusCode::OK);

let mut body = response.into_body();
let bytes = body.data().await.unwrap().unwrap();
let _ = body.flow_control().release_capacity(bytes.len());

send_stream.send_reset(h2::Reason::CANCEL);
drop(body);
drop(send_stream);

let got_err = write_err_rx.await.unwrap_or(false);
assert!(got_err, "server write should have failed after RST_STREAM");
});

let svc = service_fn(move |req: Request<IncomingBody>| {
let on_upgrade = hyper::upgrade::on(req);
let write_err_tx = write_err_tx.clone();

tokio::spawn(async move {
let mut upgraded = TokioIo::new(on_upgrade.await.expect("on_upgrade"));
upgraded.write_all(b"initial").await.unwrap();

let large_data = vec![b'x'; 1024 * 1024];
let write_result = upgraded.write_all(&large_data).await;

if let Some(tx) = write_err_tx.lock().unwrap().take() {
let _ = tx.send(write_result.is_err());
}
});

future::ok::<_, hyper::Error>(
Response::builder()
.status(200)
.body(Empty::<Bytes>::new())
.unwrap(),
)
});

let (socket, _) = listener.accept().await.unwrap();
let socket = TokioIo::new(socket);
let _ = http2::Builder::new(TokioExecutor)
.serve_connection(socket, svc)
.await;
}

#[tokio::test]
async fn h2_connect_backpressure_bidirectional() {
let (listener, addr) = setup_tcp_listener();
let conn = connect_async(addr).await;

let mut builder = h2::client::Builder::new();
builder.initial_window_size(2048);
builder.initial_connection_window_size(4096);
let (h2, connection) = builder.handshake::<_, Bytes>(conn).await.unwrap();
tokio::spawn(async move {
connection.await.unwrap();
});
let mut h2 = h2.ready().await.unwrap();

const PATTERN: &[u8] = b"All work and no bread makes nox a dull boy.\n";
const REPEAT: usize = 500;
let expected_len = PATTERN.len() * REPEAT;

let client_handle = tokio::spawn(async move {
let request = Request::connect("localhost").body(()).unwrap();
let (response, mut send_stream) = h2.send_request(request, false).unwrap();
let response = response.await.unwrap();
assert_eq!(response.status(), StatusCode::OK);

let mut body = response.into_body();
let mut received = 0usize;

while let Some(chunk) = body.data().await {
let chunk = chunk.unwrap();
if chunk.is_empty() {
break;
}
let len = chunk.len();
received += len;
let _ = body.flow_control().release_capacity(len);
}

assert_eq!(received, expected_len);

send_stream.send_data("client done".into(), true).unwrap();
});

let svc = service_fn(move |req: Request<IncomingBody>| {
let on_upgrade = hyper::upgrade::on(req);

tokio::spawn(async move {
let mut upgraded = TokioIo::new(on_upgrade.await.expect("on_upgrade"));

for _ in 0..REPEAT {
upgraded.write_all(PATTERN).await.unwrap();
}

upgraded.shutdown().await.unwrap();

let mut response_buf = vec![0u8; 64];
let n = upgraded.read(&mut response_buf).await.unwrap();
assert_eq!(&response_buf[..n], b"client done");
});

future::ok::<_, hyper::Error>(
Response::builder()
.status(200)
.body(Empty::<Bytes>::new())
.unwrap(),
)
});

let (socket, _) = listener.accept().await.unwrap();
let socket = TokioIo::new(socket);
http2::Builder::new(TokioExecutor)
.serve_connection(socket, svc)
.await
.unwrap();

client_handle.await.unwrap();
}

#[tokio::test]
async fn parse_errors_send_4xx_response() {
let (listener, addr) = setup_tcp_listener();
Expand Down
Loading