From 7efb8a2d2ad757ad5215aa1a1d155d33d5b065df Mon Sep 17 00:00:00 2001 From: Alex Crichton Date: Mon, 24 Nov 2025 17:05:13 -0800 Subject: [PATCH 1/3] Refactor streams with a `StreamOps` trait This is similar to #1409 except applied to streams now as well as futures. A new `StreamOps` trait defines all the operations for streams to help make it easier to hook up low-level functionality into higher-level bindings. --- .../src/rt/async_support/abi_buffer.rs | 66 ++-- .../src/rt/async_support/stream_support.rs | 366 ++++++++++++------ 2 files changed, 282 insertions(+), 150 deletions(-) diff --git a/crates/guest-rust/src/rt/async_support/abi_buffer.rs b/crates/guest-rust/src/rt/async_support/abi_buffer.rs index fb48b1c3f..f5c55d0c1 100644 --- a/crates/guest-rust/src/rt/async_support/abi_buffer.rs +++ b/crates/guest-rust/src/rt/async_support/abi_buffer.rs @@ -1,4 +1,4 @@ -use crate::rt::async_support::StreamVtable; +use crate::rt::async_support::StreamOps; use crate::rt::Cleanup; use std::alloc::Layout; use std::mem::{self, MaybeUninit}; @@ -16,17 +16,15 @@ use std::vec::Vec; /// /// This value is created through the [`StreamWrite`](super::StreamWrite) /// future's return value. -pub struct AbiBuffer { - rust_storage: Vec>, - vtable: &'static StreamVtable, +pub struct AbiBuffer { + rust_storage: Vec>, + ops: O, alloc: Option, cursor: usize, } -impl AbiBuffer { - pub(crate) fn new(mut vec: Vec, vtable: &'static StreamVtable) -> AbiBuffer { - assert_eq!(vtable.lower.is_some(), vtable.lift.is_some()); - +impl AbiBuffer { + pub(crate) fn new(mut vec: Vec, mut ops: O) -> AbiBuffer { // SAFETY: We're converting `Vec` to `Vec>`, which // should be safe. let rust_storage = unsafe { @@ -34,7 +32,7 @@ impl AbiBuffer { let len = vec.len(); let cap = vec.capacity(); mem::forget(vec); - Vec::>::from_raw_parts(ptr.cast(), len, cap) + Vec::>::from_raw_parts(ptr.cast(), len, cap) }; // If `lower` is provided then the canonical ABI format is different @@ -43,31 +41,32 @@ impl AbiBuffer { // Note that this is probably pretty inefficient for "big" use cases // but it's hoped that "big" use cases are using `u8` and therefore // skip this entirely. - let alloc = vtable.lower.and_then(|lower| { + let alloc = if ops.native_abi_matches_canonical_abi() { + None + } else { + let elem_layout = ops.elem_layout(); let layout = Layout::from_size_align( - vtable.layout.size() * rust_storage.len(), - vtable.layout.align(), + elem_layout.size() * rust_storage.len(), + elem_layout.align(), ) .unwrap(); let (mut ptr, cleanup) = Cleanup::new(layout); - let cleanup = cleanup?; // SAFETY: All items in `rust_storage` are already initialized so // it should be safe to read them and move ownership into the // canonical ABI format. unsafe { for item in rust_storage.iter() { let item = item.assume_init_read(); - lower(item, ptr); - ptr = ptr.add(vtable.layout.size()); + ops.lower(item, ptr); + ptr = ptr.add(elem_layout.size()); } } - - Some(cleanup) - }); + cleanup + }; AbiBuffer { rust_storage, alloc, - vtable, + ops, cursor: 0, } } @@ -78,7 +77,7 @@ impl AbiBuffer { // If there's no `lower` operation then it means that `T`'s layout is // the same in the canonical ABI so it can be used as-is. In this // situation the list would have been un-tampered with above. - if self.vtable.lower.is_none() { + if self.ops.native_abi_matches_canonical_abi() { // SAFETY: this should be in-bounds, so it should be safe. let ptr = unsafe { self.rust_storage.as_ptr().add(self.cursor).cast() }; let len = self.rust_storage.len() - self.cursor; @@ -94,7 +93,7 @@ impl AbiBuffer { .unwrap_or(ptr::null_mut()); ( // SAFETY: this should be in-bounds, so it should be safe. - unsafe { ptr.add(self.cursor * self.vtable.layout.size()) }, + unsafe { ptr.add(self.cursor * self.ops.elem_layout().size()) }, self.rust_storage.len() - self.cursor, ) } @@ -111,7 +110,7 @@ impl AbiBuffer { /// Also note that this can be an expensive operation if a partial write /// occurred as this will involve shifting items from the end of the vector /// to the start of the vector. - pub fn into_vec(mut self) -> Vec { + pub fn into_vec(mut self) -> Vec { self.take_vec() } @@ -127,10 +126,10 @@ impl AbiBuffer { /// necessary for the starting `amt` items in this list. pub(crate) fn advance(&mut self, amt: usize) { assert!(amt + self.cursor <= self.rust_storage.len()); - let Some(dealloc_lists) = self.vtable.dealloc_lists else { + if !self.ops.contains_lists() { self.cursor += amt; return; - }; + } let (mut ptr, len) = self.abi_ptr_and_len(); assert!(amt <= len); for _ in 0..amt { @@ -138,14 +137,14 @@ impl AbiBuffer { // it was initialized with a `lower`, and then the pointer // arithmetic should all be in-bounds. unsafe { - dealloc_lists(ptr.cast_mut()); - ptr = ptr.add(self.vtable.layout.size()); + self.ops.dealloc_lists(ptr.cast_mut()); + ptr = ptr.add(self.ops.elem_layout().size()); } } self.cursor += amt; } - fn take_vec(&mut self) -> Vec { + fn take_vec(&mut self) -> Vec { // First, if necessary, convert remaining values within `self.alloc` // back into `self.rust_storage`. This is necessary when a lift // operation is available meaning that the representation of `T` is @@ -155,15 +154,15 @@ impl AbiBuffer { // `AbiBuffer` was created it moved ownership of all values from the // original vector into the `alloc` value. This is the reverse // operation, moving all the values back into the vector. - if let Some(lift) = self.vtable.lift { + if !self.ops.native_abi_matches_canonical_abi() { let (mut ptr, mut len) = self.abi_ptr_and_len(); // SAFETY: this should be safe as `lift` is operating on values that // were initialized with a previous `lower`, and the pointer // arithmetic here should all be in-bounds. unsafe { for dst in self.rust_storage[self.cursor..].iter_mut() { - dst.write(lift(ptr.cast_mut())); - ptr = ptr.add(self.vtable.layout.size()); + dst.write(self.ops.lift(ptr.cast_mut())); + ptr = ptr.add(self.ops.elem_layout().size()); len -= 1; } assert_eq!(len, 0); @@ -187,12 +186,15 @@ impl AbiBuffer { let len = storage.len(); let cap = storage.capacity(); mem::forget(storage); - Vec::::from_raw_parts(ptr.cast(), len, cap) + Vec::::from_raw_parts(ptr.cast(), len, cap) } } } -impl Drop for AbiBuffer { +impl Drop for AbiBuffer +where + O: StreamOps, +{ fn drop(&mut self) { let _ = self.take_vec(); } diff --git a/crates/guest-rust/src/rt/async_support/stream_support.rs b/crates/guest-rust/src/rt/async_support/stream_support.rs index ff11e6dbc..f46803bbb 100644 --- a/crates/guest-rust/src/rt/async_support/stream_support.rs +++ b/crates/guest-rust/src/rt/async_support/stream_support.rs @@ -9,7 +9,6 @@ use { alloc::Layout, fmt, future::Future, - marker, pin::Pin, ptr, sync::atomic::{AtomicU32, Ordering::Relaxed}, @@ -18,6 +17,52 @@ use { }, }; +/// Operations that a stream requires throughout the implementation. +/// +/// This is generated by `wit_bindgen::generate!` primarily. +#[doc(hidden)] +pub unsafe trait StreamOps: Clone { + /// The Rust type that's sent or received on this stream. + type Payload: 'static; + + /// The `stream.new` intrinsic. + fn new(&mut self) -> u64; + + /// The canonical ABI layout of the type that this stream is + /// sending/receiving. + fn elem_layout(&self) -> Layout; + + /// Returns whether `lift` or `lower` is required to create `Self::Payload`. + /// + /// If this returns `false` then `Self::Payload` is natively in its + /// canonical ABI representation. + fn native_abi_matches_canonical_abi(&self) -> bool; + + /// Returns whether `O::Payload` has lists that need to be deallocated with + /// `dealloc_lists`. + fn contains_lists(&self) -> bool; + + /// Converts a Rust type to its canonical ABI representation. + unsafe fn lower(&mut self, payload: Self::Payload, dst: *mut u8); + /// Used to deallocate any Rust-owned lists in the canonical ABI + /// representation for when a value is successfully sent but needs to be + /// cleaned up. + unsafe fn dealloc_lists(&mut self, dst: *mut u8); + /// Converts from the canonical ABI representation to a Rust value. + unsafe fn lift(&mut self, dst: *mut u8) -> Self::Payload; + /// The `stream.write` intrinsic + unsafe fn start_write(&mut self, stream: u32, val: *const u8, amt: usize) -> u32; + /// The `stream.read` intrinsic + unsafe fn start_read(&mut self, stream: u32, val: *mut u8, amt: usize) -> u32; + /// The `stream.cancel-read` intrinsic + unsafe fn cancel_read(&mut self, stream: u32) -> u32; + /// The `stream.cancel-write` intrinsic + unsafe fn cancel_write(&mut self, stream: u32) -> u32; + /// The `stream.drop-readable` intrinsic + unsafe fn drop_readable(&mut self, stream: u32); + /// The `stream.drop-writable` intrinsic + unsafe fn drop_writable(&mut self, stream: u32); +} /// Operations that a stream requires throughout the implementation. /// /// This is generated by `wit_bindgen::generate!` primarily. @@ -64,36 +109,99 @@ pub struct StreamVtable { pub new: unsafe extern "C" fn() -> u64, } +unsafe impl StreamOps for &StreamVtable { + type Payload = T; + + fn new(&mut self) -> u64 { + unsafe { (self.new)() } + } + fn elem_layout(&self) -> Layout { + self.layout + } + fn native_abi_matches_canonical_abi(&self) -> bool { + self.lift.is_none() + } + fn contains_lists(&self) -> bool { + self.dealloc_lists.is_some() + } + unsafe fn lower(&mut self, payload: Self::Payload, dst: *mut u8) { + if let Some(f) = self.lower { + f(payload, dst) + } + } + unsafe fn dealloc_lists(&mut self, dst: *mut u8) { + if let Some(f) = self.dealloc_lists { + f(dst) + } + } + unsafe fn lift(&mut self, dst: *mut u8) -> Self::Payload { + (self.lift.unwrap())(dst) + } + unsafe fn start_write(&mut self, stream: u32, val: *const u8, amt: usize) -> u32 { + (self.start_write)(stream, val, amt) + } + unsafe fn start_read(&mut self, stream: u32, val: *mut u8, amt: usize) -> u32 { + (self.start_read)(stream, val, amt) + } + unsafe fn cancel_read(&mut self, stream: u32) -> u32 { + (self.cancel_read)(stream) + } + unsafe fn cancel_write(&mut self, stream: u32) -> u32 { + (self.cancel_write)(stream) + } + unsafe fn drop_readable(&mut self, stream: u32) { + (self.drop_readable)(stream) + } + unsafe fn drop_writable(&mut self, stream: u32) { + (self.drop_writable)(stream) + } +} + /// Helper function to create a new read/write pair for a component model /// stream. pub unsafe fn stream_new( vtable: &'static StreamVtable, ) -> (StreamWriter, StreamReader) { + unsafe { raw_stream_new(vtable) } +} + +/// Helper function to create a new read/write pair for a component model +/// stream. +pub unsafe fn raw_stream_new(mut ops: O) -> (RawStreamWriter, RawStreamReader) +where + O: StreamOps + Clone, +{ unsafe { - let handles = (vtable.new)(); + let handles = ops.new(); let reader = handles as u32; let writer = (handles >> 32) as u32; rtdebug!("stream.new() = [{writer}, {reader}]"); ( - StreamWriter::new(writer, vtable), - StreamReader::new(reader, vtable), + RawStreamWriter::new(writer, ops.clone()), + RawStreamReader::new(reader, ops), ) } } /// Represents the writable end of a Component Model `stream`. -pub struct StreamWriter { +pub type StreamWriter = RawStreamWriter<&'static StreamVtable>; + +/// Represents the writable end of a Component Model `stream`. +pub struct RawStreamWriter { handle: u32, - vtable: &'static StreamVtable, + ops: O, done: bool, } -impl StreamWriter { +impl RawStreamWriter +where + O: StreamOps, +{ #[doc(hidden)] - pub unsafe fn new(handle: u32, vtable: &'static StreamVtable) -> Self { + pub unsafe fn new(handle: u32, ops: O) -> Self { Self { handle, - vtable, + ops, done: false, } } @@ -135,15 +243,15 @@ impl StreamWriter { /// no values were sent. It may be possible that values were still sent /// despite being cancelled. Cancelling a write and determining what /// happened must be done with [`StreamWrite::cancel`]. - pub fn write(&mut self, values: Vec) -> StreamWrite<'_, T> { - self.write_buf(AbiBuffer::new(values, self.vtable)) + pub fn write(&mut self, values: Vec) -> RawStreamWrite<'_, O> { + self.write_buf(AbiBuffer::new(values, self.ops.clone())) } /// Same as [`StreamWriter::write`], except this takes [`AbiBuffer`] /// instead of `Vec`. - pub fn write_buf(&mut self, values: AbiBuffer) -> StreamWrite<'_, T> { - StreamWrite { - op: WaitableOperation::new(StreamWriteOp(marker::PhantomData), (self, values)), + pub fn write_buf(&mut self, values: AbiBuffer) -> RawStreamWrite<'_, O> { + RawStreamWrite { + op: WaitableOperation::new(StreamWriteOp { writer: self }, values), } } @@ -154,7 +262,7 @@ impl StreamWriter { /// all of `values` provided into this stream. Upon completion the same /// vector will be returned and any remaining elements in the vector were /// not sent because the stream was dropped. - pub async fn write_all(&mut self, values: Vec) -> Vec { + pub async fn write_all(&mut self, values: Vec) -> Vec { // Perform an initial write which converts `values` into `AbiBuffer`. let (mut status, mut buf) = self.write(values).await; @@ -187,7 +295,7 @@ impl StreamWriter { /// If the other end hangs up then the value is returned back as /// `Some(value)`, otherwise `None` is returned indicating the value was /// sent. - pub async fn write_one(&mut self, value: T) -> Option { + pub async fn write_one(&mut self, value: O::Payload) -> Option { // TODO: can probably be a bit more efficient about this and avoid // moving `value` onto the heap in some situations, but that's left as // an optimization for later. @@ -195,7 +303,10 @@ impl StreamWriter { } } -impl fmt::Debug for StreamWriter { +impl fmt::Debug for RawStreamWriter +where + O: StreamOps, +{ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.debug_struct("StreamWriter") .field("handle", &self.handle) @@ -203,21 +314,29 @@ impl fmt::Debug for StreamWriter { } } -impl Drop for StreamWriter { +impl Drop for RawStreamWriter +where + O: StreamOps, +{ fn drop(&mut self) { rtdebug!("stream.drop-writable({})", self.handle); unsafe { - (self.vtable.drop_writable)(self.handle); + self.ops.drop_writable(self.handle); } } } /// Represents a write operation which may be cancelled prior to completion. -pub struct StreamWrite<'a, T: 'static> { - op: WaitableOperation>, +pub type StreamWrite<'a, T> = RawStreamWrite<'a, &'static StreamVtable>; + +/// Represents a write operation which may be cancelled prior to completion. +pub struct RawStreamWrite<'a, O: StreamOps> { + op: WaitableOperation>, } -struct StreamWriteOp<'a, T: 'static>(marker::PhantomData<(&'a mut StreamWriter, T)>); +struct StreamWriteOp<'a, O: StreamOps> { + writer: &'a mut RawStreamWriter, +} /// Result of a [`StreamWriter::write`] or [`StreamReader::read`] operation, /// yielded by the [`StreamWrite`] or [`StreamRead`] futures. @@ -234,42 +353,42 @@ pub enum StreamResult { Cancelled, } -unsafe impl<'a, T> WaitableOp for StreamWriteOp<'a, T> +unsafe impl<'a, O> WaitableOp for StreamWriteOp<'a, O> where - T: 'static, + O: StreamOps, { - type Start = (&'a mut StreamWriter, AbiBuffer); - type InProgress = (&'a mut StreamWriter, AbiBuffer); - type Result = (StreamResult, AbiBuffer); - type Cancel = (StreamResult, AbiBuffer); - - fn start(&mut self, (writer, buf): Self::Start) -> (u32, Self::InProgress) { - if writer.done { - return (DROPPED, (writer, buf)); + type Start = AbiBuffer; + type InProgress = AbiBuffer; + type Result = (StreamResult, AbiBuffer); + type Cancel = (StreamResult, AbiBuffer); + + fn start(&mut self, buf: Self::Start) -> (u32, Self::InProgress) { + if self.writer.done { + return (DROPPED, buf); } let (ptr, len) = buf.abi_ptr_and_len(); // SAFETY: sure hope this is safe, everything in this module and // `AbiBuffer` is trying to make this safe. - let code = unsafe { (writer.vtable.start_write)(writer.handle, ptr, len) }; + let code = unsafe { self.writer.ops.start_write(self.writer.handle, ptr, len) }; rtdebug!( "stream.write({}, {ptr:?}, {len}) = {code:#x}", - writer.handle + self.writer.handle ); - (code, (writer, buf)) + (code, buf) } - fn start_cancelled(&mut self, (_writer, buf): Self::Start) -> Self::Cancel { + fn start_cancelled(&mut self, buf: Self::Start) -> Self::Cancel { (StreamResult::Cancelled, buf) } fn in_progress_update( &mut self, - (writer, mut buf): Self::InProgress, + mut buf: Self::InProgress, code: u32, ) -> Result { match ReturnCode::decode(code) { - ReturnCode::Blocked => Err((writer, buf)), + ReturnCode::Blocked => Err(buf), ReturnCode::Dropped(0) => Ok((StreamResult::Dropped, buf)), ReturnCode::Cancelled(0) => Ok((StreamResult::Cancelled, buf)), code @ (ReturnCode::Completed(amt) @@ -278,22 +397,22 @@ where let amt = amt.try_into().unwrap(); buf.advance(amt); if let ReturnCode::Dropped(_) = code { - writer.done = true; + self.writer.done = true; } Ok((StreamResult::Complete(amt), buf)) } } } - fn in_progress_waitable(&mut self, (writer, _): &Self::InProgress) -> u32 { - writer.handle + fn in_progress_waitable(&mut self, _: &Self::InProgress) -> u32 { + self.writer.handle } - fn in_progress_cancel(&mut self, (writer, _): &mut Self::InProgress) -> u32 { + fn in_progress_cancel(&mut self, _: &mut Self::InProgress) -> u32 { // SAFETY: we're managing `writer` and all the various operational bits, // so this relies on `WaitableOperation` being safe. - let code = unsafe { (writer.vtable.cancel_write)(writer.handle) }; - rtdebug!("stream.cancel-write({}) = {code:#x}", writer.handle); + let code = unsafe { self.writer.ops.cancel_write(self.writer.handle) }; + rtdebug!("stream.cancel-write({}) = {code:#x}", self.writer.handle); code } @@ -302,16 +421,16 @@ where } } -impl Future for StreamWrite<'_, T> { - type Output = (StreamResult, AbiBuffer); +impl Future for RawStreamWrite<'_, O> { + type Output = (StreamResult, AbiBuffer); fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { self.pin_project().poll_complete(cx) } } -impl<'a, T: 'static> StreamWrite<'a, T> { - fn pin_project(self: Pin<&mut Self>) -> Pin<&mut WaitableOperation>> { +impl<'a, O: StreamOps> RawStreamWrite<'a, O> { + fn pin_project(self: Pin<&mut Self>) -> Pin<&mut WaitableOperation>> { // SAFETY: we've chosen that when `Self` is pinned that it translates to // always pinning the inner field, so that's codified here. unsafe { Pin::new_unchecked(&mut self.get_unchecked_mut().op) } @@ -327,19 +446,22 @@ impl<'a, T: 'static> StreamWrite<'a, T> { /// /// Panics if the operation has already been completed via `Future::poll`, /// or if this method is called twice. - pub fn cancel(self: Pin<&mut Self>) -> (StreamResult, AbiBuffer) { + pub fn cancel(self: Pin<&mut Self>) -> (StreamResult, AbiBuffer) { self.pin_project().cancel() } } /// Represents the readable end of a Component Model `stream`. -pub struct StreamReader { +pub type StreamReader = RawStreamReader<&'static StreamVtable>; + +/// Represents the readable end of a Component Model `stream`. +pub struct RawStreamReader { handle: AtomicU32, - vtable: &'static StreamVtable, + ops: O, done: bool, } -impl fmt::Debug for StreamReader { +impl fmt::Debug for StreamReader { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.debug_struct("StreamReader") .field("handle", &self.handle) @@ -347,12 +469,12 @@ impl fmt::Debug for StreamReader { } } -impl StreamReader { +impl RawStreamReader { #[doc(hidden)] - pub fn new(handle: u32, vtable: &'static StreamVtable) -> Self { + pub fn new(handle: u32, ops: O) -> Self { Self { handle: AtomicU32::new(handle), - vtable, + ops, done: false, } } @@ -390,9 +512,9 @@ impl StreamReader { /// futures, but it does not mean that no values were read. To accurately /// determine if values were read the [`StreamRead::cancel`] method must be /// used. - pub fn read(&mut self, buf: Vec) -> StreamRead<'_, T> { - StreamRead { - op: WaitableOperation::new(StreamReadOp(marker::PhantomData), (self, buf)), + pub fn read(&mut self, buf: Vec) -> RawStreamRead<'_, O> { + RawStreamRead { + op: WaitableOperation::new(StreamReadOp { reader: self }, buf), } } @@ -400,7 +522,7 @@ impl StreamReader { /// /// This is a higher-level method than [`StreamReader::read`] in that it /// reads only a single item and does not expose control over cancellation. - pub async fn next(&mut self) -> Option { + pub async fn next(&mut self) -> Option { // TODO: should amortize this allocation and avoid doing it every time. // Or somehow perhaps make this more optimal. let (_result, mut buf) = self.read(Vec::with_capacity(1)).await; @@ -411,7 +533,7 @@ impl StreamReader { /// /// This method will read all remaining items from this stream into a list /// and await the stream to be dropped. - pub async fn collect(mut self) -> Vec { + pub async fn collect(mut self) -> Vec { let mut ret = Vec::new(); loop { // If there's no more spare capacity then reserve room for one item @@ -432,37 +554,39 @@ impl StreamReader { } } -impl Drop for StreamReader { +impl Drop for RawStreamReader { fn drop(&mut self) { let Some(handle) = self.opt_handle() else { return; }; unsafe { rtdebug!("stream.drop-readable({})", handle); - (self.vtable.drop_readable)(handle); + self.ops.drop_readable(handle); } } } /// Represents a read operation which may be cancelled prior to completion. -pub struct StreamRead<'a, T: 'static> { - op: WaitableOperation>, +pub type StreamRead<'a, T> = RawStreamRead<'a, &'static StreamVtable>; + +/// Represents a read operation which may be cancelled prior to completion. +pub struct RawStreamRead<'a, O: StreamOps> { + op: WaitableOperation>, } -struct StreamReadOp<'a, T: 'static>(marker::PhantomData<(&'a mut StreamReader, T)>); +struct StreamReadOp<'a, O: StreamOps> { + reader: &'a mut RawStreamReader, +} -unsafe impl<'a, T> WaitableOp for StreamReadOp<'a, T> -where - T: 'static, -{ - type Start = (&'a mut StreamReader, Vec); - type InProgress = (&'a mut StreamReader, Vec, Option); - type Result = (StreamResult, Vec); - type Cancel = (StreamResult, Vec); - - fn start(&mut self, (reader, mut buf): Self::Start) -> (u32, Self::InProgress) { - if reader.done { - return (DROPPED, (reader, buf, None)); +unsafe impl<'a, O: StreamOps> WaitableOp for StreamReadOp<'a, O> { + type Start = Vec; + type InProgress = (Vec, Option); + type Result = (StreamResult, Vec); + type Cancel = (StreamResult, Vec); + + fn start(&mut self, mut buf: Self::Start) -> (u32, Self::InProgress) { + if self.reader.done { + return (DROPPED, (buf, None)); } let cap = buf.spare_capacity_mut(); @@ -471,39 +595,42 @@ where // If `T` requires a lifting operation, then allocate a slab of memory // which will store the canonical ABI read. Otherwise we can use the // raw capacity in `buf` itself. - if reader.vtable.lift.is_some() { - let layout = Layout::from_size_align( - reader.vtable.layout.size() * cap.len(), - reader.vtable.layout.align(), - ) - .unwrap(); - (ptr, cleanup) = Cleanup::new(layout); - } else { + if self.reader.ops.native_abi_matches_canonical_abi() { ptr = cap.as_mut_ptr().cast(); cleanup = None; + } else { + let elem_layout = self.reader.ops.elem_layout(); + let layout = + Layout::from_size_align(elem_layout.size() * cap.len(), elem_layout.align()) + .unwrap(); + (ptr, cleanup) = Cleanup::new(layout); } // SAFETY: `ptr` is either in `buf` or in `cleanup`, both of which will // persist with this async operation itself. - let code = unsafe { (reader.vtable.start_read)(reader.handle(), ptr, cap.len()) }; + let code = unsafe { + self.reader + .ops + .start_read(self.reader.handle(), ptr, cap.len()) + }; rtdebug!( "stream.read({}, {ptr:?}, {}) = {code:#x}", - reader.handle(), + self.reader.handle(), cap.len() ); - (code, (reader, buf, cleanup)) + (code, (buf, cleanup)) } - fn start_cancelled(&mut self, (_, buf): Self::Start) -> Self::Cancel { + fn start_cancelled(&mut self, buf: Self::Start) -> Self::Cancel { (StreamResult::Cancelled, buf) } fn in_progress_update( &mut self, - (reader, mut buf, cleanup): Self::InProgress, + (mut buf, cleanup): Self::InProgress, code: u32, ) -> Result { match ReturnCode::decode(code) { - ReturnCode::Blocked => Err((reader, buf, cleanup)), + ReturnCode::Blocked => Err((buf, cleanup)), // Note that the `cleanup`, if any, is discarded here. ReturnCode::Dropped(0) => Ok((StreamResult::Dropped, buf)), @@ -522,48 +649,48 @@ where let cur_len = buf.len(); assert!(amt <= buf.capacity() - cur_len); - match reader.vtable.lift { + if self.reader.ops.native_abi_matches_canonical_abi() { + // If no `lift` was necessary, then the results of this operation + // were read directly into `buf`, so just update its length now that + // values have been initialized. + unsafe { + buf.set_len(cur_len + amt); + } + } else { // With a `lift` operation this now requires reading `amt` items // from `cleanup` and pushing them into `buf`. - Some(lift) => { - let mut ptr = cleanup - .as_ref() - .map(|c| c.ptr.as_ptr()) - .unwrap_or(ptr::null_mut()); - for _ in 0..amt { - unsafe { - buf.push(lift(ptr)); - ptr = ptr.add(reader.vtable.layout.size()); - } + let mut ptr = cleanup + .as_ref() + .map(|c| c.ptr.as_ptr()) + .unwrap_or(ptr::null_mut()); + for _ in 0..amt { + unsafe { + buf.push(self.reader.ops.lift(ptr)); + ptr = ptr.add(self.reader.ops.elem_layout().size()); } } - - // If no `lift` was necessary, then the results of this operation - // were read directly into `buf`, so just update its length now that - // values have been initialized. - None => unsafe { buf.set_len(cur_len + amt) }, } // Intentionally dispose of `cleanup` here as, if it was used, all // allocations have been read from it and appended to `buf`. drop(cleanup); if let ReturnCode::Dropped(_) = code { - reader.done = true; + self.reader.done = true; } Ok((StreamResult::Complete(amt), buf)) } } } - fn in_progress_waitable(&mut self, (reader, ..): &Self::InProgress) -> u32 { - reader.handle() + fn in_progress_waitable(&mut self, _: &Self::InProgress) -> u32 { + self.reader.handle() } - fn in_progress_cancel(&mut self, (reader, ..): &mut Self::InProgress) -> u32 { + fn in_progress_cancel(&mut self, _: &mut Self::InProgress) -> u32 { // SAFETY: we're managing `reader` and all the various operational bits, // so this relies on `WaitableOperation` being safe. - let code = unsafe { (reader.vtable.cancel_read)(reader.handle()) }; - rtdebug!("stream.cancel-read({}) = {code:#x}", reader.handle()); + let code = unsafe { self.reader.ops.cancel_read(self.reader.handle()) }; + rtdebug!("stream.cancel-read({}) = {code:#x}", self.reader.handle()); code } @@ -572,16 +699,19 @@ where } } -impl Future for StreamRead<'_, T> { - type Output = (StreamResult, Vec); +impl Future for RawStreamRead<'_, O> { + type Output = (StreamResult, Vec); fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { self.pin_project().poll_complete(cx) } } -impl<'a, T> StreamRead<'a, T> { - fn pin_project(self: Pin<&mut Self>) -> Pin<&mut WaitableOperation>> { +impl<'a, O> RawStreamRead<'a, O> +where + O: StreamOps, +{ + fn pin_project(self: Pin<&mut Self>) -> Pin<&mut WaitableOperation>> { // SAFETY: we've chosen that when `Self` is pinned that it translates to // always pinning the inner field, so that's codified here. unsafe { Pin::new_unchecked(&mut self.get_unchecked_mut().op) } @@ -600,7 +730,7 @@ impl<'a, T> StreamRead<'a, T> { /// /// Panics if the operation has already been completed via `Future::poll`, /// or if this method is called twice. - pub fn cancel(self: Pin<&mut Self>) -> (StreamResult, Vec) { + pub fn cancel(self: Pin<&mut Self>) -> (StreamResult, Vec) { self.pin_project().cancel() } } From f4bbbe71fa9bc36c6ba5f68b7e20ab4997173767 Mon Sep 17 00:00:00 2001 From: Alex Crichton Date: Mon, 24 Nov 2025 21:11:13 -0800 Subject: [PATCH 2/3] Fix tests --- crates/guest-rust/src/rt/async_support/abi_buffer.rs | 1 + 1 file changed, 1 insertion(+) diff --git a/crates/guest-rust/src/rt/async_support/abi_buffer.rs b/crates/guest-rust/src/rt/async_support/abi_buffer.rs index f5c55d0c1..d9a425e8d 100644 --- a/crates/guest-rust/src/rt/async_support/abi_buffer.rs +++ b/crates/guest-rust/src/rt/async_support/abi_buffer.rs @@ -203,6 +203,7 @@ where #[cfg(test)] mod tests { use super::*; + use crate::rt::async_support::StreamVtable; use std::sync::atomic::{AtomicUsize, Ordering::Relaxed}; use std::vec; From aee01d785a5bfea24562f91d4c6f49fc7d57e130 Mon Sep 17 00:00:00 2001 From: Alex Crichton Date: Mon, 24 Nov 2025 21:26:57 -0800 Subject: [PATCH 3/3] Fix debug impl --- crates/guest-rust/src/rt/async_support/stream_support.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/crates/guest-rust/src/rt/async_support/stream_support.rs b/crates/guest-rust/src/rt/async_support/stream_support.rs index f46803bbb..d2d4e98d6 100644 --- a/crates/guest-rust/src/rt/async_support/stream_support.rs +++ b/crates/guest-rust/src/rt/async_support/stream_support.rs @@ -461,7 +461,7 @@ pub struct RawStreamReader { done: bool, } -impl fmt::Debug for StreamReader { +impl fmt::Debug for RawStreamReader { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.debug_struct("StreamReader") .field("handle", &self.handle)