diff --git a/Cargo.lock b/Cargo.lock index e9584eae852b..cfe56f7345d8 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -4846,6 +4846,7 @@ dependencies = [ "memchr", "num_cpus", "object 0.39.0", + "pin-project-lite", "pulley-interpreter", "rand 0.10.1", "rayon", diff --git a/Cargo.toml b/Cargo.toml index f994c65e4519..2f062c257cb4 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -91,6 +91,7 @@ hyper = { workspace = true, optional = true } http = { workspace = true, optional = true } http-body-util = { workspace = true, optional = true } futures = { workspace = true, optional = true } +pin-project-lite = { workspace = true, optional = true } [target.'cfg(unix)'.dependencies] rustix = { workspace = true, features = ["mm", "process"] } @@ -596,6 +597,7 @@ serve = [ "component-model", "dep:http-body-util", "dep:http", + "dep:pin-project-lite", "wasmtime-cli-flags/async", "wasmtime-wasi-http?/p2", ] diff --git a/crates/wasi-http/src/handler.rs b/crates/wasi-http/src/handler.rs index 1ea5244cbe2b..e7020eed232f 100644 --- a/crates/wasi-http/src/handler.rs +++ b/crates/wasi-http/src/handler.rs @@ -1,26 +1,278 @@ //! Provides utilities useful for dispatching incoming HTTP requests //! `wasi:http/handler` guest instances. +use crate::p2::bindings::http::types as p2_types; #[cfg(feature = "p3")] use crate::p3; -use futures::stream::{FuturesUnordered, Stream}; +use bytes::Bytes; +use futures::{ + channel::oneshot, + future::{Either, FutureExt}, + stream::{FuturesUnordered, Stream}, +}; +use http_body_util::{BodyExt, combinators::UnsyncBoxBody}; +use p3::bindings::http::types as p3_types; use std::collections::VecDeque; use std::collections::btree_map::{BTreeMap, Entry}; +use std::error; +use std::fmt; use std::future; +use std::mem; +use std::ops::DerefMut; use std::pin::{Pin, pin}; use std::sync::{ Arc, Mutex, - atomic::{ - AtomicBool, AtomicU64, AtomicUsize, - Ordering::{Relaxed, SeqCst}, - }, + atomic::{AtomicBool, AtomicUsize, Ordering::Relaxed}, }; -use std::task::Poll; -use std::time::{Duration, Instant}; +use std::task::{Context, Poll}; +use std::time::Instant; use tokio::sync::Notify; -use wasmtime::AsContextMut; use wasmtime::component::Accessor; -use wasmtime::{Result, Store, StoreContextMut, format_err}; +use wasmtime::error::Context as _; +use wasmtime::{AsContextMut, Result, Store, format_err}; + +/// Represents either a `wasi:http/types@0.2.x` or `wasi:http/types@0.3.x` `error-code`. +pub enum ErrorCode { + /// A `wasi:http/types@0.2.x` `error-code`. + #[cfg(feature = "p2")] + P2(p2_types::ErrorCode), + /// A `wasi:http/types@0.3.x` `error-code`. + #[cfg(feature = "p3")] + P3(p3_types::ErrorCode), +} + +impl From for ErrorCode { + fn from(code: p2_types::ErrorCode) -> Self { + Self::P2(code) + } +} + +impl From for ErrorCode { + fn from(code: p3_types::ErrorCode) -> Self { + Self::P3(code) + } +} + +impl From for p2_types::ErrorCode { + fn from(code: ErrorCode) -> p2_types::ErrorCode { + match code { + ErrorCode::P2(code) => code, + ErrorCode::P3(code) => code.into(), + } + } +} + +impl From for p3_types::ErrorCode { + fn from(code: ErrorCode) -> p3_types::ErrorCode { + match code { + ErrorCode::P2(code) => code.into(), + ErrorCode::P3(code) => code, + } + } +} + +impl From for p3_types::ErrorCode { + fn from(code: p2_types::ErrorCode) -> Self { + match code { + p2_types::ErrorCode::DnsTimeout => Self::DnsTimeout, + p2_types::ErrorCode::DnsError(payload) => Self::DnsError(p3_types::DnsErrorPayload { + rcode: payload.rcode, + info_code: payload.info_code, + }), + p2_types::ErrorCode::DestinationNotFound => Self::DestinationNotFound, + p2_types::ErrorCode::DestinationUnavailable => Self::DestinationUnavailable, + p2_types::ErrorCode::DestinationIpProhibited => Self::DestinationIpProhibited, + p2_types::ErrorCode::DestinationIpUnroutable => Self::DestinationIpUnroutable, + p2_types::ErrorCode::ConnectionRefused => Self::ConnectionRefused, + p2_types::ErrorCode::ConnectionTerminated => Self::ConnectionTerminated, + p2_types::ErrorCode::ConnectionTimeout => Self::ConnectionTimeout, + p2_types::ErrorCode::ConnectionReadTimeout => Self::ConnectionReadTimeout, + p2_types::ErrorCode::ConnectionWriteTimeout => Self::ConnectionWriteTimeout, + p2_types::ErrorCode::ConnectionLimitReached => Self::ConnectionLimitReached, + p2_types::ErrorCode::TlsProtocolError => Self::TlsProtocolError, + p2_types::ErrorCode::TlsCertificateError => Self::TlsCertificateError, + p2_types::ErrorCode::TlsAlertReceived(payload) => { + Self::TlsAlertReceived(p3_types::TlsAlertReceivedPayload { + alert_id: payload.alert_id, + alert_message: payload.alert_message, + }) + } + p2_types::ErrorCode::HttpRequestDenied => Self::HttpRequestDenied, + p2_types::ErrorCode::HttpRequestLengthRequired => Self::HttpRequestLengthRequired, + p2_types::ErrorCode::HttpRequestBodySize(payload) => Self::HttpRequestBodySize(payload), + p2_types::ErrorCode::HttpRequestMethodInvalid => Self::HttpRequestMethodInvalid, + p2_types::ErrorCode::HttpRequestUriInvalid => Self::HttpRequestUriInvalid, + p2_types::ErrorCode::HttpRequestUriTooLong => Self::HttpRequestUriTooLong, + p2_types::ErrorCode::HttpRequestHeaderSectionSize(payload) => { + Self::HttpRequestHeaderSectionSize(payload) + } + p2_types::ErrorCode::HttpRequestHeaderSize(payload) => { + Self::HttpRequestHeaderSize(payload.map(|payload| p3_types::FieldSizePayload { + field_name: payload.field_name, + field_size: payload.field_size, + })) + } + p2_types::ErrorCode::HttpRequestTrailerSectionSize(payload) => { + Self::HttpRequestTrailerSectionSize(payload) + } + p2_types::ErrorCode::HttpRequestTrailerSize(payload) => { + Self::HttpRequestTrailerSize(p3_types::FieldSizePayload { + field_name: payload.field_name, + field_size: payload.field_size, + }) + } + p2_types::ErrorCode::HttpResponseIncomplete => Self::HttpResponseIncomplete, + p2_types::ErrorCode::HttpResponseHeaderSectionSize(payload) => { + Self::HttpResponseHeaderSectionSize(payload) + } + p2_types::ErrorCode::HttpResponseHeaderSize(payload) => { + Self::HttpResponseHeaderSize(p3_types::FieldSizePayload { + field_name: payload.field_name, + field_size: payload.field_size, + }) + } + p2_types::ErrorCode::HttpResponseBodySize(payload) => { + Self::HttpResponseBodySize(payload) + } + p2_types::ErrorCode::HttpResponseTrailerSectionSize(payload) => { + Self::HttpResponseTrailerSectionSize(payload) + } + p2_types::ErrorCode::HttpResponseTrailerSize(payload) => { + Self::HttpResponseTrailerSize(p3_types::FieldSizePayload { + field_name: payload.field_name, + field_size: payload.field_size, + }) + } + p2_types::ErrorCode::HttpResponseTransferCoding(payload) => { + Self::HttpResponseTransferCoding(payload) + } + p2_types::ErrorCode::HttpResponseContentCoding(payload) => { + Self::HttpResponseContentCoding(payload) + } + p2_types::ErrorCode::HttpResponseTimeout => Self::HttpResponseTimeout, + p2_types::ErrorCode::HttpUpgradeFailed => Self::HttpUpgradeFailed, + p2_types::ErrorCode::HttpProtocolError => Self::HttpProtocolError, + p2_types::ErrorCode::LoopDetected => Self::LoopDetected, + p2_types::ErrorCode::ConfigurationError => Self::ConfigurationError, + p2_types::ErrorCode::InternalError(payload) => Self::InternalError(payload), + } + } +} + +impl From for p2_types::ErrorCode { + fn from(code: p3_types::ErrorCode) -> Self { + match code { + p3_types::ErrorCode::DnsTimeout => Self::DnsTimeout, + p3_types::ErrorCode::DnsError(payload) => Self::DnsError(p2_types::DnsErrorPayload { + rcode: payload.rcode, + info_code: payload.info_code, + }), + p3_types::ErrorCode::DestinationNotFound => Self::DestinationNotFound, + p3_types::ErrorCode::DestinationUnavailable => Self::DestinationUnavailable, + p3_types::ErrorCode::DestinationIpProhibited => Self::DestinationIpProhibited, + p3_types::ErrorCode::DestinationIpUnroutable => Self::DestinationIpUnroutable, + p3_types::ErrorCode::ConnectionRefused => Self::ConnectionRefused, + p3_types::ErrorCode::ConnectionTerminated => Self::ConnectionTerminated, + p3_types::ErrorCode::ConnectionTimeout => Self::ConnectionTimeout, + p3_types::ErrorCode::ConnectionReadTimeout => Self::ConnectionReadTimeout, + p3_types::ErrorCode::ConnectionWriteTimeout => Self::ConnectionWriteTimeout, + p3_types::ErrorCode::ConnectionLimitReached => Self::ConnectionLimitReached, + p3_types::ErrorCode::TlsProtocolError => Self::TlsProtocolError, + p3_types::ErrorCode::TlsCertificateError => Self::TlsCertificateError, + p3_types::ErrorCode::TlsAlertReceived(payload) => { + Self::TlsAlertReceived(p2_types::TlsAlertReceivedPayload { + alert_id: payload.alert_id, + alert_message: payload.alert_message, + }) + } + p3_types::ErrorCode::HttpRequestDenied => Self::HttpRequestDenied, + p3_types::ErrorCode::HttpRequestLengthRequired => Self::HttpRequestLengthRequired, + p3_types::ErrorCode::HttpRequestBodySize(payload) => Self::HttpRequestBodySize(payload), + p3_types::ErrorCode::HttpRequestMethodInvalid => Self::HttpRequestMethodInvalid, + p3_types::ErrorCode::HttpRequestUriInvalid => Self::HttpRequestUriInvalid, + p3_types::ErrorCode::HttpRequestUriTooLong => Self::HttpRequestUriTooLong, + p3_types::ErrorCode::HttpRequestHeaderSectionSize(payload) => { + Self::HttpRequestHeaderSectionSize(payload) + } + p3_types::ErrorCode::HttpRequestHeaderSize(payload) => { + Self::HttpRequestHeaderSize(payload.map(|payload| p2_types::FieldSizePayload { + field_name: payload.field_name, + field_size: payload.field_size, + })) + } + p3_types::ErrorCode::HttpRequestTrailerSectionSize(payload) => { + Self::HttpRequestTrailerSectionSize(payload) + } + p3_types::ErrorCode::HttpRequestTrailerSize(payload) => { + Self::HttpRequestTrailerSize(p2_types::FieldSizePayload { + field_name: payload.field_name, + field_size: payload.field_size, + }) + } + p3_types::ErrorCode::HttpResponseIncomplete => Self::HttpResponseIncomplete, + p3_types::ErrorCode::HttpResponseHeaderSectionSize(payload) => { + Self::HttpResponseHeaderSectionSize(payload) + } + p3_types::ErrorCode::HttpResponseHeaderSize(payload) => { + Self::HttpResponseHeaderSize(p2_types::FieldSizePayload { + field_name: payload.field_name, + field_size: payload.field_size, + }) + } + p3_types::ErrorCode::HttpResponseBodySize(payload) => { + Self::HttpResponseBodySize(payload) + } + p3_types::ErrorCode::HttpResponseTrailerSectionSize(payload) => { + Self::HttpResponseTrailerSectionSize(payload) + } + p3_types::ErrorCode::HttpResponseTrailerSize(payload) => { + Self::HttpResponseTrailerSize(p2_types::FieldSizePayload { + field_name: payload.field_name, + field_size: payload.field_size, + }) + } + p3_types::ErrorCode::HttpResponseTransferCoding(payload) => { + Self::HttpResponseTransferCoding(payload) + } + p3_types::ErrorCode::HttpResponseContentCoding(payload) => { + Self::HttpResponseContentCoding(payload) + } + p3_types::ErrorCode::HttpResponseTimeout => Self::HttpResponseTimeout, + p3_types::ErrorCode::HttpUpgradeFailed => Self::HttpUpgradeFailed, + p3_types::ErrorCode::HttpProtocolError => Self::HttpProtocolError, + p3_types::ErrorCode::LoopDetected => Self::LoopDetected, + p3_types::ErrorCode::ConfigurationError => Self::ConfigurationError, + p3_types::ErrorCode::InternalError(payload) => Self::InternalError(payload), + } + } +} + +/// Represents either a p2 or p3 `WasiHttpCtxView` getter. +pub enum ViewFn { + /// A p2 getter. + #[cfg(feature = "p2")] + P2(fn(&mut T) -> crate::p2::WasiHttpCtxView), + /// A p3 getter. + #[cfg(feature = "p3")] + P3(fn(&mut T) -> p3::WasiHttpCtxView), +} + +impl Clone for ViewFn { + fn clone(&self) -> Self { + match self { + &Self::P2(view) => Self::P2(view), + &Self::P3(view) => Self::P3(view), + } + } +} + +impl Copy for ViewFn {} + +/// A Request to be handled using `ProxyHandler::handle`. +pub type Request = http::Request>; + +/// A Response returned by `ProxyHandler::handle`. +pub type Response = http::Response>; /// Alternative p2 bindings generated with `exports: { default: async | store }` /// so we can use `TypedFunc::call_concurrent` with both p2 and p3 instances. @@ -58,7 +310,8 @@ pub enum ProxyPre { } impl ProxyPre { - async fn instantiate_async(&self, store: impl AsContextMut) -> Result + /// Instantiates the pre-instance. + pub async fn instantiate_async(&self, store: impl AsContextMut) -> Result where T: Send, { @@ -82,24 +335,17 @@ pub enum Proxy { P3(p3::bindings::Service), } -/// Represents a task to run using a `wasi:http/incoming-handler@0.2.x` or -/// `wasi:http/handler@0.3.x` instance. -pub type TaskFn = Box< - dyn for<'a> FnOnce(&'a Accessor, &'a Proxy) -> Pin + Send + 'a>> - + Send, ->; - /// Async MPMC channel where each item is delivered to at most one consumer. struct Queue { queue: Mutex>, - notify: Notify, + notify_push: Notify, } impl Default for Queue { fn default() -> Self { Self { queue: Default::default(), - notify: Default::default(), + notify_push: Default::default(), } } } @@ -109,21 +355,16 @@ impl Queue { self.queue.lock().unwrap().is_empty() } - fn push(&self, item: T) { - self.queue.lock().unwrap().push_back(item); - self.notify.notify_one(); - } - fn try_pop(&self) -> Option { self.queue.lock().unwrap().pop_front() } async fn pop(&self) -> T { - // This code comes from the Unbound MPMC Channel example in [the + // This code comes from the Unbounded MPMC Channel example in [the // `tokio::sync::Notify` // docs](https://docs.rs/tokio/latest/tokio/sync/struct.Notify.html). - let mut notified = pin!(self.notify.notified()); + let mut notified = pin!(self.notify_push.notified()); loop { notified.as_mut().enable(); @@ -131,69 +372,171 @@ impl Queue { return item; } notified.as_mut().await; - notified.set(self.notify.notified()); + notified.set(self.notify_push.notified()); } } } -/// Bundles a [`Store`] with a callback to write a profile (if configured). -pub struct StoreBundle { - /// The [`Store`] to use to handle requests. - pub store: Store, - /// Callback to write a profile (if enabled) once all requests have been - /// handled. - pub write_profile: Box) + Send>, +/// Represents the status of a `ProxyHandler` worker task. +#[derive(Clone, Copy, Eq, PartialEq, Debug)] +pub enum WorkerStatus { + /// The worker is not handling any requests, nor is it doing any post-return + /// work. It _might_ be doing background work which the guest has indicated + /// can be interrupted and/or abandoned at any time, i.e. does not prevent + /// the instance from being disposed. + Idle, + /// The instance is handling one or more requests, waiting for each to + /// either produce a response or expire. + Requests, + /// All requests handled so far have either produced a response or expired, + /// but the guest has post-return work which needs to finish before the + /// instance can be considered idle. + PostReturn, } -/// Represents the application-specific state of a web server. -pub trait HandlerState: 'static + Sync + Send { - /// The type of the associated data for [`Store`]s created using - /// [`Self::new_store`]. +/// Represents the application-specific state of a `ProxyHandler` worker. +/// +/// [`HandlerState::instantiate`] returns an implementation of this trait for +/// each component instance (and thus each worker) created. The worker uses it +/// to determine when to exit. +pub trait WorkerExpiration: 'static + Send + Sync { + /// Poll whether the worker has expired. + /// + /// This will return `Poll::Ready(())` if the worker has expired, meaning + /// the component instance should be dropped. Otherwise, it will return + /// `Poll::Pending` and wake the `Waker` if and when it should be polled + /// again. + /// + /// `state` represents the current state of the worker, and `start` + /// represents when it transitioned into that state (or in the case of + /// `WorkerState::Requests`, when the most recent outstanding request + /// was accepted). + fn poll( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + state: WorkerStatus, + start: Instant, + ) -> Poll<()>; +} + +/// Represents the application-specific state of a `ProxyHandler` worker. +/// +/// [`HandlerState::instantiate`] returns an implementation of this trait for +/// each component instance (and thus each worker) created. The worker uses it +/// to determine how many requests to accept, how long to wait for the guest to +/// produce responses, etc. +pub trait WorkerState: 'static + Send + Sync { + /// The type of the associated data for [`Store`] belonging to this worker. type StoreData: Send; - /// Create a new [`Store`] for handling one or more requests. - /// - /// The `req_id` parameter is the value passed in the call to - /// [`ProxyHandler::spawn`] that created the worker to which the new `Store` - /// will belong. See that function's documentation for details. - fn new_store(&self, req_id: Option) -> Result>; + /// Indicate whether the worker should accept another request given the + /// current number it is already handling concurrently and the total it has + /// handled so far. + fn should_accept_request(&self, concurrent_count: usize, total_count: usize) -> ShouldAccept; - /// Maximum time allowed to handle a request. + /// Notification that a request has been accepted by the worker. /// - /// In practice, a guest may be allowed to run up to 2x this time in the - /// case of instance reuse to avoid penalizing concurrent requests being - /// handled by the same instance. - fn request_timeout(&self) -> Duration; - - /// Maximum time to keep an idle instance around before dropping it. - fn idle_instance_timeout(&self) -> Duration; + /// If the future returned by this function resolves before the guest has + /// produced a response, the request will be considered "expired" and the + /// original `ProxyHandler::handle` future will resolve to an + /// `Err(ExpirationError.into())`. In addition, the worker + /// will stop accepting new requests but will continue running until all + /// requests that have been accepted by the worker have either produced a + /// response or expired, at which point the state of the worker will + /// transition to either `WorkerState::PostReturn` or `WorkerState::Idle`. + /// + /// Note that the returned future is polled from within the + /// `Store::run_concurrent` event loop, and due to #11869 and #11870, it may + /// not be polled at all for arbitrary lengths of time. Consequently, the + /// `Self::Expiration` implementation (which is polled from _outside_ the + /// `Store::run_concurrent` event loop) must also enforce request expiration + /// as a second level of defence if desired. + /// + /// For example, if a request timeout of N seconds is to be enforced, the + /// `Self::Expiration::poll` implementation, when called with + /// `WorkerState::Requests` should calculate the time elapsed since the most + /// recent outstanding request was accepted as indicated by the `start` + /// parameter. If that time is greater than N seconds, we can expire the + /// instance immediately, confident that all outstanding requests have + /// expired. + /// + /// Once #11869 and #11870 have been addressed, this "second level of + /// defence" will no longer be necessary. + fn on_request_start( + &self, + request: &Request, + ) -> Pin + 'static + Send + Sync>>; + + /// Dispose of the store belonging to the now-exited worker. + /// + /// This may be used to e.g. collect metrics from the store or its + /// associated data before the store is dropped, as well as e.g. retry + /// failed instantiations after the store is dropped. + /// + /// If the store is being dropped due to an error (e.g. a guest trap or a + /// host panic) `result` will be `Err(_)`; otherwise it will be `Ok(())`. + fn drop(&self, store: Store, result: Result<(), wasmtime::Error>); +} - /// Maximum number of requests to handle using a single instance before - /// dropping it. - fn max_instance_reuse_count(&self) -> usize; +/// Represents the combination of a store and instance with which to handle +/// requests. +pub struct Instance { + /// The store to use to handle requests. + pub store: Store, + /// The instance to use to handle requests. + pub proxy: Proxy, + /// `WasiHttpCtxView` getter function. + pub view: ViewFn, + /// See [`WorkerExpiration`]. + pub expiration: E, + /// See [`WorkerState`]. + pub state: S, +} - /// Maximum number of requests to handle concurrently using a single - /// instance. - fn max_instance_concurrent_reuse_count(&self) -> usize; +/// Indicates whether a worker should accept new requests. +pub enum ShouldAccept { + /// Yes, it should. + Yes, + /// No, it shouldn't (but ask again later). + No, + /// No, it shouldn't (and don't ask again). + Never, +} - /// Called when a worker exits with an error. - fn handle_worker_error(&self, error: wasmtime::Error); +/// Represents the application-specific state of a web server. +pub trait HandlerState: 'static + Sync + Send + Sized { + /// The type of the associated data for [`Store`]s created using + /// [`Self::instantiate`]. + type StoreData: Send; + /// The type of the `WorkerExpiration` implementation to be returned from + /// [`Self::instantiate`]. + type WorkerExpiration: WorkerExpiration; + /// The type of the `WorkerState` implementation to be returned from + /// [`Self::instantiate`]. + type WorkerState: WorkerState; + + /// Create a new store and instance for handling one or more requests. + /// + /// Note that the implementer is responsible for applying a timeout to the + /// guest instantiation if appropriate (e.g. as part of an overall request + /// timeout). + fn instantiate( + &self, + ) -> impl Future< + Output = Result>, + > + Send; } struct ProxyHandlerInner { state: S, - instance_pre: ProxyPre, - next_id: AtomicU64, - task_queue: Queue>, + request_queue: Queue<(Request, oneshot::Sender>)>, worker_count: AtomicUsize, } -/// Helper utility to track the start times of tasks accepted by a worker. +/// Tracks request start times. /// -/// This is used to ensure that timeouts are enforced even when the -/// `StoreContextMut::run_concurrent` event loop is unable to make progress due -/// to the guest either busy looping or being blocked on a synchronous call to a -/// host function which has exclusive access to the `Store`. +/// This is useful for keeping a [`WorkerState`] appraised of the most recently +/// accepted outstanding request. #[derive(Default)] struct StartTimes(BTreeMap); @@ -217,8 +560,8 @@ impl StartTimes { } } - fn earliest(&self) -> Option { - self.0.first_key_value().map(|(&k, _)| k) + fn most_recent(&self) -> Option { + self.0.last_key_value().map(|(&k, _)| k) } } @@ -240,114 +583,188 @@ where if available { self.handler.0.worker_count.fetch_add(1, Relaxed); } else { - // Here we use `SeqCst` to ensure the load/store is ordered - // correctly with respect to the `Queue::is_empty` check we do - // below. - let count = self.handler.0.worker_count.fetch_sub(1, SeqCst); + // Decrement the count _before_ checking if the request queue is + // empty. This helps ensure that `ProxyHandler::spawn` sees the + // new value before deciding whether to spawn a new worker. + let count = self.handler.0.worker_count.fetch_sub(1, Relaxed); + assert!(count >= 1); + // This addresses what would otherwise be a race condition in // `ProxyHandler::spawn` where it only starts a worker if the // available worker count is zero. If we decrement the count to // zero right after `ProxyHandler::spawn` checks it, then no // worker will be started; thus it becomes our responsibility to // start a worker here instead. - if count == 1 && !self.handler.0.task_queue.is_empty() { - self.handler.start_worker(None, None); + if count == 1 && !self.handler.0.request_queue.is_empty() { + self.handler.start_worker(None); } } } } - async fn run(mut self, task: Option>, req_id: Option) { - if let Err(error) = self.run_(task, req_id).await { - self.handler.0.state.handle_worker_error(error); + async fn run( + self, + request: Option<(Request, oneshot::Sender>)>, + ) { + match self.handler.0.state.instantiate().await { + Ok(Instance { + store, + proxy, + view, + expiration, + state, + }) => { + self.run_(store, proxy, view, expiration, state, request) + .await + } + + Err(error) => { + let error = Arc::new(error); + if let Some((request, tx)) = request { + _ = tx.send(Err(InstantiationError { + request: Mutex::new(request), + error, + } + .into())); + } else { + // In this case, the worker was spawned to handle any queued + // requests. Since we can't handle those requests, we send + // them all an instantiation error. + for (request, tx) in mem::take( + self.handler + .0 + .request_queue + .queue + .lock() + .unwrap() + .deref_mut(), + ) { + _ = tx.send(Err(InstantiationError { + request: Mutex::new(request), + error: error.clone(), + } + .into())); + } + } + } } } async fn run_( - &mut self, - task: Option>, - req_id: Option, - ) -> Result<()> { + mut self, + store: Store, + proxy: Proxy, + view: ViewFn, + expiration: S::WorkerExpiration, + state: S::WorkerState, + request: Option<(Request, oneshot::Sender>)>, + ) { // NB: The code the follows is rather subtle in that it is structured - // carefully to provide a few key invariants related to how instance - // reuse and request timeouts interact: - // - // - A task must never be allowed to run for more than 2x the request - // timeout, if any. - // - // - Every task we accept here must be allowed to run for at least 1x - // the request timeout, if any. + // carefully to give the `HandlerState` implementation full control over + // the component instance lifetime. Specifically, we must keep the + // `HandlerState` informed of the worker's state and how long it has + // been in that state, as well as allow it to expire the instance based + // on whatever combination of timeouts, dynamic resource usage, etc. it + // may take into consideration. // - // - When more than one task is run concurrently in the same instance, - // we must stop accepting new tasks as soon as any existing task reaches - // the request timeout. This serves to cap the amount of time we need - // to keep the instance alive before _all_ tasks have either completed - // or timed out. + // Note that, when more than one request is handled concurrently in the + // same instance, we must stop accepting new requests as soon as any + // existing request reaches its expiration. This serves to cap the + // amount of time we need to keep the instance alive before _all_ + // requests have either completed or expired. // - // As of this writing, there's an additional wrinkle that makes - // guaranteeing those invariants particularly tricky: per #11869 and - // #11870, busy guest loops, epoch interruption, and host functions - // registered using `Linker::func_{wrap,new}_async` all require - // blocking, exclusive access to the `Store`, which effectively prevents - // the `StoreContextMut::run_concurrent` event loop from making - // progress. That, in turn, prevents any concurrent tasks from - // executing, and also prevents the `AsyncFnOnce` passed to - // `run_concurrent` from being polled. Consequently, we must rely on a - // "second line of defense" to ensure tasks are timed out promptly, - // which is to check for timeouts _outside_ the `run_concurrent` future. - // Once the aforementioned issues have been addressed, we'll be able to - // remove that check and its associated baggage. - - let handler = &self.handler.0; - - let StoreBundle { - mut store, - write_profile, - } = handler.state.new_store(req_id)?; - - let request_timeout = handler.state.request_timeout(); - let idle_instance_timeout = handler.state.idle_instance_timeout(); - let max_instance_reuse_count = handler.state.max_instance_reuse_count(); - let max_instance_concurrent_reuse_count = - handler.state.max_instance_concurrent_reuse_count(); - - let proxy = &handler.instance_pre.instantiate_async(&mut store).await?; + // As of this writing, there's an additional wrinkle that makes tracking + // expiration particularly tricky: per #11869 and #11870, busy guest + // loops, epoch interruption, and host functions registered using + // `Linker::func_{wrap,new}_async` all require blocking, exclusive + // access to the `Store`, which effectively prevents the + // `StoreContextMut::run_concurrent` event loop from making progress. + // That, in turn, prevents any concurrent tasks from executing, and also + // prevents the `AsyncFnOnce` passed to `run_concurrent` from being + // polled. Consequently, we must poll `S::WorkerState` from _outside_ + // the `run_concurrent` future to ensure expirations are enforced. Once + // the aforementioned issues have been addressed, we'll be able to + // simplify the code and eliminate the need for communication between + // the "inside" future and the "outside" one. + + // Wrap `store` in an object which, prior to leaving this scope, will + // pass the `store` to `HandlerState::drop`. + struct Dropper { + state: S::WorkerState, + store: Option>, + } + + impl Drop for Dropper { + fn drop(&mut self) { + if let Some(store) = self.store.take() { + self.state + .drop(store, Err(wasmtime::format_err!("worker panicked"))); + } + } + } + + let mut dropper = Dropper:: { + state, + store: Some(store), + }; + + let proxy = &proxy; + let accept_concurrent = AtomicBool::new(true); - let task_start_times = Mutex::new(StartTimes::default()); + let status = Mutex::new((WorkerStatus::Idle, Instant::now())); + let mut expiration = pin!(expiration); - let mut future = pin!(store.run_concurrent(async |accessor| { + let function = async |accessor: &Accessor<_>| { let mut reuse_count = 0; - let mut timed_out = false; + let mut may_accept = true; let mut futures = FuturesUnordered::new(); + let mut start_times = StartTimes::default(); - let accept_task = |task: TaskFn, - futures: &mut FuturesUnordered<_>, - reuse_count: &mut usize| { + let accept_request = |request: Request, + tx: oneshot::Sender>, + futures: &mut FuturesUnordered<_>, + start_times: &mut StartTimes, + reuse_count: &mut usize| { // Set `accept_concurrent` to false, conservatively assuming // that the new task will be CPU-bound, at least to begin with. // Only once the `StoreContextMut::run_concurrent` event loop // returns `Pending` will we set `accept_concurrent` back to - // true and consider accepting more tasks. + // true and consider accepting more requests. // // This approach avoids taking on more than one CPU-bound task // at a time, which would hurt throughput vs. leaving the - // additional tasks for other workers to handle. + // additional requests for other workers to handle. accept_concurrent.store(false, Relaxed); *reuse_count += 1; - let start_time = Instant::now().checked_add(request_timeout); - if let Some(start_time) = start_time { - task_start_times.lock().unwrap().add(start_time); - } - - futures.push(tokio::time::timeout(request_timeout, async move { - (task)(accessor, proxy).await; - start_time - })); + // Notify the `HandlerState` that we're starting to handle a + // request and retrieve the deadline by which it must produce a + // response. + // + // If it fails to produce a response by the deadline, we'll stop + // accepting new requests and eventually exit the worker. + let expiration = dropper.state.on_request_start(&request); + + let start_time = Instant::now(); + start_times.add(start_time); + *status.try_lock().unwrap() = (WorkerStatus::Requests, start_time); + + futures.push(async move { + Ok::<_, wasmtime::Error>(( + handle(accessor, proxy, request, view, tx, expiration).await?, + start_time, + )) + }); }; - if let Some(task) = task { - accept_task(task, &mut futures, &mut reuse_count); + if let Some((request, tx)) = request { + accept_request( + request, + tx, + &mut futures, + &mut start_times, + &mut reuse_count, + ); } // This is the main driver loop for this worker. This is modeled as @@ -355,25 +772,18 @@ where // Events are sourced from the locals here, pinned outside of the // `poll_fn` closure. let mut futures = pin!(futures); - let mut idle_timeout_set = false; - let mut idle_timeout = pin!(tokio::time::sleep(Duration::MAX)); let handler = self.handler.clone(); - let mut incoming_tasks = pin!(futures::stream::unfold( - &handler.0.task_queue, + let mut incoming_requests = pin!(futures::stream::unfold( + &handler.0.request_queue, |queue| async move { - let task = queue.pop().await; - Some((task, queue)) + let pair = queue.pop().await; + Some((pair, queue)) } )); future::poll_fn(|cx| { - // See docs about the idle timeout handling at the very bottom - // for what this is doing. - let prev_idle_timeout_set = idle_timeout_set; - idle_timeout_set = false; - loop { - // First, and crucially first , poll `futures` first. This - // way we'll discover any tasks that may have timed out, at + // First, and crucially first, poll `futures`. This way + // we'll discover any tasks that may have timed out, at // which point we'll stop accepting new tasks altogether // (see below for details). This is especially important in // the case where the task was blocked on a synchronous call @@ -383,29 +793,34 @@ where // task first, then we'd have to wait for _that_ task to // finish or time out before we could kill the instance. match futures.as_mut().poll_next(cx) { - // Task completed; carry on! - Poll::Ready(Some(Ok(start_time))) => { - if let Some(start_time) = start_time { - task_start_times.lock().unwrap().remove(start_time); + // A request either produced a response or expired. + Poll::Ready(Some(Ok((responded, start_time)))) => { + // Remove its start time from the map and update the + // state. + start_times.remove(start_time); + *status.try_lock().unwrap() = + if let Some(start_time) = start_times.most_recent() { + (WorkerStatus::Requests, start_time) + } else { + (WorkerStatus::PostReturn, Instant::now()) + }; + + if responded { + // Response produced; carry on! + } else { + // Request expired; stop accepting new requests, but + // continue polling until any other, in-progress + // tasks until they have either finished or expired. + // This effectively kicks off a "graceful shutdown" + // of the worker, allowing any other concurrent + // tasks time to finish before we drop the instance. + may_accept = false; } } - // Task timed out; stop accepting new tasks, but - // continue polling until any other, in-progress tasks - // until they have either finished or timed out. This - // effectively kicks off a "graceful shutdown" of the - // worker, allowing any other concurrent tasks time to - // finish before we drop the instance. - // - // TODO: We should also send a cancel request to the - // timed-out task to give it a chance to shut down - // gracefully (and delay dropping the instance for a - // reasonable amount of time), but as of this writing - // Wasmtime does not yet provide an API for doing that. - // See issue #11833. - Poll::Ready(Some(Err(_))) => { - timed_out = true; - reuse_count = max_instance_reuse_count; + // Instance trapped. + Poll::Ready(Some(Err(error))) => { + break Poll::Ready(Err(error)); } Poll::Ready(None) | Poll::Pending => {} @@ -421,9 +836,20 @@ where // have capacity for another task if either we have no tasks // at all or all our tasks really are blocked on I/O. self.set_available( - reuse_count < max_instance_reuse_count - && futures.len() < max_instance_concurrent_reuse_count - && (futures.is_empty() || accept_concurrent.load(Relaxed)), + may_accept + && match dropper + .state + .should_accept_request(futures.len(), reuse_count) + { + ShouldAccept::Yes => { + futures.is_empty() || accept_concurrent.load(Relaxed) + } + ShouldAccept::No => false, + ShouldAccept::Never => { + may_accept = false; + false + } + }, ); // If we're available for accepting more requests after the @@ -431,9 +857,16 @@ where // successful then push it into `futures` and turn this loop // again to see where we're at next time around. if self.available - && let Poll::Ready(Some(task)) = incoming_tasks.as_mut().poll_next(cx) + && let Poll::Ready(Some((request, tx))) = + incoming_requests.as_mut().poll_next(cx) { - accept_task(task, &mut futures, &mut reuse_count); + accept_request( + request, + tx, + &mut futures, + &mut start_times, + &mut reuse_count, + ); continue; } @@ -456,133 +889,91 @@ where // then we're done with this iteration of `poll`. We'll get // woken up when anything changes, but otherwise it's time // to let something else happen. - // - // This is all skipped if something has timed out though. In - // that situation we're basically no longer interested in - // this store so we're no longer cooperatively trying to let - // it keep going. - if !timed_out && !accessor.poll_no_interesting_tasks(cx).is_ready() { + if accessor.poll_no_interesting_tasks(cx).is_pending() { break Poll::Pending; } // And now at this point we (a) have no `futures`, (b) no - // new connections came in, and (c) the store is completely - // devoid of interesting work. In this situation if we're - // not actually capable of accepting any more work, then - // we're completely done and it's time to exit this worker. - if !self.available { - break Poll::Ready(()); + // new requests are available, and (c) the store is + // completely devoid of interesting work. In this situation + // if we're not actually capable of accepting any more work, + // then we're completely done and it's time to exit this + // worker. + if !may_accept { + break Poll::Ready(Ok(())); } - // And now, finally, we wait for a timeout. Here we're just - // like above except that we're candidate for accepting more - // work in the future. If this is our first time here then - // reset the idle timeout to `idle_instance_timeout` from - // now, but othrewise just go take a look at `idle_timeout` - // and see if it's elapsed yet. - // - // Note that the way that this entire loop is structured is - // that we've already polled all the interesting sources of - // events we're interested in at this point, for example - // `futures`, `accessor`, and `incoming_tasks`. Here we add - // `idle_timeout` to that set and once anything is ready and - // fires then this entire loop will restart and we'll check - // everything again. - // - // Also note that the idle timeout is supposed to start when - // the store is itself entirely idle. The way this loop is - // structured is that when we entire this `poll` closure the - // `idle_timeout_set` variable is unconditionally set to - // `false`. That way if we exit out for some other reason, - // such as getting work, then the idle timeout will get - // reset next time we fall down here. Otherwise though if we - // fell down this far we actually want to preserve - // `idle_timeout_set` from when we first started, so that's - // restored here. - idle_timeout_set = prev_idle_timeout_set; - if !idle_timeout_set { - idle_timeout - .as_mut() - .reset(tokio::time::Instant::now() + idle_instance_timeout); - idle_timeout_set = true; + // Finally, at this point we're idle but still eligible to + // accept new work, so update the state if appropriate and + // then return pending while we wait for new work. + { + let mut status = status.try_lock().unwrap(); + if status.0 != WorkerStatus::Idle { + *status = (WorkerStatus::Idle, Instant::now()); + } } - break idle_timeout.as_mut().poll(cx); + break Poll::Pending; } }) - .await; - - accessor.with(|mut access| write_profile(access.as_context_mut())); + .await + }; - if timed_out { - Err(format_err!("guest timed out")) - } else { - wasmtime::error::Ok(()) - } - })); + let result = { + let mut future = pin!( + dropper + .store + .as_mut() + .unwrap() + .run_concurrent(function) + .map(|v| v.flatten()) + ); - let mut sleep = pin!(tokio::time::sleep(Duration::MAX)); + future::poll_fn(|cx| { + let poll = future.as_mut().poll(cx); + if poll.is_pending() { + // If the future returns `Pending`, that's either because it's + // idle (in which case it can definitely accept a new request) or + // because all its tasks are awaiting I/O, in which case it may + // have capacity for additional tasks to run concurrently. + // + // However, per #11869 and #11870, if one of the tasks is + // blocked on a sync call to a host function which has exclusive + // access to the `Store`, the `StoreContextMut::run_concurrent` + // event loop will be unable to make progress until that call + // finishes. Similarly, if the task loops indefinitely, subject + // only to epoch interruption, the event loop will also be + // stuck. Either way, any request expirations created inside + // the `AsyncFnOnce` we passed to `run_concurrent` won't have a + // chance to trigger. Consequently, we poll for instance + // expiration here, outside the event loop, based on the most + // recently recorded state of the worker. + + let (status, start) = *status.try_lock().unwrap(); + + if let Poll::Ready(()) = expiration.as_mut().poll(cx, status, start) { + return Poll::Ready(match status { + WorkerStatus::Requests | WorkerStatus::PostReturn => { + Err(format_err!("guest timed out")) + } + WorkerStatus::Idle => Ok(()), + }); + } - future::poll_fn(|cx| { - let poll = future.as_mut().poll(cx); - if poll.is_pending() { - // If the future returns `Pending`, that's either because it's - // idle (in which case it can definitely accept a new task) or - // because all its tasks are awaiting I/O, in which case it may - // have capacity for additional tasks to run concurrently. - // - // However, if one of the tasks is blocked on a sync call to a - // host function which has exclusive access to the `Store`, the - // `StoreContextMut::run_concurrent` event loop will be unable - // to make progress until that call finishes. Similarly, if the - // task loops indefinitely, subject only to epoch interruption, - // the event loop will also be stuck. Either way, any task - // timeouts created inside the `AsyncFnOnce` we passed to - // `run_concurrent` won't have a chance to trigger. - // Consequently, we need to _also_ enforce timeouts here, - // outside the event loop. - // - // Therefore, we check if the oldest outstanding task has been - // running for at least `request_timeout*2`, which is the - // maximum time needed for any other concurrent tasks to - // complete or time out, at which point we can safely discard - // the instance. If that deadline has not yet arrived, we - // schedule a wakeup to occur when it does. - // - // We uphold the "never kill an instance with a task which has - // been running for less than the request timeout" invariant - // here by noting that this timeout will only trigger if the - // `AsyncFnOnce` we passed to `run_concurrent` has been unable - // to run for at least the past `request_timeout` amount of - // time, meaning it can't possibly have accepted a task newer - // than that. - if let Some(deadline) = task_start_times - .lock() - .unwrap() - .earliest() - .and_then(|v| v.checked_add(request_timeout.saturating_mul(2))) - { - sleep.as_mut().reset(deadline.into()); - // Note that this will schedule a wakeup for later if the - // deadline has not yet arrived: - if sleep.as_mut().poll(cx).is_ready() { - // Deadline has been reached; kill the instance with an - // error. - return Poll::Ready(Err(format_err!("guest timed out"))); + // Otherwise, if the instance has not yet expired, we set + // `accept_concurrent` to true and, if it wasn't already true + // before, poll the future one more time so it can ask for + // another request if appropriate. + if !accept_concurrent.swap(true, Relaxed) { + return future.as_mut().poll(cx); } } - // Otherwise, if no timeouts have elapsed, we set - // `accept_concurrent` to true and, if it wasn't already true - // before, poll the future one more time so it can ask for - // another task if appropriate. - if !accept_concurrent.swap(true, Relaxed) { - return future.as_mut().poll(cx); - } - } + poll + }) + .await + }; - poll - }) - .await? + dropper.state.drop(dropper.store.take().unwrap(), result); } } @@ -608,76 +999,150 @@ impl Clone for ProxyHandler { } } +/// This error is returned if, when handling the request, a new worker and +/// associated instance needed to be created, but instantiation failed, e.g. due +/// to reaching a pooling allocator limit or running out of memory. In this +/// case, the caller may be able to recover and retry (e.g. after waiting for +/// existing instances to be dropped and/or freeing memory used by caches, +/// etc.). Otherwise, it will probably need to return an HTTP 500 error. +pub struct InstantiationError { + /// The original request passed to `ProxyHandler::handle`. + /// + /// This is wrapped in a `Mutex` to satisfy the `Send + Sync` bounds + /// required by `wasmtime::Error`. + pub request: Mutex, + /// The original instantiation error. + /// + /// This is wrapped in an `Arc` because a single instantiation error may + /// affect multiple requests, and each caller will be given a clone. + pub error: Arc, +} + +impl fmt::Display for InstantiationError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> { + write!(f, "instantiation error: {}", self.error) + } +} + +impl fmt::Debug for InstantiationError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> { + write!(f, "instantiation error: {:?}", self.error) + } +} + +impl error::Error for InstantiationError {} + +/// Returned when the guest failed to produce a response before the expiration +/// returned by `HandlerState::on_request_start` elapsed. +pub struct ExpirationError; + +impl fmt::Display for ExpirationError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> { + fmt::Debug::fmt(self, f) + } +} + +impl fmt::Debug for ExpirationError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> { + write!(f, "guest timed out") + } +} + +impl error::Error for ExpirationError {} + +/// A worker trapped or panicked and failed to produce a result. +pub struct TrapOrPanicError; + +impl fmt::Display for TrapOrPanicError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> { + fmt::Debug::fmt(self, f) + } +} + +impl fmt::Debug for TrapOrPanicError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> { + write!(f, "worker trapped or panicked") + } +} + +impl error::Error for TrapOrPanicError {} + impl ProxyHandler where S: HandlerState, { /// Create a new `ProxyHandler` with the specified application state and /// pre-instance. - pub fn new(state: S, instance_pre: ProxyPre) -> Self { + pub fn new(state: S) -> Self { Self(Arc::new(ProxyHandlerInner { state, - instance_pre, - next_id: AtomicU64::from(0), - task_queue: Default::default(), + request_queue: Default::default(), worker_count: AtomicUsize::from(0), })) } - /// Push a task to the task queue for this handler. + /// Handle the specified request, returning a response on success or the + /// tuple of the request and error on failure. /// - /// This will either spawn a new background worker to run the task or - /// deliver it to an already-running worker. + /// This function will return a `wasmtime::Error` on failure, which may be + /// downcast to a more specific type in certain scenarios: /// - /// The `req_id` will be passed to `::new_store` _if_ a - /// new worker is started for this task. It is intended to be used as a - /// "request identifier" corresponding to that task and can be used e.g. to - /// prefix all logging from the `Store` with that identifier. Note that a - /// non-`None` value only makes sense when `::max_instance_reuse_count == 1`; otherwise the identifier - /// will not match subsequent tasks handled by the worker. - pub fn spawn(&self, req_id: Option, task: TaskFn) { - match self.0.state.max_instance_reuse_count() { - 0 => panic!("`max_instance_reuse_count` must be at least 1"), - _ => { - if self.0.worker_count.load(Relaxed) == 0 { - // There are no available workers; skip the queue and pass - // the task directly to the worker, which improves - // performance as measured by `wasmtime-server-rps.sh` by - // about 15%. - self.start_worker(Some(task), req_id); - } else { - self.0.task_queue.push(task); - // Start a new worker to handle the task if the last worker - // just went unavailable. See also `Worker::set_available` - // for what happens if the available worker count goes to - // zero right after we check it here, and note that we only - // check the count _after_ we've pushed the task to the - // queue. We use `SeqCst` here to ensure that we get an - // updated view of `worker_count` as it exists after the - // `Queue::push` above. - // - // The upshot is that at least one (or more) of the - // following will happen: - // - // - An existing worker will accept the task - // - We'll start a new worker here to accept the task - // - `Worker::set_available` will start a new worker to accept the task - // - // I.e. it should not be possible for the task to be - // orphaned indefinitely in the queue without being - // accepted. - if self.0.worker_count.load(SeqCst) == 0 { - self.start_worker(None, None); - } - } + /// - [`InstantiationError`] if a new worker was created to handle the + /// request but could not instantiate the guest component. + /// + /// - [`ExpirationError`] if the request expired before it produced a + /// response. See [`HandlerState::on_request_start`] for details. + /// + /// - [`TrapOrPanicError`] if the worker responsible for handling the + /// request trapped or panicked before it produced a response. This may be + /// used when a trap occurs but cannot be traced to a specific request, + /// e.g. during concurrent request handling. + /// + /// In other failure cases (e.g. `wasi:http/types#error-code` return values + /// and/or traps when executing synchronous WASIp2 handler functions), the + /// original error returned by the handler will be returned. + pub async fn handle(&self, request: Request) -> Result { + let (tx, rx) = oneshot::channel(); + + if self.0.worker_count.load(Relaxed) == 0 { + // There are no available workers; skip the queue and pass + // the request directly to the worker, which improves + // performance as measured by `wasmtime-server-rps.sh` by + // about 15%. + self.start_worker(Some((request, tx))); + } else { + let mut queue = self.0.request_queue.queue.lock().unwrap(); + queue.push_back((request, tx)); + + // Start a new worker to handle the request if the last worker just + // went unavailable. See also `Worker::set_available` for what + // happens if the available worker count goes to zero right after we + // check it here, and note that we only check the count _after_ + // we've pushed the request to the queue. + // + // The upshot is that at least one (or more) of the + // following will happen: + // + // - An existing worker will accept the request + // - We'll start a new worker here to accept the request + // - `Worker::set_available` will start a new worker to accept the request + // + // I.e. it should not be possible for the request to be orphaned + // indefinitely in the queue without being accepted except in the + // case of a panic or an instantiation error. In the case of an + // instantiation error, we'll give the request back to the caller in + // an `Err(_)`, allowing the application to decide what to do next. + if self.0.worker_count.load(Relaxed) == 0 { + let (request, tx) = queue.pop_back().unwrap(); + drop(queue); + self.start_worker(Some((request, tx))); + } else { + drop(queue); + self.0.request_queue.notify_push.notify_one(); } } - } - /// Generate a unique request ID. - pub fn next_req_id(&self) -> u64 { - self.0.next_id.fetch_add(1, Relaxed) + rx.await.map_err(|_| TrapOrPanicError)? } /// Return a reference to the application state. @@ -685,18 +1150,145 @@ where &self.0.state } - /// Return a reference to the pre-instance. - pub fn instance_pre(&self) -> &ProxyPre { - &self.0.instance_pre - } - - fn start_worker(&self, task: Option>, req_id: Option) { + fn start_worker( + &self, + request: Option<(Request, oneshot::Sender>)>, + ) { tokio::spawn( Worker { handler: self.clone(), available: false, } - .run(task, req_id), + .run(request), ); } } + +async fn handle( + accessor: &Accessor, + proxy: &Proxy, + request: Request, + view: ViewFn, + tx: oneshot::Sender>, + expiration: impl Future, +) -> Result { + let expiration = pin!(expiration); + + match (proxy, view) { + #[cfg(feature = "p3")] + (Proxy::P3(guest), ViewFn::P3(view)) => { + let (request, body) = request.into_parts(); + let body = body.map_err(p3_types::ErrorCode::from); + let request = http::Request::from_parts(request, body); + let (request, request_io_result) = p3::Request::from_http(request); + + let request = accessor.with(|mut store| { + Ok::<_, wasmtime::Error>(view(store.data_mut()).table.push(request)?) + })?; + + let handle = pin!(async move { + let response = guest + .wasi_http_handler() + .call_handle(accessor, request) + .await?; + + let response = accessor.with(|mut store| { + let response = view(store.get()).table.delete(response?)?; + Ok::<_, wasmtime::Error>(response.into_http_with_getter( + &mut store, + request_io_result, + view, + )?) + })?; + + Ok(response.map(move |body| body.map_err(ErrorCode::from).boxed_unsync())) + }); + + // TODO: We should also use `oneshot::Sender::poll_close` to be + // notified when the receiver is dropped, in which case we should + // expire the request since the response is no longer of interest to + // the original `ProxyHandler::handle` caller. + let (result, sent) = match futures::future::select(handle, expiration).await { + Either::Left((result, _)) => (result, true), + // TODO: We should also send a cancel request to the expired + // task to give it a chance to shut down gracefully, but as of + // this writing Wasmtime does not yet provide an API for doing + // that. See issue #11833. Instead, we let it continue running + // as a background task until it either returns a response + // (which we'll ignore) or the instance itself has expired. + Either::Right(((), _)) => (Err(ExpirationError.into()), false), + }; + + _ = tx.send(result); + + Ok(sent) + } + (Proxy::P2(guest), ViewFn::P2(view)) => { + // Here we wrap the sender in an `Arc>>`, with one + // clone used in the `response-outparam` and the other used to send + // an error if the request expires or the handler returns without + // producing a response. + let tx = Arc::new(Mutex::new(Some(tx))); + + let (request, out) = accessor.with({ + let tx = tx.clone(); + move |mut access| { + let request = view(access.data_mut()) + .new_incoming_request(p2_types::Scheme::Http, request)?; + + let out = view(access.data_mut()).new_response_outparam_from_callback( + move |value| { + if let Some(tx) = tx.lock().unwrap().take() { + _ = tx.send( + value + .map(|v| { + v.map(move |body| { + body.map_err(ErrorCode::from).boxed_unsync() + }) + }) + .map_err(wasmtime::Error::from), + ); + } + }, + )?; + + wasmtime::error::Ok((request, out)) + } + })?; + + let handle = pin!( + guest + .wasi_http_incoming_handler() + .call_handle(accessor, request, out) + ); + + const MESSAGE: &str = "guest never invoked `response-outparam::set` method"; + + struct Dropper(Arc>>>>); + + impl Drop for Dropper { + fn drop(&mut self) { + if let Some(tx) = self.0.lock().unwrap().take() { + _ = tx.send(Err(format_err!("{MESSAGE}"))); + } + } + } + + let tx = Dropper(tx); + + // See corresponding TODO comment for the p3 case above. + let (result, sent) = match futures::future::select(handle, expiration).await { + Either::Left((result, _)) => (result.context(MESSAGE), true), + // See corresponding TODO comment for the p3 case above. + Either::Right(((), _)) => (Err(ExpirationError.into()), false), + }; + + if let Some(tx) = tx.0.lock().unwrap().take() { + _ = tx.send(result.and_then(|()| Err(format_err!("{MESSAGE}")))); + } + + Ok(sent) + } + _ => unreachable!(), + } +} diff --git a/crates/wasi-http/src/p2/types.rs b/crates/wasi-http/src/p2/types.rs index dd845ad7b2e5..7018b392d8a5 100644 --- a/crates/wasi-http/src/p2/types.rs +++ b/crates/wasi-http/src/p2/types.rs @@ -146,9 +146,9 @@ impl WasiHttpCtxView<'_> { /// The concrete type behind a `wasi:http/types.response-outparam` resource. pub struct HostResponseOutparam { - /// The sender for sending a response. - pub result: - tokio::sync::oneshot::Sender, types::ErrorCode>>, + /// The callback sending a response. + pub send: + Box, types::ErrorCode>) + Send + Sync>, } impl WasiHttpCtxView<'_> { @@ -159,7 +159,29 @@ impl WasiHttpCtxView<'_> { Result, types::ErrorCode>, >, ) -> wasmtime::Result> { - let id = self.table.push(HostResponseOutparam { result })?; + let id = self.table.push(HostResponseOutparam { + send: Box::new(move |value| { + // Giving the API doesn't return any error, it's probably + // better to ignore the error than trap the guest, in case of + // host timeout and dropped the receiver side of the channel. + // See also: #10784 + _ = result.send(value) + }), + })?; + Ok(id) + } + + /// Create a new outgoing response from an `FnOnce`. + pub fn new_response_outparam_from_callback( + &mut self, + callback: impl FnOnce(Result, types::ErrorCode>) + + Send + + Sync + + 'static, + ) -> wasmtime::Result> { + let id = self.table.push(HostResponseOutparam { + send: Box::new(callback), + })?; Ok(id) } } diff --git a/crates/wasi-http/src/p2/types_impl.rs b/crates/wasi-http/src/p2/types_impl.rs index e403784a7b52..44d1d5eda2cf 100644 --- a/crates/wasi-http/src/p2/types_impl.rs +++ b/crates/wasi-http/src/p2/types_impl.rs @@ -401,11 +401,7 @@ impl types::HostResponseOutparam for WasiHttpCtxView<'_> { }; let resp = self.table.delete(id)?; - // Giving the API doesn't return any error, it's probably - // better to ignore the error than trap the guest, in case of - // host timeout and dropped the receiver side of the channel. - // See also: #10784 - let _ = resp.result.send(val); + (resp.send)(val); Ok(()) } diff --git a/src/commands/serve.rs b/src/commands/serve.rs index c69678f2417f..aa95ecc00d6a 100644 --- a/src/commands/serve.rs +++ b/src/commands/serve.rs @@ -1,11 +1,10 @@ use crate::common::{HttpHooks, Profile, RunCommon, RunTarget}; use bytes::Bytes; use clap::Parser; -use futures::future::FutureExt; use http::{Response, StatusCode}; use http_body_util::BodyExt as _; use http_body_util::combinators::UnsyncBoxBody; -use hyper::body::{Body, Frame, SizeHint}; +use pin_project_lite::pin_project; use std::convert::Infallible; use std::ffi::OsString; use std::net::SocketAddr; @@ -15,22 +14,27 @@ use std::{ path::PathBuf, sync::{ Arc, Mutex, - atomic::{AtomicBool, Ordering}, + atomic::{AtomicBool, AtomicU64, Ordering}, }, - time::Duration, + time::{Duration, Instant}, }; use tokio::io::{self, AsyncWrite}; use tokio::sync::Notify; use wasmtime::component::{Component, Linker}; +#[cfg(feature = "gdbstub")] +use wasmtime::error::Context as _; use wasmtime::{ - Engine, Result, Store, StoreContextMut, StoreLimits, UpdateDeadline, bail, error::Context as _, + AsContextMut as _, Engine, Result, Store, StoreContextMut, StoreLimits, UpdateDeadline, bail, }; use wasmtime_cli_flags::opt::WasmtimeOptionValue; use wasmtime_wasi::p2::{StreamError, StreamResult}; use wasmtime_wasi::{WasiCtx, WasiCtxBuilder, WasiCtxView, WasiView}; #[cfg(feature = "component-model-async")] use wasmtime_wasi_http::handler::p2::bindings as p2; -use wasmtime_wasi_http::handler::{HandlerState, Proxy, ProxyHandler, ProxyPre, StoreBundle}; +use wasmtime_wasi_http::handler::{ + self, HandlerState, Instance, ProxyHandler, ProxyPre, ShouldAccept, ViewFn, WorkerExpiration, + WorkerState, WorkerStatus, +}; use wasmtime_wasi_http::io::TokioIo; use wasmtime_wasi_http::{WasiHttpCtx, p2::WasiHttpView}; @@ -67,6 +71,8 @@ struct Host { #[cfg(feature = "profiling")] guest_profiler: Option>, + + write_profile: Option, } impl WasiView for Host { @@ -347,20 +353,20 @@ impl ServeCommand { .await } - fn new_store(&self, engine: &Engine, req_id: Option) -> Result> { + fn new_store(&self, engine: &Engine, instance_id: Option) -> Result> { let mut builder = WasiCtxBuilder::new(); self.run.configure_wasip2(&mut builder)?; - if let Some(req_id) = req_id { - builder.env("REQUEST_ID", req_id.to_string()); + if let Some(instance_id) = instance_id { + builder.env("INSTANCE_ID", instance_id.to_string()); } let stdout_prefix: String; let stderr_prefix: String; - match req_id { - Some(req_id) if !self.no_logging_prefix => { - stdout_prefix = format!("stdout [{req_id}] :: "); - stderr_prefix = format!("stderr [{req_id}] :: "); + match instance_id { + Some(instance_id) if !self.no_logging_prefix => { + stdout_prefix = format!("stdout [{instance_id}] :: "); + stderr_prefix = format!("stderr [{instance_id}] :: "); } _ => { stdout_prefix = "".to_string(); @@ -390,6 +396,7 @@ impl ServeCommand { wasi_keyvalue: None, #[cfg(feature = "profiling")] guest_profiler: None, + write_profile: None, }; if self.run.common.wasi.nn == Some(true) { @@ -659,19 +666,18 @@ impl ServeCommand { 1 }; - let handler = ProxyHandler::new( - HostHandlerState { - cmd: self, - engine, - component, - max_instance_reuse_count, - max_instance_concurrent_reuse_count, - // Give one shutdown guard to this handler which will track the - // full lifetime of any instances spawned. - _shutdown_guard: Box::new(shutdown.clone().increment()), - }, + let handler = ProxyHandler::new(HostHandlerState { + cmd: self, + engine, + component, + max_instance_reuse_count, + max_instance_concurrent_reuse_count, instance, - ); + next_instance_id: AtomicU64::default(), + // Give one shutdown guard to this handler which will track the + // full lifetime of any instances spawned. + _shutdown_guard: Box::new(shutdown.clone().increment()), + }); loop { // Wait for a socket, but also "race" against shutdown to break out @@ -765,46 +771,138 @@ impl ServeCommand { } } +pin_project! { + struct HostWorkerExpiration { + idle_timeout: Duration, + request_timeout: Duration, + #[pin] + sleep: tokio::time::Sleep, + } +} + +impl WorkerExpiration for HostWorkerExpiration { + fn poll( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + status: WorkerStatus, + start: Instant, + ) -> Poll<()> { + let mut me = self.project(); + + let timeout = match status { + WorkerStatus::Idle => *me.idle_timeout, + // TODO: add a dedicated `post_return_timeout` config setting + // instead of reusing `request_timeout` for + // `WorkerStatus::PostReturn` here + WorkerStatus::Requests | WorkerStatus::PostReturn => *me.request_timeout, + }; + + if let Some(deadline) = start.checked_add(timeout) { + let deadline = deadline.into(); + if deadline != me.sleep.deadline() { + me.sleep.as_mut().reset(deadline); + } + me.sleep.poll(cx) + } else { + Poll::Pending + } + } +} + +struct HostWorkerState { + instance_id: u64, + max_instance_reuse_count: usize, + max_instance_concurrent_reuse_count: usize, + request_timeout: Duration, +} + +impl WorkerState for HostWorkerState { + type StoreData = Host; + + fn should_accept_request(&self, concurrent_count: usize, total_count: usize) -> ShouldAccept { + if total_count >= self.max_instance_reuse_count { + ShouldAccept::Never + } else if concurrent_count >= self.max_instance_concurrent_reuse_count { + ShouldAccept::No + } else { + ShouldAccept::Yes + } + } + + fn on_request_start( + &self, + req: &handler::Request, + ) -> Pin + 'static + Send + Sync>> { + log::info!( + "Instance {} handling request {} {}", + self.instance_id, + req.method(), + req.uri() + ); + + Box::pin(tokio::time::sleep(self.request_timeout)) + } + + fn drop(&self, mut store: Store, result: Result<(), wasmtime::Error>) { + if let Err(error) = result { + eprintln!("worker failed: {error:?}"); + } + + if let Some(write_profile) = store.data_mut().write_profile.take() { + write_profile(store.as_context_mut()); + } + + drop(store); + } +} + struct HostHandlerState { cmd: ServeCommand, engine: Engine, component: Component, max_instance_reuse_count: usize, max_instance_concurrent_reuse_count: usize, + instance: ProxyPre, + next_instance_id: AtomicU64, _shutdown_guard: Box, } impl HandlerState for HostHandlerState { type StoreData = Host; + type WorkerExpiration = HostWorkerExpiration; + type WorkerState = HostWorkerState; - fn new_store(&self, req_id: Option) -> Result> { - let mut store = self.cmd.new_store(&self.engine, req_id)?; + async fn instantiate( + &self, + ) -> Result> { + let instance_id = self.next_instance_id.fetch_add(1, Ordering::Relaxed); + let mut store = self.cmd.new_store(&self.engine, Some(instance_id))?; let write_profile = setup_epoch_handler(&self.cmd, &mut store, self.component.clone())?; + store.data_mut().write_profile = Some(write_profile); - Ok(StoreBundle { - store, - write_profile, - }) - } - - fn request_timeout(&self) -> Duration { - self.cmd.run.common.wasm.timeout.unwrap_or(Duration::MAX) - } - - fn idle_instance_timeout(&self) -> Duration { - self.cmd.idle_instance_timeout - } - - fn max_instance_reuse_count(&self) -> usize { - self.max_instance_reuse_count - } + let proxy = self.instance.instantiate_async(&mut store).await?; - fn max_instance_concurrent_reuse_count(&self) -> usize { - self.max_instance_concurrent_reuse_count - } + let view = match &self.instance { + ProxyPre::P2(_) => ViewFn::P2(wasmtime_wasi_http::p2::WasiHttpView::http), + ProxyPre::P3(_) => ViewFn::P3(wasmtime_wasi_http::p3::WasiHttpView::http), + }; - fn handle_worker_error(&self, error: wasmtime::Error) { - eprintln!("worker error: {error}"); + Ok(Instance { + store, + proxy, + view, + expiration: HostWorkerExpiration { + idle_timeout: self.cmd.idle_instance_timeout, + request_timeout: self.cmd.run.common.wasm.timeout.unwrap_or(Duration::MAX), + sleep: tokio::time::sleep(Duration::MAX), + }, + state: HostWorkerState { + max_instance_reuse_count: self.max_instance_reuse_count, + max_instance_concurrent_reuse_count: self.max_instance_concurrent_reuse_count, + instance_id, + request_timeout: self.cmd.run.common.wasm.timeout.unwrap_or(Duration::MAX), + }, + }) } } @@ -1134,162 +1232,21 @@ async fn handle_request( handler: ProxyHandler, req: Request, ) -> Result>> { - use tokio::sync::oneshot; - - let req_id = handler.next_req_id(); - - log::info!( - "Request {req_id} handling {} to {}", - req.method(), - req.uri() - ); - - // Here we must declare different channel types for p2 and p3 since p2's - // `WasiHttpView::new_response_outparam` expects a specific kind of sender - // that uses `p2::http::types::ErrorCode`, and we don't want to have to - // convert from the p3 `ErrorCode` to the p2 one, only to convert again to - // `wasmtime::Error`. - - type P2Response = Result< - hyper::Response, - p2::http::types::ErrorCode, - >; - type P3Response = hyper::Response>; - - enum Sender { - P2(oneshot::Sender), - P3(oneshot::Sender), - } - - enum Receiver { - P2(oneshot::Receiver), - P3(oneshot::Receiver), - } - - let (tx, rx) = match handler.instance_pre() { - ProxyPre::P2(_) => { - let (tx, rx) = oneshot::channel(); - (Sender::P2(tx), Receiver::P2(rx)) - } - ProxyPre::P3(_) => { - let (tx, rx) = oneshot::channel(); - (Sender::P3(tx), Receiver::P3(rx)) - } - }; - - handler.spawn( - if handler.state().max_instance_reuse_count() == 1 { - Some(req_id) - } else { - None - }, - Box::new(move |store, proxy| { - Box::pin( - async move { - match proxy { - Proxy::P2(proxy) => { - let Sender::P2(tx) = tx else { unreachable!() }; - let (req, out) = store.with(move |mut store| { - let req = store - .data_mut() - .http() - .new_incoming_request(p2::http::types::Scheme::Http, req)?; - let out = store.data_mut().http().new_response_outparam(tx)?; - wasmtime::error::Ok((req, out)) - })?; - - proxy - .wasi_http_incoming_handler() - .call_handle(store, req, out) - .await - } - Proxy::P3(proxy) => { - use wasmtime_wasi_http::p3::bindings::http::types::{ - ErrorCode, Request, - }; - - let Sender::P3(tx) = tx else { unreachable!() }; - let (req, body) = req.into_parts(); - let body = body.map_err(ErrorCode::from_hyper_request_error); - let req = http::Request::from_parts(req, body); - let (request, request_io_result) = Request::from_http(req); - let res = proxy.handle(store, request).await??; - let res = store - .with(|mut store| res.into_http(&mut store, request_io_result))?; - - // With the guest response now transformed into a - // host-compatible response layer one more wrapper - // around the body. This layer is solely responsible - // for dropping a channel half on destruction, and - // this enables waiting here until the body is - // consumed by waiting for this destruction to - // happen. - let (resp_body_tx, resp_body_rx) = oneshot::channel(); - let res = res.map(|body| { - let body = body.map_err(|e| e.into()); - P3BodyWrapper { - _tx: resp_body_tx, - body, - } - .boxed_unsync() - }); - - // If `wasmtime serve` is waiting on this response - // and actually got it then wait for the body to - // finish, otherwise it's thrown away so skip that - // step. - if tx.send(res).is_ok() { - _ = resp_body_rx.await; - } - - Ok(()) - } - } - } - .map(move |result| { - if let Err(error) = result { - eprintln!("[{req_id}] :: {error:?}"); - } - }), - ) - }), - ); - - return Ok(match rx { - Receiver::P2(rx) => rx - .await - .context("guest never invoked `response-outparam::set` method")? - .map_err(|e| wasmtime::Error::from(e))? - .map(|body| body.map_err(|e| e.into()).boxed_unsync()), - Receiver::P3(rx) => rx.await?, - }); - - // Forwarding implementation of `Body` to an inner `B` with the sole purpose - // of carrying `_tx` to its destruction. - struct P3BodyWrapper { - body: B, - _tx: oneshot::Sender<()>, - } - - impl Body for P3BodyWrapper { - type Data = B::Data; - type Error = B::Error; - - fn poll_frame( - mut self: Pin<&mut Self>, - cx: &mut Context<'_>, - ) -> Poll, Self::Error>>> { - Pin::new(&mut self.body).poll_frame(cx) - } - - fn is_end_stream(&self) -> bool { - self.body.is_end_stream() - } - - fn size_hint(&self) -> SizeHint { - self.body.size_hint() - } - } + use wasmtime_wasi_http::p3::bindings::http::types::ErrorCode; + + handler + .handle(req.map(|body| { + body.map_err(ErrorCode::from_hyper_request_error) + .map_err(handler::ErrorCode::from) + .boxed_unsync() + })) + .await + .map(|v| { + v.map(|body| { + body.map_err(|code| ErrorCode::from(code).into()) + .boxed_unsync() + }) + }) } #[derive(Clone)]