From 6e1f69f4a6c493cf312d2d399c2e60086a6bc103 Mon Sep 17 00:00:00 2001 From: Alex Crichton Date: Mon, 1 Jun 2026 15:48:46 -0700 Subject: [PATCH 1/3] Evolve the wasip3 async C ABI for tasks This commit is an evolution of the C ABI used to managed task-related infrastructure in WASIp3 with the goal of solving #1618. The basic problem of #1618 is that waitables in Rust aren't guaranteed to be polled within the context of the original task. For example by mixing an `async` Rust export and `block_on` it's possible to "cross the wires" and poll in one context while dropping/completing in another context. This can lead to buggy situations where a waitable is left in a set, not added to an appropriate set, or generally mis-managed. The solution here is to enhance the current C ABI of task management with clone/drop operations. Notably this enables waitables to retain a strong reference to the task state as opposed to always consulting what the current task in. This fixes a few situations such as: * When dropping a half-finished waitable it no longer needs to be dropped in the context of the original task. Dropping will unregister the waitable from a task that it was originally registered with. * When a waitable is moved from one task to another it needs to implicitly de-register with the previous task, and this was not previously done. Now with a retained strong reference it's able to clear out previous state upon re-registering with a new task. This change requires some finesse as this needs to be ABI-stable to work with previous versions of the `wit-bindgen` crate. The runtime support additionally can't assume that the new ABI bits are available and instead needs to handle the previous ABI as well. Not too too bad, in the end, though. This additionally did some refactoring of the state associated with async tasks to juggle things around and better represent the raw pointers/`Arc`/etc from before. Closes #1618 --- crates/guest-rust/src/rt/async_support.rs | 272 +++++++++++------- .../guest-rust/src/rt/async_support/cabi.rs | 55 ++++ .../src/rt/async_support/inter_task_wakeup.rs | 12 +- .../src/rt/async_support/waitable.rs | 167 ++++++++--- .../async/rust-async-and-block-on/runner.rs | 49 ++++ .../async/rust-async-and-block-on/test.rs | 13 + .../async/rust-async-and-block-on/test.wit | 15 + 7 files changed, 438 insertions(+), 145 deletions(-) create mode 100644 tests/runtime-async/async/rust-async-and-block-on/runner.rs create mode 100644 tests/runtime-async/async/rust-async-and-block-on/test.rs create mode 100644 tests/runtime-async/async/rust-async-and-block-on/test.wit diff --git a/crates/guest-rust/src/rt/async_support.rs b/crates/guest-rust/src/rt/async_support.rs index 2a5218ed1..baa5b7497 100644 --- a/crates/guest-rust/src/rt/async_support.rs +++ b/crates/guest-rust/src/rt/async_support.rs @@ -1,12 +1,13 @@ #![deny(missing_docs)] +use self::try_lock::TryLock; use alloc::boxed::Box; use alloc::collections::BTreeMap; use alloc::sync::Arc; use alloc::task::Wake; use core::ffi::c_void; use core::future::Future; -use core::mem; +use core::mem::{self, ManuallyDrop}; use core::pin::Pin; use core::ptr; use core::sync::atomic::{AtomicU32, Ordering}; @@ -104,15 +105,28 @@ use spawn_disabled as spawn; /// Represents a task created by either a call to an async-lifted export or a /// future run using `block_on` or `start_task`. -struct FutureState<'a> { +struct TaskState<'a> { /// Remaining work to do (if any) before this task can be considered "done". /// /// Note that we won't tell the host the task is done until this is drained /// and `waitables` is empty. tasks: spawn::Tasks<'a>, - /// The waitable set containing waitables created by this task, if any. - waitable_set: Option, + /// Dual-mode rust-level Waker and C ABI level "task" for wasip3 + /// integration. + shared: Arc, + + /// Clone of `shared` field, but represented as `std::task::Waker`. + waker: Waker, + + /// State related to supporting inter-task wakeup scenarios. + inter_task_wakeup: inter_task_wakeup::State, +} + +struct SharedTaskState { + /// One of `SLEEP_STATE_*` indicating the current status. + sleep_state: AtomicU32, + inter_task_stream: inter_task_wakeup::WakerState, /// State of all waitables in `waitable_set`, and the ptr/callback they're /// associated with. @@ -123,56 +137,44 @@ struct FutureState<'a> { // the `wasi_snapshot_preview1` adapter when targeting `wasm32-wasip2` and // later, and that's expensive enough that we'd prefer to avoid it for apps // which otherwise make no use of the adapter. - waitables: BTreeMap, - - /// Raw structure used to pass to `cabi::wasip3_task_set` - wasip3_task: cabi::wasip3_task, - - /// Rust-level state for the waker, notably a bool as to whether this has - /// been woken. - waker: Arc, + // + // Also note that the `TryLock` here should never be contended, but it's + // used for interior mutability. + waitables: TryLock>, - /// Clone of `waker` field, but represented as `std::task::Waker`. - waker_clone: Waker, + /// The waitable set containing waitables created by this task, if any. + // + // Note the `TryLock` is the same as `waitables` above, it's serving the + // purpose of interior mutability. + waitable_set: TryLock>, +} - /// State related to supporting inter-task wakeup scenarios. - inter_task_wakeup: inter_task_wakeup::State, +/// An entry of `SharedTaskState::waitables` which is added through the C ABI. +struct CabiWaitable { + callback: unsafe extern "C" fn(*mut c_void, u32), + callback_ptr: *mut c_void, } -impl FutureState<'_> { - fn new(future: BoxFuture<'_>) -> FutureState<'_> { - let waker = Arc::new(FutureWaker::default()); - FutureState { - waker_clone: waker.clone().into(), - waker, +unsafe impl Send for CabiWaitable {} + +impl TaskState<'_> { + fn new(future: BoxFuture<'_>) -> TaskState<'_> { + let shared = Arc::new(SharedTaskState { + sleep_state: AtomicU32::new(0), + inter_task_stream: Default::default(), + waitables: Default::default(), + waitable_set: Default::default(), + }); + TaskState { + waker: shared.clone().into(), + shared, tasks: spawn::Tasks::new(future), - waitable_set: None, - waitables: BTreeMap::new(), - wasip3_task: cabi::wasip3_task { - // This pointer is filled in before calling `wasip3_task_set`. - ptr: ptr::null_mut(), - version: cabi::WASIP3_TASK_V1, - waitable_register, - waitable_unregister, - }, inter_task_wakeup: Default::default(), } } - fn get_or_create_waitable_set(&mut self) -> &WaitableSet { - self.waitable_set.get_or_insert_with(WaitableSet::new) - } - - fn add_waitable(&mut self, waitable: u32) { - self.get_or_create_waitable_set().join(waitable) - } - - fn remove_waitable(&mut self, waitable: u32) { - WaitableSet::remove_waitable_from_all_sets(waitable) - } - fn remaining_work(&self) -> bool { - !self.waitables.is_empty() + !self.shared.waitables.try_lock().unwrap().is_empty() } /// Handles the `event{0,1,2}` event codes and returns a corresponding @@ -190,7 +192,7 @@ impl FutureState<'_> { // Cancellation is mapped to destruction in Rust, so return a // code/bool indicating we're done. The caller will then - // appropriately deallocate this `FutureState` which will + // appropriately deallocate this `TaskState` which will // transitively run all destructors. return CallbackCode::Exit; } @@ -200,7 +202,7 @@ impl FutureState<'_> { self.with_p3_task_set(|me| { // Transition our sleep state to ensure that the inter-task stream // isn't used since there's no need to use that here. - me.waker + me.shared .sleep_state .store(SLEEP_STATE_WOKEN, Ordering::Relaxed); @@ -221,13 +223,13 @@ impl FutureState<'_> { me.cancel_inter_task_stream_read(); loop { - let mut context = Context::from_waker(&me.waker_clone); + let mut context = Context::from_waker(&me.waker); // On each turn of this loop reset the state to "polling" // which clears out any pending wakeup if one was sent. This // in theory helps minimize wakeups from previous iterations // happening in this iteration. - me.waker + me.shared .sleep_state .store(SLEEP_STATE_POLLING, Ordering::Relaxed); @@ -242,7 +244,8 @@ impl FutureState<'_> { Poll::Ready(()) => { assert!(me.tasks.is_empty()); if me.remaining_work() { - let waitable = me.waitable_set.as_ref().unwrap().as_raw(); + let set = me.shared.waitable_set.try_lock().unwrap(); + let waitable = set.as_ref().unwrap().as_raw(); break CallbackCode::Wait(waitable); } else { break CallbackCode::Exit; @@ -255,10 +258,12 @@ impl FutureState<'_> { // something. Poll::Pending => { assert!(!me.tasks.is_empty()); - if me.waker.sleep_state.load(Ordering::Relaxed) == SLEEP_STATE_WOKEN { + if me.shared.sleep_state.load(Ordering::Relaxed) == SLEEP_STATE_WOKEN { if me.remaining_work() { - let (event0, event1, event2) = - me.waitable_set.as_ref().unwrap().poll(); + let (event0, event1, event2) = { + let set = me.shared.waitable_set.try_lock().unwrap(); + set.as_ref().unwrap().poll() + }; if event0 != EVENT_NONE { me.deliver_waitable_event(event1, event2); continue; @@ -270,11 +275,12 @@ impl FutureState<'_> { // Transition our state to "sleeping" so wakeup // notifications know that they need to signal the // inter-task stream. - me.waker + me.shared .sleep_state .store(SLEEP_STATE_SLEEPING, Ordering::Relaxed); me.read_inter_task_stream(); - let waitable = me.waitable_set.as_ref().unwrap().as_raw(); + let set = me.shared.waitable_set.try_lock().unwrap(); + let waitable = set.as_ref().unwrap().as_raw(); break CallbackCode::Wait(waitable); } } @@ -286,7 +292,7 @@ impl FutureState<'_> { /// waitable should be present because it's part of the waitable set which /// is kept in-sync with our map. fn deliver_waitable_event(&mut self, waitable: u32, code: u32) { - self.remove_waitable(waitable); + WaitableSet::remove_waitable_from_all_sets(waitable); if self .inter_task_wakeup @@ -295,17 +301,19 @@ impl FutureState<'_> { return; } - let (ptr, callback) = self.waitables.remove(&waitable).unwrap(); + let c = { + let mut waitables = self.shared.waitables.try_lock().unwrap(); + waitables.remove(&waitable).unwrap() + }; unsafe { - callback(ptr, code); + (c.callback)(c.callback_ptr, code); } } fn with_p3_task_set(&mut self, f: impl FnOnce(&mut Self) -> R) -> R { - // Finish our `wasip3_task` by initializing its self-referential pointer, - // and then register it for the duration of this function with - // `wasip3_task_set`. The previous value of `wasip3_task_set` will get - // restored when this function returns. + // Initialize a temporary `wasip3_task` structure on the stack and + // inform `wasip3_task_set` that we're now within that task. Note the + // RAII guard to reset the task back to its previous contents. struct ResetTask(*mut cabi::wasip3_task); impl Drop for ResetTask { fn drop(&mut self) { @@ -314,16 +322,32 @@ impl FutureState<'_> { } } } - let self_raw = self as *mut FutureState<'_>; - self.wasip3_task.ptr = self_raw.cast(); - let prev = unsafe { cabi::wasip3_task_set(&mut self.wasip3_task) }; + // The `ptr` field of `wasip3_task` is to `SharedTaskState` which is + // what's cloned/handed out/etc. + let shared_raw: *const SharedTaskState = &*self.shared; + let mut wasip3_task = cabi::wasip3_task_v2 { + v1: cabi::wasip3_task { + ptr: shared_raw.cast_mut().cast(), + version: cabi::WASIP3_TASK_V2, + waitable_register: SharedTaskState::CABI_VTABLE.waitable_register, + waitable_unregister: SharedTaskState::CABI_VTABLE.waitable_unregister, + }, + vtable: &SharedTaskState::CABI_VTABLE, + }; + + // Explicitly take a mutable borrow on the entire `wasip3_task` + // structure, and then cast its raw pointer to the "smaller" historical + // version, ensuring the final pointer has provenace over the entire + // structure. + let wasip3_task: *mut cabi::wasip3_task_v2 = &mut wasip3_task; + let prev = unsafe { cabi::wasip3_task_set(wasip3_task.cast::()) }; let _reset = ResetTask(prev); f(self) } } -impl Drop for FutureState<'_> { +impl Drop for TaskState<'_> { fn drop(&mut self) { // If there's an active read of the inter-task stream, go ahead and // cancel it, since we're about to drop the stream anyway. @@ -342,33 +366,80 @@ impl Drop for FutureState<'_> { } } -unsafe extern "C" fn waitable_register( - ptr: *mut c_void, - waitable: u32, - callback: unsafe extern "C" fn(*mut c_void, u32), - callback_ptr: *mut c_void, -) -> *mut c_void { - let ptr = ptr.cast::>(); - assert!(!ptr.is_null()); - unsafe { - (*ptr).add_waitable(waitable); - match (*ptr).waitables.insert(waitable, (callback_ptr, callback)) { - Some((prev, _)) => prev, +impl SharedTaskState { + const CABI_VTABLE: cabi::wasip3_task_vtable = cabi::wasip3_task_vtable { + version: cabi::WASIP3_TASK_V2, + waitable_register: Self::cabi_waitable_register, + waitable_unregister: Self::cabi_waitable_unregister, + drop: Self::cabi_drop, + clone: Self::cabi_clone, + }; + + /// Adds the `waitable` provided to this task's waitable set. + fn add_waitable(&self, waitable: u32) { + let mut set = self.waitable_set.try_lock().unwrap(); + set.get_or_insert_with(WaitableSet::new).join(waitable); + } + + /// Implementation of the CABI `waitable_register` function. + fn waitable_register( + &self, + waitable: u32, + callback: unsafe extern "C" fn(*mut c_void, u32), + callback_ptr: *mut c_void, + ) -> *mut c_void { + self.add_waitable(waitable); + let mut waitables = self.waitables.try_lock().unwrap(); + let c = CabiWaitable { + callback, + callback_ptr, + }; + match waitables.insert(waitable, c) { + Some(prev) => prev.callback_ptr, None => ptr::null_mut(), } } -} -unsafe extern "C" fn waitable_unregister(ptr: *mut c_void, waitable: u32) -> *mut c_void { - let ptr = ptr.cast::>(); - assert!(!ptr.is_null()); - unsafe { - (*ptr).remove_waitable(waitable); - match (*ptr).waitables.remove(&waitable) { - Some((prev, _)) => prev, + /// Implementation of the CABI `waitable_unregister` function. + fn waitable_unregister(&self, waitable: u32) -> *mut c_void { + WaitableSet::remove_waitable_from_all_sets(waitable); + let mut waitables = self.waitables.try_lock().unwrap(); + match waitables.remove(&waitable) { + Some(prev) => prev.callback_ptr, None => ptr::null_mut(), } } + + /// Helper to go from a raw `c_void` FFI pointer to a typed + /// self-representation. + unsafe fn cabi_to_self(ptr: *mut c_void) -> ManuallyDrop> { + unsafe { ManuallyDrop::new(Arc::from_raw(ptr.cast::())) } + } + + unsafe extern "C" fn cabi_waitable_register( + ptr: *mut c_void, + waitable: u32, + callback: unsafe extern "C" fn(*mut c_void, u32), + callback_ptr: *mut c_void, + ) -> *mut c_void { + let me = unsafe { Self::cabi_to_self(ptr) }; + me.waitable_register(waitable, callback, callback_ptr) + } + + unsafe extern "C" fn cabi_waitable_unregister(ptr: *mut c_void, waitable: u32) -> *mut c_void { + let me = unsafe { Self::cabi_to_self(ptr) }; + me.waitable_unregister(waitable) + } + + unsafe extern "C" fn cabi_clone(ptr: *mut c_void) -> *mut c_void { + let me = unsafe { Self::cabi_to_self(ptr) }; + Arc::into_raw(Arc::clone(&me)).cast_mut().cast() + } + + unsafe extern "C" fn cabi_drop(ptr: *mut c_void) { + let mut me = unsafe { Self::cabi_to_self(ptr) }; + unsafe { ManuallyDrop::drop(&mut me) } + } } /// Status for "this task is actively being polled" @@ -380,14 +451,7 @@ const SLEEP_STATE_WOKEN: u32 = 1; /// Wakeups on this status signal the inter-task stream. const SLEEP_STATE_SLEEPING: u32 = 2; -#[derive(Default)] -struct FutureWaker { - /// One of `SLEEP_STATE_*` indicating the current status. - sleep_state: AtomicU32, - inter_task_stream: inter_task_wakeup::WakerState, -} - -impl Wake for FutureWaker { +impl Wake for SharedTaskState { fn wake(self: Arc) { Self::wake_by_ref(&self) } @@ -483,11 +547,11 @@ impl ReturnCode { /// immediately with its result. #[doc(hidden)] pub fn start_task(task: impl Future + 'static) -> i32 { - // Allocate a new `FutureState` which will track all state necessary for + // Allocate a new `TaskState` which will track all state necessary for // our exported task. - let state = Box::into_raw(Box::new(FutureState::new(Box::pin(task)))); + let state = Box::into_raw(Box::new(TaskState::new(Box::pin(task)))); - // Store our `FutureState` into our context-local-storage slot and then + // Store our `TaskState` into our context-local-storage slot and then // pretend we got EVENT_NONE to kick off everything. // // SAFETY: we should own `context.set` as we're the root level exported @@ -505,13 +569,13 @@ pub fn start_task(task: impl Future + 'static) -> i32 { /// /// # Unsafety /// -/// This function assumes that `context_get()` returns a `FutureState`. +/// This function assumes that `context_get()` returns a `TaskState`. #[doc(hidden)] pub unsafe fn callback(event0: u32, event1: u32, event2: u32) -> u32 { // Acquire our context-local state, assert it's not-null, and then reset // the state to null while we're running to help prevent any unintended // usage. - let state = context_get().cast::>(); + let state = context_get().cast::>(); assert!(!state.is_null()); unsafe { context_set(ptr::null_mut()); @@ -540,7 +604,7 @@ pub unsafe fn callback(event0: u32, event1: u32, event2: u32) -> u32 { // TODO: refactor so `'static` bounds aren't necessary pub fn block_on(future: impl Future) -> T { let mut result = None; - let mut state = FutureState::new(Box::pin(async { + let mut state = TaskState::new(Box::pin(async { result = Some(future.await); })); let mut event = (EVENT_NONE, 0, 0); @@ -550,8 +614,14 @@ pub fn block_on(future: impl Future) -> T { drop(state); break result.unwrap(); } - CallbackCode::Yield => event = state.waitable_set.as_ref().unwrap().poll(), - CallbackCode::Wait(_) => event = state.waitable_set.as_ref().unwrap().wait(), + CallbackCode::Yield => { + let set = state.shared.waitable_set.try_lock().unwrap(); + event = set.as_ref().unwrap().poll() + } + CallbackCode::Wait(_) => { + let set = state.shared.waitable_set.try_lock().unwrap(); + event = set.as_ref().unwrap().wait() + } } } } diff --git a/crates/guest-rust/src/rt/async_support/cabi.rs b/crates/guest-rust/src/rt/async_support/cabi.rs index d696d1e55..c35222d7b 100644 --- a/crates/guest-rust/src/rt/async_support/cabi.rs +++ b/crates/guest-rust/src/rt/async_support/cabi.rs @@ -66,6 +66,7 @@ extern_wasm! { /// The first version of `wasip3_task` which implies the existence of the /// fields `ptr`, `waitable_register`, and `waitable_unregister`. pub const WASIP3_TASK_V1: u32 = 1; +pub const WASIP3_TASK_V2: u32 = 2; /// Indirect "vtable" used to connect imported functions and exported tasks. /// Executors (e.g. exported functions) define and manage this while imports @@ -105,3 +106,57 @@ pub struct wasip3_task { /// `NULL` if it's not present. pub waitable_unregister: unsafe extern "C" fn(ptr: *mut c_void, waitable: u32) -> *mut c_void, } + +unsafe impl Send for wasip3_task {} +unsafe impl Sync for wasip3_task {} + +/// Indirect "vtable" used to connect imported functions and exported tasks. +/// Executors (e.g. exported functions) define and manage this while imports +/// use it. +#[repr(C)] +pub struct wasip3_task_v2 { + /// TODO + pub v1: wasip3_task, + /// TODO + pub vtable: &'static wasip3_task_vtable, +} + +/// Indirect "vtable" used to connect imported functions and exported tasks. +/// Executors (e.g. exported functions) define and manage this while imports +/// use it. +#[repr(C)] +pub struct wasip3_task_vtable { + /// Currently `WASIP3_TASK_V1`. Indicates what fields are present next + /// depending on the version here. + pub version: u32, + + /// Register a new `waitable` for this exported task. + /// + /// This exported task will add `waitable` to its `waitable-set`. When it + /// becomes ready then `callback` will be invoked with the ready code as + /// well as the `callback_ptr` provided. + /// + /// If `waitable` was previously registered with this task then the + /// previous `callback_ptr` is returned. Otherwise `NULL` is returned. + /// + /// It's the caller's responsibility to ensure that `callback_ptr` is valid + /// until `callback` is invoked, `waitable_unregister` is invoked, or + /// `waitable_register` is called again to overwrite the value. + pub waitable_register: unsafe extern "C" fn( + ptr: *mut c_void, + waitable: u32, + callback: unsafe extern "C" fn(callback_ptr: *mut c_void, code: u32), + callback_ptr: *mut c_void, + ) -> *mut c_void, + + /// Removes the `waitable` from this task's `waitable-set`. + /// + /// Returns the `callback_ptr` passed to `waitable_register` if present, or + /// `NULL` if it's not present. + pub waitable_unregister: unsafe extern "C" fn(ptr: *mut c_void, waitable: u32) -> *mut c_void, + + /// TODO + pub clone: unsafe extern "C" fn(ptr: *mut c_void) -> *mut c_void, + /// TODO + pub drop: unsafe extern "C" fn(ptr: *mut c_void), +} diff --git a/crates/guest-rust/src/rt/async_support/inter_task_wakeup.rs b/crates/guest-rust/src/rt/async_support/inter_task_wakeup.rs index d36d34bd6..f827a5233 100644 --- a/crates/guest-rust/src/rt/async_support/inter_task_wakeup.rs +++ b/crates/guest-rust/src/rt/async_support/inter_task_wakeup.rs @@ -1,6 +1,6 @@ -use super::FutureState; +use super::TaskState; use crate::rt::async_support::try_lock::TryLock; -use crate::rt::async_support::{BLOCKED, COMPLETED}; +use crate::rt::async_support::{BLOCKED, COMPLETED, WaitableSet}; use crate::{RawStreamReader, RawStreamWriter, StreamOps, UnitStreamOps}; use core::ptr; @@ -18,7 +18,7 @@ pub struct State { stream_reading: bool, } -impl FutureState<'_> { +impl TaskState<'_> { pub(super) fn read_inter_task_stream(&mut self) { // Lazily allocate the inter-task stream now that we're actually going // to sleep. We don't know where the wakeup notification will come from @@ -27,7 +27,7 @@ impl FutureState<'_> { assert!(!self.inter_task_wakeup.stream_reading); let (writer, reader) = UnitStreamOps::new(); self.inter_task_wakeup.stream = Some(reader); - let mut waker_stream = self.waker.inter_task_stream.lock.try_lock().unwrap(); + let mut waker_stream = self.shared.inter_task_stream.lock.try_lock().unwrap(); assert!(waker_stream.is_none()); *waker_stream = Some(writer); } @@ -44,7 +44,7 @@ impl FutureState<'_> { let rc = unsafe { UnitStreamOps.start_read(handle, ptr::null_mut(), 1) }; assert_eq!(rc, BLOCKED); self.inter_task_wakeup.stream_reading = true; - self.add_waitable(handle); + self.shared.add_waitable(handle); } } @@ -63,7 +63,7 @@ impl FutureState<'_> { unsafe { UnitStreamOps.cancel_read(handle); } - self.remove_waitable(handle); + WaitableSet::remove_waitable_from_all_sets(handle); } } diff --git a/crates/guest-rust/src/rt/async_support/waitable.rs b/crates/guest-rust/src/rt/async_support/waitable.rs index 4e45f911c..d0b8c97f0 100644 --- a/crates/guest-rust/src/rt/async_support/waitable.rs +++ b/crates/guest-rust/src/rt/async_support/waitable.rs @@ -19,11 +19,59 @@ use core::task::{Context, Poll, Waker}; pub struct WaitableOperation { op: S, state: WaitableOperationState, + task: Option, /// Storage for the final result of this asynchronous operation, if it's /// completed asynchronously. completion_status: CompletionStatus, } +struct CabiTask { + ptr: *mut c_void, + registered: Option, + vtable: &'static cabi::wasip3_task_vtable, +} + +impl CabiTask { + /// Creates a new task from the raw C ABI representation provided. + /// + /// # Safety + /// + /// The `task` provided must be valid and adhere to C ABI conventions. + unsafe fn new(task: *mut cabi::wasip3_task_v2) -> CabiTask { + // SAFETY: the validity of `task` is up to the caller. + unsafe { + CabiTask { + ptr: ((*task).vtable.clone)((*task).v1.ptr), + registered: None, + vtable: (*task).vtable, + } + } + } + + fn unregister(&mut self, waitable: u32) -> *mut c_void { + self.registered = None; + // SAFETY: this was created from a valid task, so this should be safe + // to invoke. + unsafe { (self.vtable.waitable_unregister)(self.ptr, waitable) } + } +} + +unsafe impl Send for CabiTask {} +unsafe impl Sync for CabiTask {} + +impl Drop for CabiTask { + fn drop(&mut self) { + if let Some(waitable) = self.registered { + self.unregister(waitable); + } + // SAFETY: this was created from a valid atask, so this should be safe + // to invoke. + unsafe { + (self.vtable.drop)(self.ptr); + } + } +} + /// Structure used to store the `u32` return code from the canonical ABI about /// an asynchronous operation. /// @@ -140,6 +188,7 @@ where WaitableOperation { op, state: WaitableOperationState::Start(state), + task: None, completion_status: CompletionStatus { code: None, waker: None, @@ -153,6 +202,7 @@ where ) -> ( &mut S, &mut WaitableOperationState, + &mut Option, Pin<&mut CompletionStatus>, ) { // SAFETY: this is the one method used to project from `Pin<&mut Self>` @@ -165,6 +215,7 @@ where ( &mut me.op, &mut me.state, + &mut me.task, Pin::new_unchecked(&mut me.completion_status), ) } @@ -175,7 +226,7 @@ where /// * Fill in `completion_status` with the result of a completion event. /// * Call `cx.waker().wake()`. pub fn register_waker(self: Pin<&mut Self>, waitable: u32, cx: &mut Context) { - let (_, _, mut completion_status) = self.pin_project(); + let (_, _, last_task, mut completion_status) = self.pin_project(); debug_assert!(completion_status.as_mut().code_mut().is_none()); *completion_status.as_mut().waker_mut() = Some(cx.waker().clone()); @@ -193,13 +244,33 @@ where assert!(!task.is_null()); assert!((*task).version >= cabi::WASIP3_TASK_V1); let ptr: *mut CompletionStatus = completion_status.get_unchecked_mut(); + + // For the v2+ ABI clone the task structure to store internally + // within this waitable operation, if we're not already storing + // this task. This ensures that `unregister_waker` below works + // correctly in cross-task situations. + // + // Note that this must happen before the `waitable_register` below + // to ensure we're fully removed from the previous task, if + // applicable, before registering with another task. + if (*task).version >= cabi::WASIP3_TASK_V2 { + let task = task.cast::(); + let last_task = match last_task { + Some(prev) if prev.ptr == (*task).v1.ptr => prev, + _ => last_task.insert(CabiTask::new(task)), + }; + last_task.registered = Some(waitable); + } + let prev = ((*task).waitable_register)((*task).ptr, waitable, cabi_wake, ptr.cast()); + // We might be inserting a waker for the first time or overwriting // the previous waker. Only assert the expected value here if the // previous value was non-null. if !prev.is_null() { assert_eq!(ptr, prev.cast()); } + cabi::wasip3_task_set(task); } @@ -216,42 +287,59 @@ where /// This relinquishes control of the original `completion_status` pointer /// passed to `register_waker` after this call has completed. pub fn unregister_waker(self: Pin<&mut Self>, waitable: u32) { - // SAFETY: the contract of `wasip3_task_set` is that the returned - // pointer is valid for the lifetime of our entire task, so it's valid - // for this stack frame. Additionally we assert it's non-null to - // double-check it's initialized and additionally check the version for - // the fields that we access. - // - // Otherwise the `waitable_unregister` callback should be safe because: - // - // * We're fulfilling the contract where the first argument must be - // `(*task).ptr` - // * We own the `waitable` that we're passing in, so we're fulfilling - // the contract that arbitrary waitables for other units of work - // aren't being manipulated. - unsafe { - let task = cabi::wasip3_task_set(ptr::null_mut()); - assert!(!task.is_null()); - assert!((*task).version >= cabi::WASIP3_TASK_V1); - let prev = ((*task).waitable_unregister)((*task).ptr, waitable); - - // Note that `_prev` here is not guaranteed to be either `NULL` or - // not. A racy completion notification may have come in and - // removed our waitable from the map even though we're in the - // `InProgress` state, meaning it may not be present. + let (_, _, task, completion) = self.pin_project(); + + let prev = match task { + // Note that in this case we leave `task` as-is to avoid re-cloning + // it in the future if we're re-registered with it, so this only + // unregisters. + Some(prev) => prev.unregister(waitable), + + // If we don't have a previous task listed then that means that we + // are registered with a "v1" task that can't be cloned out. Assume + // blindly that we're still under the same task and unregister from + // it. // - // The main thing is that after this method is called the - // internal `completion_status` is guaranteed to no longer be in - // `task`. + // SAFETY: the contract of `wasip3_task_set` is that the returned + // pointer is valid for the lifetime of our entire task, so it's + // valid for this stack frame. Additionally we assert it's non-null + // to double-check it's initialized and additionally check the + // version for the fields that we access. // - // Note, though, that if present this must be our `CompletionStatus` - // pointer. - if !prev.is_null() { - let ptr: *mut CompletionStatus = self.pin_project().2.get_unchecked_mut(); - assert_eq!(ptr, prev.cast()); - } + // Otherwise the `waitable_unregister` callback should be safe + // because: + // + // * We're fulfilling the contract where the first argument must be + // `(*task).ptr` + // * We own the `waitable` that we're passing in, so we're + // fulfilling the contract that arbitrary waitables for other + // units of work aren't being manipulated. + None => unsafe { + let task = cabi::wasip3_task_set(ptr::null_mut()); + assert!(!task.is_null()); + assert!((*task).version >= cabi::WASIP3_TASK_V1); + let prev = ((*task).waitable_unregister)((*task).ptr, waitable); + cabi::wasip3_task_set(task); + prev + }, + }; - cabi::wasip3_task_set(task); + // Note that `prev` here may be null or may be valid. A racy completion + // notification may have come in and removed our waitable from the map + // even though we're in the `InProgress` state, meaning it may not be + // present. + // + // The main thing is that after this method is called the + // internal `completion_status` is guaranteed to no longer be in + // `task`. + // + // Note, though, that if present this must be our `CompletionStatus` + // pointer. If it's not then something's been corrupted and this is + // intended to catch that early. + if !prev.is_null() { + // SAFETY: only used for a comparison, not mutated. + let ptr: *mut CompletionStatus = unsafe { completion.get_unchecked_mut() }; + assert_eq!(ptr, prev.cast()); } } @@ -261,7 +349,7 @@ where pub fn poll_complete(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll { use WaitableOperationState::*; - let (op, state, completion_status) = self.as_mut().pin_project(); + let (op, state, _task, completion_status) = self.as_mut().pin_project(); // First up, determine the completion status, if any, that's available. let optional_code = match state { @@ -305,7 +393,7 @@ where ) -> Poll { use WaitableOperationState::*; - let (op, state, _completion_status) = self.as_mut().pin_project(); + let (op, state, task, _completion_status) = self.as_mut().pin_project(); // If a status code is provided, then extract the in-progress state and // see what it thinks about this code. If we're done, yay! If not then @@ -315,6 +403,9 @@ where // If no status code is available then that means we were polled before // the status came back, so just re-register the waker. if let Some(code) = optional_code { + if let Some(task) = task { + task.registered = None; + } let InProgress(in_progress) = mem::replace(state, Done) else { unreachable!() }; @@ -354,7 +445,7 @@ where pub fn cancel(mut self: Pin<&mut Self>) -> S::Cancel { use WaitableOperationState::*; - let (op, state, mut completion_status) = self.as_mut().pin_project(); + let (op, state, _task, mut completion_status) = self.as_mut().pin_project(); let in_progress = match state { // This operation was never actually started, so there's no need to // cancel anything, just pull out the value and return it. @@ -432,7 +523,7 @@ where // this to be sound. Rust doesn't currently have linear types or async // destructors for example to ensure otherwise that if this were to // proceed asynchronously that we could rely on it being invoked. - let (op, InProgress(in_progress), _) = self.as_mut().pin_project() else { + let (op, InProgress(in_progress), _, _) = self.as_mut().pin_project() else { unreachable!() }; let code = op.in_progress_cancel(in_progress); diff --git a/tests/runtime-async/async/rust-async-and-block-on/runner.rs b/tests/runtime-async/async/rust-async-and-block-on/runner.rs new file mode 100644 index 000000000..e9347009a --- /dev/null +++ b/tests/runtime-async/async/rust-async-and-block-on/runner.rs @@ -0,0 +1,49 @@ +//@ wasmtime-flags = '-Wcomponent-model-async' + +include!(env!("BINDINGS")); + +use std::future::Future; +use std::pin::pin; +use std::task::{Context, Poll, Waker}; +use wit_bindgen::block_on; + +struct Component; + +export!(Component); + +impl Guest for Component { + async fn run() { + let (writer, reader) = wit_stream::new::(); + let reader = a::b::i::launder(reader); + let noop_cx = &mut Context::from_waker(Waker::noop()); + + let mut w1 = pin!(async { + let mut w = writer; + let _ = w.write(vec![1u8]).await; + w + }); + + // Step 1 — register &w1.completion_status in export_task.waitables[H]. + assert!(matches!(w1.as_mut().poll(noop_cx), Poll::Pending)); + + // Step 2 — block_on completes w1; export_task.waitables[H] goes stale. + // _reader must stay alive so the step-3 write blocks (Dropped skips + // register_waker and hides the bug). + let (writer, _reader) = block_on(async move { + let mut reader = reader; + let (w, _) = futures::join!(w1, reader.read(Vec::with_capacity(1))); + (w, reader) + }); + + // Step 3 — register &w2.completion_status; gets freed + // &w1.completion_status back as prev → assert_eq!(ptr, prev.cast()) + // panics at waitable.rs:201. + let mut w2 = pin!(async move { + let mut w = writer; + let pad = [0u64; 16]; + let _ = w.write(vec![2u8]).await; + let _ = pad; // explicit use after await keeps pad in the state machine + }); + let _ = w2.as_mut().poll(noop_cx); // panics + } +} diff --git a/tests/runtime-async/async/rust-async-and-block-on/test.rs b/tests/runtime-async/async/rust-async-and-block-on/test.rs new file mode 100644 index 000000000..2011d1ae0 --- /dev/null +++ b/tests/runtime-async/async/rust-async-and-block-on/test.rs @@ -0,0 +1,13 @@ +include!(env!("BINDINGS")); + +use wit_bindgen::StreamReader; + +struct Component; + +export!(Component); + +impl crate::exports::a::b::i::Guest for Component { + fn launder(x: StreamReader) -> StreamReader { + x + } +} diff --git a/tests/runtime-async/async/rust-async-and-block-on/test.wit b/tests/runtime-async/async/rust-async-and-block-on/test.wit new file mode 100644 index 000000000..3e66f4289 --- /dev/null +++ b/tests/runtime-async/async/rust-async-and-block-on/test.wit @@ -0,0 +1,15 @@ +package a:b; + +interface i { + launder: func(x: stream) -> stream; +} + +world test { + export i; +} + +world runner { + import i; + + export run: async func(); +} From 47d1c37967e7c4f8fe95d7d0a1acf5569a160495 Mon Sep 17 00:00:00 2001 From: Alex Crichton Date: Wed, 17 Jun 2026 13:34:48 -0700 Subject: [PATCH 2/3] Document the C ABI --- .../guest-rust/src/rt/async_support/cabi.rs | 74 ++++++++++++------- 1 file changed, 46 insertions(+), 28 deletions(-) diff --git a/crates/guest-rust/src/rt/async_support/cabi.rs b/crates/guest-rust/src/rt/async_support/cabi.rs index c35222d7b..67ebbc569 100644 --- a/crates/guest-rust/src/rt/async_support/cabi.rs +++ b/crates/guest-rust/src/rt/async_support/cabi.rs @@ -43,6 +43,29 @@ //! //! Additionally for now this file is serving as documentation of this //! interface. +//! +//! # Revisions +//! +//! This interface is intended to be evolvable over time if needed. Notably the +//! original task structure, `wasip3_task`, has a `version` field where certain +//! version levels imply the existence of certain fields. The historical +//! revisions are: +//! +//! ### V1 +//! +//! This was the original version. This is the original specification of +//! `wasip3_task_set` and `wasip3_task`. +//! +//! ### V2 +//! +//! This was added 2026-06-17 in response to #1618. This added +//! `wasip3_task_v2` and `wasip3_task_vtable`. This version enables cloning a +//! task to create a strong reference to it independent of the stack lifetime +//! that `wasip3_task_set` is required to uphold. This necessitated introducing +//! `clone` and `drop` callbacks to manage the lifetime of this reference. +//! While doing this everything was moved into a vtable structure instead of +//! inline in `wasip3_task` to make it easier to add more function pointers +//! in the future if necessary. use core::ffi::c_void; @@ -81,18 +104,7 @@ pub struct wasip3_task { /// below as the first argument. pub ptr: *mut c_void, - /// Register a new `waitable` for this exported task. - /// - /// This exported task will add `waitable` to its `waitable-set`. When it - /// becomes ready then `callback` will be invoked with the ready code as - /// well as the `callback_ptr` provided. - /// - /// If `waitable` was previously registered with this task then the - /// previous `callback_ptr` is returned. Otherwise `NULL` is returned. - /// - /// It's the caller's responsibility to ensure that `callback_ptr` is valid - /// until `callback` is invoked, `waitable_unregister` is invoked, or - /// `waitable_register` is called again to overwrite the value. + /// See `wasip3_task_vtable::waitable_register`. pub waitable_register: unsafe extern "C" fn( ptr: *mut c_void, waitable: u32, @@ -100,34 +112,31 @@ pub struct wasip3_task { callback_ptr: *mut c_void, ) -> *mut c_void, - /// Removes the `waitable` from this task's `waitable-set`. - /// - /// Returns the `callback_ptr` passed to `waitable_register` if present, or - /// `NULL` if it's not present. + /// See `wasip3_task_vtable::waitable_unregister`. pub waitable_unregister: unsafe extern "C" fn(ptr: *mut c_void, waitable: u32) -> *mut c_void, } unsafe impl Send for wasip3_task {} unsafe impl Sync for wasip3_task {} -/// Indirect "vtable" used to connect imported functions and exported tasks. -/// Executors (e.g. exported functions) define and manage this while imports -/// use it. +/// Representation when `wasip3_task::version` is `WASIP3_TASK_V2`. #[repr(C)] pub struct wasip3_task_v2 { - /// TODO + /// The original task structure. pub v1: wasip3_task, - /// TODO + + /// An always-valid pointer to a list of function pointers, described + /// below. pub vtable: &'static wasip3_task_vtable, } -/// Indirect "vtable" used to connect imported functions and exported tasks. -/// Executors (e.g. exported functions) define and manage this while imports -/// use it. +/// Function pointer operations that can operate on `wasip3_task::ptr`. +/// +/// This was introduced in the "v2" ABI and is a member of `wasip3_task_v2`. #[repr(C)] pub struct wasip3_task_vtable { - /// Currently `WASIP3_TASK_V1`. Indicates what fields are present next - /// depending on the version here. + /// Currently `WASIP3_TASK_V2` as that was the first version that specified + /// vtables. pub version: u32, /// Register a new `waitable` for this exported task. @@ -155,8 +164,17 @@ pub struct wasip3_task_vtable { /// `NULL` if it's not present. pub waitable_unregister: unsafe extern "C" fn(ptr: *mut c_void, waitable: u32) -> *mut c_void, - /// TODO + /// Clones this task's pointer to create a separately owned pointer which + /// can be persisted outside the stack frame that this is being used + /// within. + /// + /// Cloned values must be dropped/deallocated with `drop` below. pub clone: unsafe extern "C" fn(ptr: *mut c_void) -> *mut c_void, - /// TODO + + /// Drops and deallocates the provided pointer previously created by a + /// call to the `clone` callback above. + /// + /// This must not be called on the `ptr` value within `wasip3_task::ptr` as + /// that's not managed with this lifetime. pub drop: unsafe extern "C" fn(ptr: *mut c_void), } From 7d24e68759ed8e00ac80c994e1cd831e70288194 Mon Sep 17 00:00:00 2001 From: Alex Crichton Date: Wed, 17 Jun 2026 14:57:22 -0700 Subject: [PATCH 3/3] Fix disabled compile --- .../src/rt/async_support/inter_task_wakeup_disabled.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/crates/guest-rust/src/rt/async_support/inter_task_wakeup_disabled.rs b/crates/guest-rust/src/rt/async_support/inter_task_wakeup_disabled.rs index 3f7462cfb..6cd1fd140 100644 --- a/crates/guest-rust/src/rt/async_support/inter_task_wakeup_disabled.rs +++ b/crates/guest-rust/src/rt/async_support/inter_task_wakeup_disabled.rs @@ -1,9 +1,9 @@ -use super::FutureState; +use super::TaskState; #[derive(Default)] pub struct State; -impl FutureState<'_> { +impl TaskState<'_> { pub(super) fn read_inter_task_stream(&mut self) { assert!( self.remaining_work(),