diff --git a/crates/guest-rust/src/rt/async_support.rs b/crates/guest-rust/src/rt/async_support.rs index 2a5218ed1..baa5b7497 100644 --- a/crates/guest-rust/src/rt/async_support.rs +++ b/crates/guest-rust/src/rt/async_support.rs @@ -1,12 +1,13 @@ #![deny(missing_docs)] +use self::try_lock::TryLock; use alloc::boxed::Box; use alloc::collections::BTreeMap; use alloc::sync::Arc; use alloc::task::Wake; use core::ffi::c_void; use core::future::Future; -use core::mem; +use core::mem::{self, ManuallyDrop}; use core::pin::Pin; use core::ptr; use core::sync::atomic::{AtomicU32, Ordering}; @@ -104,15 +105,28 @@ use spawn_disabled as spawn; /// Represents a task created by either a call to an async-lifted export or a /// future run using `block_on` or `start_task`. -struct FutureState<'a> { +struct TaskState<'a> { /// Remaining work to do (if any) before this task can be considered "done". /// /// Note that we won't tell the host the task is done until this is drained /// and `waitables` is empty. tasks: spawn::Tasks<'a>, - /// The waitable set containing waitables created by this task, if any. - waitable_set: Option, + /// Dual-mode rust-level Waker and C ABI level "task" for wasip3 + /// integration. + shared: Arc, + + /// Clone of `shared` field, but represented as `std::task::Waker`. + waker: Waker, + + /// State related to supporting inter-task wakeup scenarios. + inter_task_wakeup: inter_task_wakeup::State, +} + +struct SharedTaskState { + /// One of `SLEEP_STATE_*` indicating the current status. + sleep_state: AtomicU32, + inter_task_stream: inter_task_wakeup::WakerState, /// State of all waitables in `waitable_set`, and the ptr/callback they're /// associated with. @@ -123,56 +137,44 @@ struct FutureState<'a> { // the `wasi_snapshot_preview1` adapter when targeting `wasm32-wasip2` and // later, and that's expensive enough that we'd prefer to avoid it for apps // which otherwise make no use of the adapter. - waitables: BTreeMap, - - /// Raw structure used to pass to `cabi::wasip3_task_set` - wasip3_task: cabi::wasip3_task, - - /// Rust-level state for the waker, notably a bool as to whether this has - /// been woken. - waker: Arc, + // + // Also note that the `TryLock` here should never be contended, but it's + // used for interior mutability. + waitables: TryLock>, - /// Clone of `waker` field, but represented as `std::task::Waker`. - waker_clone: Waker, + /// The waitable set containing waitables created by this task, if any. + // + // Note the `TryLock` is the same as `waitables` above, it's serving the + // purpose of interior mutability. + waitable_set: TryLock>, +} - /// State related to supporting inter-task wakeup scenarios. - inter_task_wakeup: inter_task_wakeup::State, +/// An entry of `SharedTaskState::waitables` which is added through the C ABI. +struct CabiWaitable { + callback: unsafe extern "C" fn(*mut c_void, u32), + callback_ptr: *mut c_void, } -impl FutureState<'_> { - fn new(future: BoxFuture<'_>) -> FutureState<'_> { - let waker = Arc::new(FutureWaker::default()); - FutureState { - waker_clone: waker.clone().into(), - waker, +unsafe impl Send for CabiWaitable {} + +impl TaskState<'_> { + fn new(future: BoxFuture<'_>) -> TaskState<'_> { + let shared = Arc::new(SharedTaskState { + sleep_state: AtomicU32::new(0), + inter_task_stream: Default::default(), + waitables: Default::default(), + waitable_set: Default::default(), + }); + TaskState { + waker: shared.clone().into(), + shared, tasks: spawn::Tasks::new(future), - waitable_set: None, - waitables: BTreeMap::new(), - wasip3_task: cabi::wasip3_task { - // This pointer is filled in before calling `wasip3_task_set`. - ptr: ptr::null_mut(), - version: cabi::WASIP3_TASK_V1, - waitable_register, - waitable_unregister, - }, inter_task_wakeup: Default::default(), } } - fn get_or_create_waitable_set(&mut self) -> &WaitableSet { - self.waitable_set.get_or_insert_with(WaitableSet::new) - } - - fn add_waitable(&mut self, waitable: u32) { - self.get_or_create_waitable_set().join(waitable) - } - - fn remove_waitable(&mut self, waitable: u32) { - WaitableSet::remove_waitable_from_all_sets(waitable) - } - fn remaining_work(&self) -> bool { - !self.waitables.is_empty() + !self.shared.waitables.try_lock().unwrap().is_empty() } /// Handles the `event{0,1,2}` event codes and returns a corresponding @@ -190,7 +192,7 @@ impl FutureState<'_> { // Cancellation is mapped to destruction in Rust, so return a // code/bool indicating we're done. The caller will then - // appropriately deallocate this `FutureState` which will + // appropriately deallocate this `TaskState` which will // transitively run all destructors. return CallbackCode::Exit; } @@ -200,7 +202,7 @@ impl FutureState<'_> { self.with_p3_task_set(|me| { // Transition our sleep state to ensure that the inter-task stream // isn't used since there's no need to use that here. - me.waker + me.shared .sleep_state .store(SLEEP_STATE_WOKEN, Ordering::Relaxed); @@ -221,13 +223,13 @@ impl FutureState<'_> { me.cancel_inter_task_stream_read(); loop { - let mut context = Context::from_waker(&me.waker_clone); + let mut context = Context::from_waker(&me.waker); // On each turn of this loop reset the state to "polling" // which clears out any pending wakeup if one was sent. This // in theory helps minimize wakeups from previous iterations // happening in this iteration. - me.waker + me.shared .sleep_state .store(SLEEP_STATE_POLLING, Ordering::Relaxed); @@ -242,7 +244,8 @@ impl FutureState<'_> { Poll::Ready(()) => { assert!(me.tasks.is_empty()); if me.remaining_work() { - let waitable = me.waitable_set.as_ref().unwrap().as_raw(); + let set = me.shared.waitable_set.try_lock().unwrap(); + let waitable = set.as_ref().unwrap().as_raw(); break CallbackCode::Wait(waitable); } else { break CallbackCode::Exit; @@ -255,10 +258,12 @@ impl FutureState<'_> { // something. Poll::Pending => { assert!(!me.tasks.is_empty()); - if me.waker.sleep_state.load(Ordering::Relaxed) == SLEEP_STATE_WOKEN { + if me.shared.sleep_state.load(Ordering::Relaxed) == SLEEP_STATE_WOKEN { if me.remaining_work() { - let (event0, event1, event2) = - me.waitable_set.as_ref().unwrap().poll(); + let (event0, event1, event2) = { + let set = me.shared.waitable_set.try_lock().unwrap(); + set.as_ref().unwrap().poll() + }; if event0 != EVENT_NONE { me.deliver_waitable_event(event1, event2); continue; @@ -270,11 +275,12 @@ impl FutureState<'_> { // Transition our state to "sleeping" so wakeup // notifications know that they need to signal the // inter-task stream. - me.waker + me.shared .sleep_state .store(SLEEP_STATE_SLEEPING, Ordering::Relaxed); me.read_inter_task_stream(); - let waitable = me.waitable_set.as_ref().unwrap().as_raw(); + let set = me.shared.waitable_set.try_lock().unwrap(); + let waitable = set.as_ref().unwrap().as_raw(); break CallbackCode::Wait(waitable); } } @@ -286,7 +292,7 @@ impl FutureState<'_> { /// waitable should be present because it's part of the waitable set which /// is kept in-sync with our map. fn deliver_waitable_event(&mut self, waitable: u32, code: u32) { - self.remove_waitable(waitable); + WaitableSet::remove_waitable_from_all_sets(waitable); if self .inter_task_wakeup @@ -295,17 +301,19 @@ impl FutureState<'_> { return; } - let (ptr, callback) = self.waitables.remove(&waitable).unwrap(); + let c = { + let mut waitables = self.shared.waitables.try_lock().unwrap(); + waitables.remove(&waitable).unwrap() + }; unsafe { - callback(ptr, code); + (c.callback)(c.callback_ptr, code); } } fn with_p3_task_set(&mut self, f: impl FnOnce(&mut Self) -> R) -> R { - // Finish our `wasip3_task` by initializing its self-referential pointer, - // and then register it for the duration of this function with - // `wasip3_task_set`. The previous value of `wasip3_task_set` will get - // restored when this function returns. + // Initialize a temporary `wasip3_task` structure on the stack and + // inform `wasip3_task_set` that we're now within that task. Note the + // RAII guard to reset the task back to its previous contents. struct ResetTask(*mut cabi::wasip3_task); impl Drop for ResetTask { fn drop(&mut self) { @@ -314,16 +322,32 @@ impl FutureState<'_> { } } } - let self_raw = self as *mut FutureState<'_>; - self.wasip3_task.ptr = self_raw.cast(); - let prev = unsafe { cabi::wasip3_task_set(&mut self.wasip3_task) }; + // The `ptr` field of `wasip3_task` is to `SharedTaskState` which is + // what's cloned/handed out/etc. + let shared_raw: *const SharedTaskState = &*self.shared; + let mut wasip3_task = cabi::wasip3_task_v2 { + v1: cabi::wasip3_task { + ptr: shared_raw.cast_mut().cast(), + version: cabi::WASIP3_TASK_V2, + waitable_register: SharedTaskState::CABI_VTABLE.waitable_register, + waitable_unregister: SharedTaskState::CABI_VTABLE.waitable_unregister, + }, + vtable: &SharedTaskState::CABI_VTABLE, + }; + + // Explicitly take a mutable borrow on the entire `wasip3_task` + // structure, and then cast its raw pointer to the "smaller" historical + // version, ensuring the final pointer has provenace over the entire + // structure. + let wasip3_task: *mut cabi::wasip3_task_v2 = &mut wasip3_task; + let prev = unsafe { cabi::wasip3_task_set(wasip3_task.cast::()) }; let _reset = ResetTask(prev); f(self) } } -impl Drop for FutureState<'_> { +impl Drop for TaskState<'_> { fn drop(&mut self) { // If there's an active read of the inter-task stream, go ahead and // cancel it, since we're about to drop the stream anyway. @@ -342,33 +366,80 @@ impl Drop for FutureState<'_> { } } -unsafe extern "C" fn waitable_register( - ptr: *mut c_void, - waitable: u32, - callback: unsafe extern "C" fn(*mut c_void, u32), - callback_ptr: *mut c_void, -) -> *mut c_void { - let ptr = ptr.cast::>(); - assert!(!ptr.is_null()); - unsafe { - (*ptr).add_waitable(waitable); - match (*ptr).waitables.insert(waitable, (callback_ptr, callback)) { - Some((prev, _)) => prev, +impl SharedTaskState { + const CABI_VTABLE: cabi::wasip3_task_vtable = cabi::wasip3_task_vtable { + version: cabi::WASIP3_TASK_V2, + waitable_register: Self::cabi_waitable_register, + waitable_unregister: Self::cabi_waitable_unregister, + drop: Self::cabi_drop, + clone: Self::cabi_clone, + }; + + /// Adds the `waitable` provided to this task's waitable set. + fn add_waitable(&self, waitable: u32) { + let mut set = self.waitable_set.try_lock().unwrap(); + set.get_or_insert_with(WaitableSet::new).join(waitable); + } + + /// Implementation of the CABI `waitable_register` function. + fn waitable_register( + &self, + waitable: u32, + callback: unsafe extern "C" fn(*mut c_void, u32), + callback_ptr: *mut c_void, + ) -> *mut c_void { + self.add_waitable(waitable); + let mut waitables = self.waitables.try_lock().unwrap(); + let c = CabiWaitable { + callback, + callback_ptr, + }; + match waitables.insert(waitable, c) { + Some(prev) => prev.callback_ptr, None => ptr::null_mut(), } } -} -unsafe extern "C" fn waitable_unregister(ptr: *mut c_void, waitable: u32) -> *mut c_void { - let ptr = ptr.cast::>(); - assert!(!ptr.is_null()); - unsafe { - (*ptr).remove_waitable(waitable); - match (*ptr).waitables.remove(&waitable) { - Some((prev, _)) => prev, + /// Implementation of the CABI `waitable_unregister` function. + fn waitable_unregister(&self, waitable: u32) -> *mut c_void { + WaitableSet::remove_waitable_from_all_sets(waitable); + let mut waitables = self.waitables.try_lock().unwrap(); + match waitables.remove(&waitable) { + Some(prev) => prev.callback_ptr, None => ptr::null_mut(), } } + + /// Helper to go from a raw `c_void` FFI pointer to a typed + /// self-representation. + unsafe fn cabi_to_self(ptr: *mut c_void) -> ManuallyDrop> { + unsafe { ManuallyDrop::new(Arc::from_raw(ptr.cast::())) } + } + + unsafe extern "C" fn cabi_waitable_register( + ptr: *mut c_void, + waitable: u32, + callback: unsafe extern "C" fn(*mut c_void, u32), + callback_ptr: *mut c_void, + ) -> *mut c_void { + let me = unsafe { Self::cabi_to_self(ptr) }; + me.waitable_register(waitable, callback, callback_ptr) + } + + unsafe extern "C" fn cabi_waitable_unregister(ptr: *mut c_void, waitable: u32) -> *mut c_void { + let me = unsafe { Self::cabi_to_self(ptr) }; + me.waitable_unregister(waitable) + } + + unsafe extern "C" fn cabi_clone(ptr: *mut c_void) -> *mut c_void { + let me = unsafe { Self::cabi_to_self(ptr) }; + Arc::into_raw(Arc::clone(&me)).cast_mut().cast() + } + + unsafe extern "C" fn cabi_drop(ptr: *mut c_void) { + let mut me = unsafe { Self::cabi_to_self(ptr) }; + unsafe { ManuallyDrop::drop(&mut me) } + } } /// Status for "this task is actively being polled" @@ -380,14 +451,7 @@ const SLEEP_STATE_WOKEN: u32 = 1; /// Wakeups on this status signal the inter-task stream. const SLEEP_STATE_SLEEPING: u32 = 2; -#[derive(Default)] -struct FutureWaker { - /// One of `SLEEP_STATE_*` indicating the current status. - sleep_state: AtomicU32, - inter_task_stream: inter_task_wakeup::WakerState, -} - -impl Wake for FutureWaker { +impl Wake for SharedTaskState { fn wake(self: Arc) { Self::wake_by_ref(&self) } @@ -483,11 +547,11 @@ impl ReturnCode { /// immediately with its result. #[doc(hidden)] pub fn start_task(task: impl Future + 'static) -> i32 { - // Allocate a new `FutureState` which will track all state necessary for + // Allocate a new `TaskState` which will track all state necessary for // our exported task. - let state = Box::into_raw(Box::new(FutureState::new(Box::pin(task)))); + let state = Box::into_raw(Box::new(TaskState::new(Box::pin(task)))); - // Store our `FutureState` into our context-local-storage slot and then + // Store our `TaskState` into our context-local-storage slot and then // pretend we got EVENT_NONE to kick off everything. // // SAFETY: we should own `context.set` as we're the root level exported @@ -505,13 +569,13 @@ pub fn start_task(task: impl Future + 'static) -> i32 { /// /// # Unsafety /// -/// This function assumes that `context_get()` returns a `FutureState`. +/// This function assumes that `context_get()` returns a `TaskState`. #[doc(hidden)] pub unsafe fn callback(event0: u32, event1: u32, event2: u32) -> u32 { // Acquire our context-local state, assert it's not-null, and then reset // the state to null while we're running to help prevent any unintended // usage. - let state = context_get().cast::>(); + let state = context_get().cast::>(); assert!(!state.is_null()); unsafe { context_set(ptr::null_mut()); @@ -540,7 +604,7 @@ pub unsafe fn callback(event0: u32, event1: u32, event2: u32) -> u32 { // TODO: refactor so `'static` bounds aren't necessary pub fn block_on(future: impl Future) -> T { let mut result = None; - let mut state = FutureState::new(Box::pin(async { + let mut state = TaskState::new(Box::pin(async { result = Some(future.await); })); let mut event = (EVENT_NONE, 0, 0); @@ -550,8 +614,14 @@ pub fn block_on(future: impl Future) -> T { drop(state); break result.unwrap(); } - CallbackCode::Yield => event = state.waitable_set.as_ref().unwrap().poll(), - CallbackCode::Wait(_) => event = state.waitable_set.as_ref().unwrap().wait(), + CallbackCode::Yield => { + let set = state.shared.waitable_set.try_lock().unwrap(); + event = set.as_ref().unwrap().poll() + } + CallbackCode::Wait(_) => { + let set = state.shared.waitable_set.try_lock().unwrap(); + event = set.as_ref().unwrap().wait() + } } } } diff --git a/crates/guest-rust/src/rt/async_support/cabi.rs b/crates/guest-rust/src/rt/async_support/cabi.rs index d696d1e55..67ebbc569 100644 --- a/crates/guest-rust/src/rt/async_support/cabi.rs +++ b/crates/guest-rust/src/rt/async_support/cabi.rs @@ -43,6 +43,29 @@ //! //! Additionally for now this file is serving as documentation of this //! interface. +//! +//! # Revisions +//! +//! This interface is intended to be evolvable over time if needed. Notably the +//! original task structure, `wasip3_task`, has a `version` field where certain +//! version levels imply the existence of certain fields. The historical +//! revisions are: +//! +//! ### V1 +//! +//! This was the original version. This is the original specification of +//! `wasip3_task_set` and `wasip3_task`. +//! +//! ### V2 +//! +//! This was added 2026-06-17 in response to #1618. This added +//! `wasip3_task_v2` and `wasip3_task_vtable`. This version enables cloning a +//! task to create a strong reference to it independent of the stack lifetime +//! that `wasip3_task_set` is required to uphold. This necessitated introducing +//! `clone` and `drop` callbacks to manage the lifetime of this reference. +//! While doing this everything was moved into a vtable structure instead of +//! inline in `wasip3_task` to make it easier to add more function pointers +//! in the future if necessary. use core::ffi::c_void; @@ -66,6 +89,7 @@ extern_wasm! { /// The first version of `wasip3_task` which implies the existence of the /// fields `ptr`, `waitable_register`, and `waitable_unregister`. pub const WASIP3_TASK_V1: u32 = 1; +pub const WASIP3_TASK_V2: u32 = 2; /// Indirect "vtable" used to connect imported functions and exported tasks. /// Executors (e.g. exported functions) define and manage this while imports @@ -80,6 +104,41 @@ pub struct wasip3_task { /// below as the first argument. pub ptr: *mut c_void, + /// See `wasip3_task_vtable::waitable_register`. + pub waitable_register: unsafe extern "C" fn( + ptr: *mut c_void, + waitable: u32, + callback: unsafe extern "C" fn(callback_ptr: *mut c_void, code: u32), + callback_ptr: *mut c_void, + ) -> *mut c_void, + + /// See `wasip3_task_vtable::waitable_unregister`. + pub waitable_unregister: unsafe extern "C" fn(ptr: *mut c_void, waitable: u32) -> *mut c_void, +} + +unsafe impl Send for wasip3_task {} +unsafe impl Sync for wasip3_task {} + +/// Representation when `wasip3_task::version` is `WASIP3_TASK_V2`. +#[repr(C)] +pub struct wasip3_task_v2 { + /// The original task structure. + pub v1: wasip3_task, + + /// An always-valid pointer to a list of function pointers, described + /// below. + pub vtable: &'static wasip3_task_vtable, +} + +/// Function pointer operations that can operate on `wasip3_task::ptr`. +/// +/// This was introduced in the "v2" ABI and is a member of `wasip3_task_v2`. +#[repr(C)] +pub struct wasip3_task_vtable { + /// Currently `WASIP3_TASK_V2` as that was the first version that specified + /// vtables. + pub version: u32, + /// Register a new `waitable` for this exported task. /// /// This exported task will add `waitable` to its `waitable-set`. When it @@ -104,4 +163,18 @@ pub struct wasip3_task { /// Returns the `callback_ptr` passed to `waitable_register` if present, or /// `NULL` if it's not present. pub waitable_unregister: unsafe extern "C" fn(ptr: *mut c_void, waitable: u32) -> *mut c_void, + + /// Clones this task's pointer to create a separately owned pointer which + /// can be persisted outside the stack frame that this is being used + /// within. + /// + /// Cloned values must be dropped/deallocated with `drop` below. + pub clone: unsafe extern "C" fn(ptr: *mut c_void) -> *mut c_void, + + /// Drops and deallocates the provided pointer previously created by a + /// call to the `clone` callback above. + /// + /// This must not be called on the `ptr` value within `wasip3_task::ptr` as + /// that's not managed with this lifetime. + pub drop: unsafe extern "C" fn(ptr: *mut c_void), } diff --git a/crates/guest-rust/src/rt/async_support/inter_task_wakeup.rs b/crates/guest-rust/src/rt/async_support/inter_task_wakeup.rs index d36d34bd6..f827a5233 100644 --- a/crates/guest-rust/src/rt/async_support/inter_task_wakeup.rs +++ b/crates/guest-rust/src/rt/async_support/inter_task_wakeup.rs @@ -1,6 +1,6 @@ -use super::FutureState; +use super::TaskState; use crate::rt::async_support::try_lock::TryLock; -use crate::rt::async_support::{BLOCKED, COMPLETED}; +use crate::rt::async_support::{BLOCKED, COMPLETED, WaitableSet}; use crate::{RawStreamReader, RawStreamWriter, StreamOps, UnitStreamOps}; use core::ptr; @@ -18,7 +18,7 @@ pub struct State { stream_reading: bool, } -impl FutureState<'_> { +impl TaskState<'_> { pub(super) fn read_inter_task_stream(&mut self) { // Lazily allocate the inter-task stream now that we're actually going // to sleep. We don't know where the wakeup notification will come from @@ -27,7 +27,7 @@ impl FutureState<'_> { assert!(!self.inter_task_wakeup.stream_reading); let (writer, reader) = UnitStreamOps::new(); self.inter_task_wakeup.stream = Some(reader); - let mut waker_stream = self.waker.inter_task_stream.lock.try_lock().unwrap(); + let mut waker_stream = self.shared.inter_task_stream.lock.try_lock().unwrap(); assert!(waker_stream.is_none()); *waker_stream = Some(writer); } @@ -44,7 +44,7 @@ impl FutureState<'_> { let rc = unsafe { UnitStreamOps.start_read(handle, ptr::null_mut(), 1) }; assert_eq!(rc, BLOCKED); self.inter_task_wakeup.stream_reading = true; - self.add_waitable(handle); + self.shared.add_waitable(handle); } } @@ -63,7 +63,7 @@ impl FutureState<'_> { unsafe { UnitStreamOps.cancel_read(handle); } - self.remove_waitable(handle); + WaitableSet::remove_waitable_from_all_sets(handle); } } diff --git a/crates/guest-rust/src/rt/async_support/inter_task_wakeup_disabled.rs b/crates/guest-rust/src/rt/async_support/inter_task_wakeup_disabled.rs index 3f7462cfb..6cd1fd140 100644 --- a/crates/guest-rust/src/rt/async_support/inter_task_wakeup_disabled.rs +++ b/crates/guest-rust/src/rt/async_support/inter_task_wakeup_disabled.rs @@ -1,9 +1,9 @@ -use super::FutureState; +use super::TaskState; #[derive(Default)] pub struct State; -impl FutureState<'_> { +impl TaskState<'_> { pub(super) fn read_inter_task_stream(&mut self) { assert!( self.remaining_work(), diff --git a/crates/guest-rust/src/rt/async_support/waitable.rs b/crates/guest-rust/src/rt/async_support/waitable.rs index 4e45f911c..d0b8c97f0 100644 --- a/crates/guest-rust/src/rt/async_support/waitable.rs +++ b/crates/guest-rust/src/rt/async_support/waitable.rs @@ -19,11 +19,59 @@ use core::task::{Context, Poll, Waker}; pub struct WaitableOperation { op: S, state: WaitableOperationState, + task: Option, /// Storage for the final result of this asynchronous operation, if it's /// completed asynchronously. completion_status: CompletionStatus, } +struct CabiTask { + ptr: *mut c_void, + registered: Option, + vtable: &'static cabi::wasip3_task_vtable, +} + +impl CabiTask { + /// Creates a new task from the raw C ABI representation provided. + /// + /// # Safety + /// + /// The `task` provided must be valid and adhere to C ABI conventions. + unsafe fn new(task: *mut cabi::wasip3_task_v2) -> CabiTask { + // SAFETY: the validity of `task` is up to the caller. + unsafe { + CabiTask { + ptr: ((*task).vtable.clone)((*task).v1.ptr), + registered: None, + vtable: (*task).vtable, + } + } + } + + fn unregister(&mut self, waitable: u32) -> *mut c_void { + self.registered = None; + // SAFETY: this was created from a valid task, so this should be safe + // to invoke. + unsafe { (self.vtable.waitable_unregister)(self.ptr, waitable) } + } +} + +unsafe impl Send for CabiTask {} +unsafe impl Sync for CabiTask {} + +impl Drop for CabiTask { + fn drop(&mut self) { + if let Some(waitable) = self.registered { + self.unregister(waitable); + } + // SAFETY: this was created from a valid atask, so this should be safe + // to invoke. + unsafe { + (self.vtable.drop)(self.ptr); + } + } +} + /// Structure used to store the `u32` return code from the canonical ABI about /// an asynchronous operation. /// @@ -140,6 +188,7 @@ where WaitableOperation { op, state: WaitableOperationState::Start(state), + task: None, completion_status: CompletionStatus { code: None, waker: None, @@ -153,6 +202,7 @@ where ) -> ( &mut S, &mut WaitableOperationState, + &mut Option, Pin<&mut CompletionStatus>, ) { // SAFETY: this is the one method used to project from `Pin<&mut Self>` @@ -165,6 +215,7 @@ where ( &mut me.op, &mut me.state, + &mut me.task, Pin::new_unchecked(&mut me.completion_status), ) } @@ -175,7 +226,7 @@ where /// * Fill in `completion_status` with the result of a completion event. /// * Call `cx.waker().wake()`. pub fn register_waker(self: Pin<&mut Self>, waitable: u32, cx: &mut Context) { - let (_, _, mut completion_status) = self.pin_project(); + let (_, _, last_task, mut completion_status) = self.pin_project(); debug_assert!(completion_status.as_mut().code_mut().is_none()); *completion_status.as_mut().waker_mut() = Some(cx.waker().clone()); @@ -193,13 +244,33 @@ where assert!(!task.is_null()); assert!((*task).version >= cabi::WASIP3_TASK_V1); let ptr: *mut CompletionStatus = completion_status.get_unchecked_mut(); + + // For the v2+ ABI clone the task structure to store internally + // within this waitable operation, if we're not already storing + // this task. This ensures that `unregister_waker` below works + // correctly in cross-task situations. + // + // Note that this must happen before the `waitable_register` below + // to ensure we're fully removed from the previous task, if + // applicable, before registering with another task. + if (*task).version >= cabi::WASIP3_TASK_V2 { + let task = task.cast::(); + let last_task = match last_task { + Some(prev) if prev.ptr == (*task).v1.ptr => prev, + _ => last_task.insert(CabiTask::new(task)), + }; + last_task.registered = Some(waitable); + } + let prev = ((*task).waitable_register)((*task).ptr, waitable, cabi_wake, ptr.cast()); + // We might be inserting a waker for the first time or overwriting // the previous waker. Only assert the expected value here if the // previous value was non-null. if !prev.is_null() { assert_eq!(ptr, prev.cast()); } + cabi::wasip3_task_set(task); } @@ -216,42 +287,59 @@ where /// This relinquishes control of the original `completion_status` pointer /// passed to `register_waker` after this call has completed. pub fn unregister_waker(self: Pin<&mut Self>, waitable: u32) { - // SAFETY: the contract of `wasip3_task_set` is that the returned - // pointer is valid for the lifetime of our entire task, so it's valid - // for this stack frame. Additionally we assert it's non-null to - // double-check it's initialized and additionally check the version for - // the fields that we access. - // - // Otherwise the `waitable_unregister` callback should be safe because: - // - // * We're fulfilling the contract where the first argument must be - // `(*task).ptr` - // * We own the `waitable` that we're passing in, so we're fulfilling - // the contract that arbitrary waitables for other units of work - // aren't being manipulated. - unsafe { - let task = cabi::wasip3_task_set(ptr::null_mut()); - assert!(!task.is_null()); - assert!((*task).version >= cabi::WASIP3_TASK_V1); - let prev = ((*task).waitable_unregister)((*task).ptr, waitable); - - // Note that `_prev` here is not guaranteed to be either `NULL` or - // not. A racy completion notification may have come in and - // removed our waitable from the map even though we're in the - // `InProgress` state, meaning it may not be present. + let (_, _, task, completion) = self.pin_project(); + + let prev = match task { + // Note that in this case we leave `task` as-is to avoid re-cloning + // it in the future if we're re-registered with it, so this only + // unregisters. + Some(prev) => prev.unregister(waitable), + + // If we don't have a previous task listed then that means that we + // are registered with a "v1" task that can't be cloned out. Assume + // blindly that we're still under the same task and unregister from + // it. // - // The main thing is that after this method is called the - // internal `completion_status` is guaranteed to no longer be in - // `task`. + // SAFETY: the contract of `wasip3_task_set` is that the returned + // pointer is valid for the lifetime of our entire task, so it's + // valid for this stack frame. Additionally we assert it's non-null + // to double-check it's initialized and additionally check the + // version for the fields that we access. // - // Note, though, that if present this must be our `CompletionStatus` - // pointer. - if !prev.is_null() { - let ptr: *mut CompletionStatus = self.pin_project().2.get_unchecked_mut(); - assert_eq!(ptr, prev.cast()); - } + // Otherwise the `waitable_unregister` callback should be safe + // because: + // + // * We're fulfilling the contract where the first argument must be + // `(*task).ptr` + // * We own the `waitable` that we're passing in, so we're + // fulfilling the contract that arbitrary waitables for other + // units of work aren't being manipulated. + None => unsafe { + let task = cabi::wasip3_task_set(ptr::null_mut()); + assert!(!task.is_null()); + assert!((*task).version >= cabi::WASIP3_TASK_V1); + let prev = ((*task).waitable_unregister)((*task).ptr, waitable); + cabi::wasip3_task_set(task); + prev + }, + }; - cabi::wasip3_task_set(task); + // Note that `prev` here may be null or may be valid. A racy completion + // notification may have come in and removed our waitable from the map + // even though we're in the `InProgress` state, meaning it may not be + // present. + // + // The main thing is that after this method is called the + // internal `completion_status` is guaranteed to no longer be in + // `task`. + // + // Note, though, that if present this must be our `CompletionStatus` + // pointer. If it's not then something's been corrupted and this is + // intended to catch that early. + if !prev.is_null() { + // SAFETY: only used for a comparison, not mutated. + let ptr: *mut CompletionStatus = unsafe { completion.get_unchecked_mut() }; + assert_eq!(ptr, prev.cast()); } } @@ -261,7 +349,7 @@ where pub fn poll_complete(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll { use WaitableOperationState::*; - let (op, state, completion_status) = self.as_mut().pin_project(); + let (op, state, _task, completion_status) = self.as_mut().pin_project(); // First up, determine the completion status, if any, that's available. let optional_code = match state { @@ -305,7 +393,7 @@ where ) -> Poll { use WaitableOperationState::*; - let (op, state, _completion_status) = self.as_mut().pin_project(); + let (op, state, task, _completion_status) = self.as_mut().pin_project(); // If a status code is provided, then extract the in-progress state and // see what it thinks about this code. If we're done, yay! If not then @@ -315,6 +403,9 @@ where // If no status code is available then that means we were polled before // the status came back, so just re-register the waker. if let Some(code) = optional_code { + if let Some(task) = task { + task.registered = None; + } let InProgress(in_progress) = mem::replace(state, Done) else { unreachable!() }; @@ -354,7 +445,7 @@ where pub fn cancel(mut self: Pin<&mut Self>) -> S::Cancel { use WaitableOperationState::*; - let (op, state, mut completion_status) = self.as_mut().pin_project(); + let (op, state, _task, mut completion_status) = self.as_mut().pin_project(); let in_progress = match state { // This operation was never actually started, so there's no need to // cancel anything, just pull out the value and return it. @@ -432,7 +523,7 @@ where // this to be sound. Rust doesn't currently have linear types or async // destructors for example to ensure otherwise that if this were to // proceed asynchronously that we could rely on it being invoked. - let (op, InProgress(in_progress), _) = self.as_mut().pin_project() else { + let (op, InProgress(in_progress), _, _) = self.as_mut().pin_project() else { unreachable!() }; let code = op.in_progress_cancel(in_progress); diff --git a/tests/runtime-async/async/rust-async-and-block-on/runner.rs b/tests/runtime-async/async/rust-async-and-block-on/runner.rs new file mode 100644 index 000000000..e9347009a --- /dev/null +++ b/tests/runtime-async/async/rust-async-and-block-on/runner.rs @@ -0,0 +1,49 @@ +//@ wasmtime-flags = '-Wcomponent-model-async' + +include!(env!("BINDINGS")); + +use std::future::Future; +use std::pin::pin; +use std::task::{Context, Poll, Waker}; +use wit_bindgen::block_on; + +struct Component; + +export!(Component); + +impl Guest for Component { + async fn run() { + let (writer, reader) = wit_stream::new::(); + let reader = a::b::i::launder(reader); + let noop_cx = &mut Context::from_waker(Waker::noop()); + + let mut w1 = pin!(async { + let mut w = writer; + let _ = w.write(vec![1u8]).await; + w + }); + + // Step 1 — register &w1.completion_status in export_task.waitables[H]. + assert!(matches!(w1.as_mut().poll(noop_cx), Poll::Pending)); + + // Step 2 — block_on completes w1; export_task.waitables[H] goes stale. + // _reader must stay alive so the step-3 write blocks (Dropped skips + // register_waker and hides the bug). + let (writer, _reader) = block_on(async move { + let mut reader = reader; + let (w, _) = futures::join!(w1, reader.read(Vec::with_capacity(1))); + (w, reader) + }); + + // Step 3 — register &w2.completion_status; gets freed + // &w1.completion_status back as prev → assert_eq!(ptr, prev.cast()) + // panics at waitable.rs:201. + let mut w2 = pin!(async move { + let mut w = writer; + let pad = [0u64; 16]; + let _ = w.write(vec![2u8]).await; + let _ = pad; // explicit use after await keeps pad in the state machine + }); + let _ = w2.as_mut().poll(noop_cx); // panics + } +} diff --git a/tests/runtime-async/async/rust-async-and-block-on/test.rs b/tests/runtime-async/async/rust-async-and-block-on/test.rs new file mode 100644 index 000000000..2011d1ae0 --- /dev/null +++ b/tests/runtime-async/async/rust-async-and-block-on/test.rs @@ -0,0 +1,13 @@ +include!(env!("BINDINGS")); + +use wit_bindgen::StreamReader; + +struct Component; + +export!(Component); + +impl crate::exports::a::b::i::Guest for Component { + fn launder(x: StreamReader) -> StreamReader { + x + } +} diff --git a/tests/runtime-async/async/rust-async-and-block-on/test.wit b/tests/runtime-async/async/rust-async-and-block-on/test.wit new file mode 100644 index 000000000..3e66f4289 --- /dev/null +++ b/tests/runtime-async/async/rust-async-and-block-on/test.wit @@ -0,0 +1,15 @@ +package a:b; + +interface i { + launder: func(x: stream) -> stream; +} + +world test { + export i; +} + +world runner { + import i; + + export run: async func(); +}