Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 3 additions & 4 deletions crates/wasi-http/src/p3/body.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ use core::task::{Context, Poll, ready};
use http_body::Body as _;
use http_body_util::combinators::UnsyncBoxBody;
use std::any::{Any, TypeId};
use std::io::Cursor;
use std::sync::Arc;
use tokio::sync::{mpsc, oneshot};
use tokio_util::sync::PollSender;
Expand Down Expand Up @@ -466,7 +465,7 @@ where
D: 'static,
{
type Item = u8;
type Buffer = Cursor<Bytes>;
type Buffer = Bytes;

fn poll_produce<'a>(
mut self: Pin<&mut Self>,
Expand Down Expand Up @@ -513,7 +512,7 @@ where
let cap = cap.into();
if n > cap {
// data frame does not fit in destination, fill it and buffer the rest
dst.set_buffer(Cursor::new(frame.split_off(cap)));
dst.set_buffer(frame.split_off(cap));
let mut dst = dst.as_direct(store, cap);
dst.remaining().copy_from_slice(&frame);
dst.mark_written(cap);
Expand All @@ -524,7 +523,7 @@ where
dst.mark_written(n);
}
} else {
dst.set_buffer(Cursor::new(frame));
dst.set_buffer(frame);
}
return Poll::Ready(Ok(StreamResult::Completed));
}
Expand Down
3 changes: 1 addition & 2 deletions crates/wasi-tls/src/p3/util/closed.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
use crate::Error;
use bytes::BytesMut;
use std::io::Cursor;
use std::pin::Pin;
use std::task::{Context, Poll};
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
Expand Down Expand Up @@ -38,7 +37,7 @@ impl AsyncWrite for Closed {
}
impl<D> StreamProducer<D> for Closed {
type Item = u8;
type Buffer = Cursor<BytesMut>;
type Buffer = BytesMut;

fn poll_produce<'a>(
self: Pin<&mut Self>,
Expand Down
3 changes: 1 addition & 2 deletions crates/wasi-tls/src/p3/util/tokio_streams.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
use bytes::BytesMut;
use std::io::Cursor;
use std::pin::Pin;
use std::task::{Context, Poll};
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
Expand Down Expand Up @@ -37,7 +36,7 @@ where
IO: AsyncRead + Send + Unpin + 'static,
{
type Item = u8;
type Buffer = Cursor<BytesMut>;
type Buffer = BytesMut;

fn poll_produce<'a>(
mut self: Pin<&mut Self>,
Expand Down
4 changes: 2 additions & 2 deletions crates/wasi/src/p3/cli/host.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ use crate::p3::cli::{TerminalInput, TerminalOutput};
use bytes::BytesMut;
use core::pin::Pin;
use core::task::{Context, Poll};
use std::io::{self, Cursor};
use std::io;
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
use tokio::sync::oneshot;
use wasmtime::component::{
Expand All @@ -36,7 +36,7 @@ fn io_error_to_error_code(err: io::Error) -> ErrorCode {

impl<D> StreamProducer<D> for InputStreamProducer {
type Item = u8;
type Buffer = Cursor<BytesMut>;
type Buffer = BytesMut;

fn poll_produce<'a>(
mut self: Pin<&mut Self>,
Expand Down
8 changes: 4 additions & 4 deletions crates/wasi/src/p3/filesystem/host.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ use bytes::BytesMut;
use core::pin::Pin;
use core::task::{Context, Poll, ready};
use core::{iter, mem};
use std::io::{self, Cursor};
use std::io;
use std::sync::Arc;
use tokio::sync::{mpsc, oneshot};
use tokio::task::{JoinHandle, spawn_blocking};
Expand Down Expand Up @@ -160,7 +160,7 @@ impl ReadStreamProducer {

impl<D> StreamProducer<D> for ReadStreamProducer {
type Item = u8;
type Buffer = Cursor<BytesMut>;
type Buffer = BytesMut;

fn poll_produce<'a>(
mut self: Pin<&mut Self>,
Expand Down Expand Up @@ -196,7 +196,7 @@ impl<D> StreamProducer<D> for ReadStreamProducer {
// Lazily spawn a read task if one hasn't already been spawned yet.
let me = &mut *self;
let task = me.task.get_or_insert_with(|| {
let mut buf = dst.take_buffer().into_inner();
let mut buf = dst.take_buffer();
buf.resize(DEFAULT_BUFFER_CAPACITY, 0);
let file = Arc::clone(me.file.as_file());
let offset = me.offset;
Expand Down Expand Up @@ -229,7 +229,7 @@ impl<D> StreamProducer<D> for ReadStreamProducer {
}
Ok(Ok(buf)) => {
let n = buf.len();
dst.set_buffer(Cursor::new(buf));
dst.set_buffer(buf);
Poll::Ready(Ok(self.complete_read(n)))
}
Ok(Err(err)) => {
Expand Down
3 changes: 1 addition & 2 deletions crates/wasi/src/p3/sockets/host/types/tcp.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@ use core::iter;
use core::pin::Pin;
use core::task::{Context, Poll};
use io_lifetimes::AsSocketlike as _;
use std::io::Cursor;
use std::net::{Shutdown, SocketAddr};
use std::sync::Arc;
use tokio::net::{TcpListener, TcpStream};
Expand Down Expand Up @@ -115,7 +114,7 @@ impl ReceiveStreamProducer {

impl<D> StreamProducer<D> for ReceiveStreamProducer {
type Item = u8;
type Buffer = Cursor<BytesMut>;
type Buffer = BytesMut;

fn poll_produce<'a>(
mut self: Pin<&mut Self>,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@ use futures::channel::oneshot;
use futures::{FutureExt as _, stream};
use std::any::{Any, TypeId};
use std::boxed::Box;
use std::io::Cursor;
use std::string::String;
use std::sync::{Arc, Mutex, MutexGuard};
use std::vec::Vec;
Expand Down Expand Up @@ -337,11 +336,25 @@ pub(super) struct FlatAbi {
pub(super) align: u32,
}

struct HostBuffer<'a> {
dst: &'a mut Vec<u8>,
marked_written: &'a mut usize,
}

impl HostBuffer<'_> {
fn reborrow(&mut self) -> HostBuffer<'_> {
HostBuffer {
dst: &mut *self.dst,
marked_written: &mut *self.marked_written,
}
}
}

/// Represents the buffer for a host- or guest-initiated stream read.
pub struct Destination<'a, T, B> {
id: TableId<TransmitState>,
buffer: &'a mut B,
host_buffer: Option<&'a mut Cursor<Vec<u8>>>,
host_buffer: Option<HostBuffer<'a>>,
_phantom: PhantomData<fn() -> T>,
}

Expand All @@ -351,7 +364,7 @@ impl<'a, T, B> Destination<'a, T, B> {
Destination {
id: self.id,
buffer: &mut *self.buffer,
host_buffer: self.host_buffer.as_deref_mut(),
host_buffer: self.host_buffer.as_mut().map(|b| b.reborrow()),
_phantom: PhantomData,
}
}
Expand Down Expand Up @@ -435,11 +448,9 @@ impl<'a, B> Destination<'a, u8, B> {
store: StoreContextMut<'a, D>,
capacity: usize,
) -> DirectDestination<'a, D> {
if let Some(buffer) = self.host_buffer.as_deref_mut() {
buffer.set_position(0);
if buffer.get_mut().is_empty() {
buffer.get_mut().resize(capacity, 0);
}
if let Some(buffer) = &mut self.host_buffer {
*buffer.marked_written = 0;
buffer.dst.resize(capacity, 0);
}

DirectDestination {
Expand All @@ -454,7 +465,7 @@ impl<'a, B> Destination<'a, u8, B> {
/// writer's buffer.
pub struct DirectDestination<'a, D: 'static> {
id: TableId<TransmitState>,
host_buffer: Option<&'a mut Cursor<Vec<u8>>>,
host_buffer: Option<HostBuffer<'a>>,
store: StoreContextMut<'a, D>,
}

Expand Down Expand Up @@ -482,8 +493,8 @@ impl<D: 'static> DirectDestination<'_, D> {
}

fn remaining_(&mut self) -> Result<&mut [u8]> {
if let Some(buffer) = self.host_buffer.as_deref_mut() {
return Ok(buffer.get_mut());
if let Some(buffer) = self.host_buffer.as_mut() {
return Ok(buffer.dst);
}
let transmit = self
.store
Expand Down Expand Up @@ -531,15 +542,10 @@ impl<D: 'static> DirectDestination<'_, D> {
}

fn mark_written_(&mut self, count: usize) -> Result<()> {
if let Some(buffer) = self.host_buffer.as_deref_mut() {
buffer.set_position(
// Note that these `.unwrap`s are documented panic conditions of
// `mark_written`.
buffer
.position()
.checked_add(u64::try_from(count).unwrap())
.unwrap(),
);
if let Some(buffer) = self.host_buffer.as_mut() {
// Note that this `.unwrap` is a documented panic condition of
// `mark_written`.
*buffer.marked_written = buffer.marked_written.checked_add(count).unwrap();
} else {
let transmit = self
.store
Expand Down Expand Up @@ -827,7 +833,7 @@ where
#[cfg(feature = "component-model-async-bytes")]
impl<D> StreamProducer<D> for bytes::Bytes {
type Item = u8;
type Buffer = Cursor<Self>;
type Buffer = Self;

fn poll_produce<'a>(
self: Pin<&mut Self>,
Expand All @@ -836,15 +842,15 @@ impl<D> StreamProducer<D> for bytes::Bytes {
mut dst: Destination<'a, Self::Item, Self::Buffer>,
_: bool,
) -> Poll<Result<StreamResult>> {
dst.set_buffer(Cursor::new(mem::take(self.get_mut())));
dst.set_buffer(mem::take(self.get_mut()));
Poll::Ready(Ok(StreamResult::Dropped))
}
}

#[cfg(feature = "component-model-async-bytes")]
impl<D> StreamProducer<D> for bytes::BytesMut {
type Item = u8;
type Buffer = Cursor<Self>;
type Buffer = Self;

fn poll_produce<'a>(
self: Pin<&mut Self>,
Expand All @@ -853,7 +859,7 @@ impl<D> StreamProducer<D> for bytes::BytesMut {
mut dst: Destination<'a, Self::Item, Self::Buffer>,
_: bool,
) -> Poll<Result<StreamResult>> {
dst.set_buffer(Cursor::new(mem::take(self.get_mut())));
dst.set_buffer(mem::take(self.get_mut()));
Poll::Ready(Ok(StreamResult::Dropped))
}
}
Expand Down Expand Up @@ -2641,9 +2647,10 @@ impl<T> StoreContextMut<'_, T> {
bail_bug!("expected WriteState::HostReady")
};

let mut host_written = 0;
let mut host_buffer =
if let ReadState::HostToHost { buffer, .. } = &mut transmit.read {
Some(Cursor::new(mem::take(buffer)))
Some(mem::take(buffer))
} else {
None
};
Expand All @@ -2654,7 +2661,12 @@ impl<T> StoreContextMut<'_, T> {
Destination {
id,
buffer,
host_buffer: host_buffer.as_mut(),
host_buffer: host_buffer.as_mut().map(|b| {
HostBuffer {
dst: b,
marked_written: &mut host_written,
}
}),
_phantom: PhantomData,
},
cancel,
Expand All @@ -2667,8 +2679,8 @@ impl<T> StoreContextMut<'_, T> {
ReadState::HostToHost { buffer, limit, .. },
) = (host_buffer, &mut transmit.read)
{
*limit = usize::try_from(host_buffer.position())?;
*buffer = host_buffer.into_inner();
*limit = host_written;
*buffer = host_buffer;
*limit
} else {
0
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
#[cfg(feature = "component-model-async-bytes")]
use bytes::{Bytes, BytesMut};
#[cfg(feature = "component-model-async-bytes")]
use std::io::Cursor;
use std::mem::{self, MaybeUninit};
use std::slice;
use std::vec::Vec;
Expand Down Expand Up @@ -377,52 +375,36 @@ impl<T: Send + Sync + 'static> ReadBuffer<T> for Vec<T> {
// SAFETY: the `take` implementation below guarantees that the `fun` closure is
// provided with fully initialized items.
#[cfg(feature = "component-model-async-bytes")]
unsafe impl WriteBuffer<u8> for Cursor<Bytes> {
unsafe impl WriteBuffer<u8> for Bytes {
fn remaining(&self) -> &[u8] {
&self.get_ref()[usize::try_from(self.position()).unwrap()..]
self
}

fn skip(&mut self, count: usize) {
assert!(
count <= self.remaining().len(),
"tried to skip {count} with {} remaining",
self.remaining().len()
);
self.set_position(
self.position()
.checked_add(u64::try_from(count).unwrap())
.unwrap(),
);
let _prefix = self.split_to(count);
}

fn take(&mut self, count: usize, fun: &mut dyn FnMut(&[MaybeUninit<u8>])) {
assert!(count <= self.remaining().len());
fun(unsafe_byte_slice(&self.remaining()[..count]));
self.skip(count);
let prefix = self.split_to(count);
fun(unsafe_byte_slice(&prefix));
}
}

// SAFETY: the `take` implementation below guarantees that the `fun` closure is
// provided with fully initialized items.
#[cfg(feature = "component-model-async-bytes")]
unsafe impl WriteBuffer<u8> for Cursor<BytesMut> {
unsafe impl WriteBuffer<u8> for BytesMut {
fn remaining(&self) -> &[u8] {
&self.get_ref()[usize::try_from(self.position()).unwrap()..]
self
}

fn skip(&mut self, count: usize) {
assert!(count <= self.remaining().len());
self.set_position(
self.position()
.checked_add(u64::try_from(count).unwrap())
.unwrap(),
);
let _prefix = self.split_to(count);
}

fn take(&mut self, count: usize, fun: &mut dyn FnMut(&[MaybeUninit<u8>])) {
assert!(count <= self.remaining().len());
fun(unsafe_byte_slice(&self.remaining()[..count]));
self.skip(count);
let prefix = self.split_to(count);
fun(unsafe_byte_slice(&prefix));
}
}

Expand Down Expand Up @@ -486,7 +468,7 @@ mod tests {
#[test]
#[cfg(feature = "component-model-async-bytes")]
fn test_cursor_bytes_take() {
let mut buf = Cursor::new(Bytes::from(&b"123"[..]));
let mut buf = Bytes::from(&b"123"[..]);
let mut dst = Vec::new();
dst.reserve(1);
dst.move_from(&mut buf, 1);
Expand All @@ -503,7 +485,7 @@ mod tests {
#[test]
#[cfg(feature = "component-model-async-bytes")]
fn test_cursor_bytes_mut_take() {
let mut buf = Cursor::new(BytesMut::from(&b"123"[..]));
let mut buf = BytesMut::from(&b"123"[..]);
let mut dst = Vec::new();
dst.reserve(1);
dst.move_from(&mut buf, 1);
Expand Down
Loading