diff --git a/src/proto/h2/upgrade.rs b/src/proto/h2/upgrade.rs index 40a98de08a..735cbfd9dd 100644 --- a/src/proto/h2/upgrade.rs +++ b/src/proto/h2/upgrade.rs @@ -80,7 +80,7 @@ where if me.h2_tx.capacity() == 0 { // poll_capacity oddly needs a loop - 'capacity: loop { + loop { match me.h2_tx.poll_capacity(cx) { Poll::Ready(Some(Ok(0))) => {} Poll::Ready(Some(Ok(_))) => break, @@ -95,7 +95,7 @@ where "send stream capacity unexpectedly closed", ))); } - Poll::Pending => break 'capacity, + Poll::Pending => return Poll::Pending, } } } diff --git a/tests/server.rs b/tests/server.rs index 651fbdf40d..14cc976163 100644 --- a/tests/server.rs +++ b/tests/server.rs @@ -2300,6 +2300,283 @@ async fn h2_connect_empty_frames() { .unwrap(); } +#[tokio::test] +async fn h2_connect_backpressure_respected() { + let (listener, addr) = setup_tcp_listener(); + let conn = connect_async(addr).await; + + let mut builder = h2::client::Builder::new(); + builder.initial_window_size(1024); + builder.initial_connection_window_size(1024); + let (h2, connection) = builder.handshake::<_, Bytes>(conn).await.unwrap(); + tokio::spawn(async move { + connection.await.unwrap(); + }); + let mut h2 = h2.ready().await.unwrap(); + + const CHUNK: &[u8] = b"backpressure test data chunk!\n"; + const TOTAL_LEN: usize = CHUNK.len() * 2000; + + let client_handle = tokio::spawn(async move { + let request = Request::connect("localhost").body(()).unwrap(); + let (response, _send_stream) = h2.send_request(request, false).unwrap(); + let response = response.await.unwrap(); + assert_eq!(response.status(), StatusCode::OK); + + let mut body = response.into_body(); + let mut received = 0usize; + + while let Some(chunk) = body.data().await { + let chunk = chunk.unwrap(); + if chunk.is_empty() { + break; + } + let len = chunk.len(); + received += len; + let _ = body.flow_control().release_capacity(len); + } + + assert_eq!(received, TOTAL_LEN); + }); + + let svc = service_fn(move |req: Request| { + let on_upgrade = hyper::upgrade::on(req); + + tokio::spawn(async move { + let mut upgraded = TokioIo::new(on_upgrade.await.expect("on_upgrade")); + + for _ in 0..2000 { + upgraded.write_all(CHUNK).await.unwrap(); + } + + upgraded.shutdown().await.unwrap(); + }); + + future::ok::<_, hyper::Error>( + Response::builder() + .status(200) + .body(Empty::::new()) + .unwrap(), + ) + }); + + let (socket, _) = listener.accept().await.unwrap(); + let socket = TokioIo::new(socket); + http2::Builder::new(TokioExecutor) + .serve_connection(socket, svc) + .await + .unwrap(); + + client_handle.await.unwrap(); +} + +#[tokio::test] +async fn h2_connect_zero_window_then_release() { + let (listener, addr) = setup_tcp_listener(); + let conn = connect_async(addr).await; + + let mut builder = h2::client::Builder::new(); + builder.initial_window_size(65535); + let (h2, connection) = builder.handshake::<_, Bytes>(conn).await.unwrap(); + tokio::spawn(async move { + connection.await.unwrap(); + }); + let mut h2 = h2.ready().await.unwrap(); + + const DATA: &[u8] = b"Hello from upgraded stream"; + + let client_handle = tokio::spawn(async move { + let request = Request::connect("localhost").body(()).unwrap(); + let (response, _send_stream) = h2.send_request(request, false).unwrap(); + let response = response.await.unwrap(); + assert_eq!(response.status(), StatusCode::OK); + + let mut body = response.into_body(); + let mut received = Vec::new(); + + while let Some(chunk) = body.data().await { + let chunk = chunk.unwrap(); + if chunk.is_empty() { + break; + } + let len = chunk.len(); + received.extend_from_slice(&chunk); + let _ = body.flow_control().release_capacity(len); + } + + assert_eq!(&received[..], DATA); + }); + + let svc = service_fn(move |req: Request| { + let on_upgrade = hyper::upgrade::on(req); + + tokio::spawn(async move { + let mut upgraded = TokioIo::new(on_upgrade.await.expect("on_upgrade")); + upgraded.write_all(DATA).await.unwrap(); + upgraded.shutdown().await.unwrap(); + }); + + future::ok::<_, hyper::Error>( + Response::builder() + .status(200) + .body(Empty::::new()) + .unwrap(), + ) + }); + + let (socket, _) = listener.accept().await.unwrap(); + let socket = TokioIo::new(socket); + http2::Builder::new(TokioExecutor) + .serve_connection(socket, svc) + .await + .unwrap(); + + client_handle.await.unwrap(); +} + +#[tokio::test] +async fn h2_connect_reset_during_backpressure() { + let (listener, addr) = setup_tcp_listener(); + let conn = connect_async(addr).await; + + let mut builder = h2::client::Builder::new(); + builder.initial_window_size(1024); + builder.initial_connection_window_size(1024); + let (h2, connection) = builder.handshake::<_, Bytes>(conn).await.unwrap(); + tokio::spawn(async move { + let _ = connection.await; + }); + let mut h2 = h2.ready().await.unwrap(); + + let (write_err_tx, write_err_rx) = oneshot::channel::(); + let write_err_tx = Arc::new(Mutex::new(Some(write_err_tx))); + + tokio::spawn(async move { + let request = Request::connect("localhost").body(()).unwrap(); + let (response, mut send_stream) = h2.send_request(request, false).unwrap(); + let response = response.await.unwrap(); + assert_eq!(response.status(), StatusCode::OK); + + let mut body = response.into_body(); + let bytes = body.data().await.unwrap().unwrap(); + let _ = body.flow_control().release_capacity(bytes.len()); + + send_stream.send_reset(h2::Reason::CANCEL); + drop(body); + drop(send_stream); + + let got_err = write_err_rx.await.unwrap_or(false); + assert!(got_err, "server write should have failed after RST_STREAM"); + }); + + let svc = service_fn(move |req: Request| { + let on_upgrade = hyper::upgrade::on(req); + let write_err_tx = write_err_tx.clone(); + + tokio::spawn(async move { + let mut upgraded = TokioIo::new(on_upgrade.await.expect("on_upgrade")); + upgraded.write_all(b"initial").await.unwrap(); + + let large_data = vec![b'x'; 1024 * 1024]; + let write_result = upgraded.write_all(&large_data).await; + + if let Some(tx) = write_err_tx.lock().unwrap().take() { + let _ = tx.send(write_result.is_err()); + } + }); + + future::ok::<_, hyper::Error>( + Response::builder() + .status(200) + .body(Empty::::new()) + .unwrap(), + ) + }); + + let (socket, _) = listener.accept().await.unwrap(); + let socket = TokioIo::new(socket); + let _ = http2::Builder::new(TokioExecutor) + .serve_connection(socket, svc) + .await; +} + +#[tokio::test] +async fn h2_connect_backpressure_bidirectional() { + let (listener, addr) = setup_tcp_listener(); + let conn = connect_async(addr).await; + + let mut builder = h2::client::Builder::new(); + builder.initial_window_size(2048); + builder.initial_connection_window_size(4096); + let (h2, connection) = builder.handshake::<_, Bytes>(conn).await.unwrap(); + tokio::spawn(async move { + connection.await.unwrap(); + }); + let mut h2 = h2.ready().await.unwrap(); + + const PATTERN: &[u8] = b"All work and no bread makes nox a dull boy.\n"; + const REPEAT: usize = 500; + let expected_len = PATTERN.len() * REPEAT; + + let client_handle = tokio::spawn(async move { + let request = Request::connect("localhost").body(()).unwrap(); + let (response, mut send_stream) = h2.send_request(request, false).unwrap(); + let response = response.await.unwrap(); + assert_eq!(response.status(), StatusCode::OK); + + let mut body = response.into_body(); + let mut received = 0usize; + + while let Some(chunk) = body.data().await { + let chunk = chunk.unwrap(); + if chunk.is_empty() { + break; + } + let len = chunk.len(); + received += len; + let _ = body.flow_control().release_capacity(len); + } + + assert_eq!(received, expected_len); + + send_stream.send_data("client done".into(), true).unwrap(); + }); + + let svc = service_fn(move |req: Request| { + let on_upgrade = hyper::upgrade::on(req); + + tokio::spawn(async move { + let mut upgraded = TokioIo::new(on_upgrade.await.expect("on_upgrade")); + + for _ in 0..REPEAT { + upgraded.write_all(PATTERN).await.unwrap(); + } + + upgraded.shutdown().await.unwrap(); + + let mut response_buf = vec![0u8; 64]; + let n = upgraded.read(&mut response_buf).await.unwrap(); + assert_eq!(&response_buf[..n], b"client done"); + }); + + future::ok::<_, hyper::Error>( + Response::builder() + .status(200) + .body(Empty::::new()) + .unwrap(), + ) + }); + + let (socket, _) = listener.accept().await.unwrap(); + let socket = TokioIo::new(socket); + http2::Builder::new(TokioExecutor) + .serve_connection(socket, svc) + .await + .unwrap(); + + client_handle.await.unwrap(); +} + #[tokio::test] async fn parse_errors_send_4xx_response() { let (listener, addr) = setup_tcp_listener();