Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 35 additions & 32 deletions crates/guest-rust/src/rt/async_support/abi_buffer.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use crate::rt::async_support::StreamVtable;
use crate::rt::async_support::StreamOps;
use crate::rt::Cleanup;
use std::alloc::Layout;
use std::mem::{self, MaybeUninit};
Expand All @@ -16,25 +16,23 @@ use std::vec::Vec;
///
/// This value is created through the [`StreamWrite`](super::StreamWrite)
/// future's return value.
pub struct AbiBuffer<T: 'static> {
rust_storage: Vec<MaybeUninit<T>>,
vtable: &'static StreamVtable<T>,
pub struct AbiBuffer<O: StreamOps> {
rust_storage: Vec<MaybeUninit<O::Payload>>,
ops: O,
alloc: Option<Cleanup>,
cursor: usize,
}

impl<T: 'static> AbiBuffer<T> {
pub(crate) fn new(mut vec: Vec<T>, vtable: &'static StreamVtable<T>) -> AbiBuffer<T> {
assert_eq!(vtable.lower.is_some(), vtable.lift.is_some());

impl<O: StreamOps> AbiBuffer<O> {
pub(crate) fn new(mut vec: Vec<O::Payload>, mut ops: O) -> AbiBuffer<O> {
// SAFETY: We're converting `Vec<T>` to `Vec<MaybeUninit<T>>`, which
// should be safe.
let rust_storage = unsafe {
let ptr = vec.as_mut_ptr();
let len = vec.len();
let cap = vec.capacity();
mem::forget(vec);
Vec::<MaybeUninit<T>>::from_raw_parts(ptr.cast(), len, cap)
Vec::<MaybeUninit<O::Payload>>::from_raw_parts(ptr.cast(), len, cap)
};

// If `lower` is provided then the canonical ABI format is different
Expand All @@ -43,31 +41,32 @@ impl<T: 'static> AbiBuffer<T> {
// Note that this is probably pretty inefficient for "big" use cases
// but it's hoped that "big" use cases are using `u8` and therefore
// skip this entirely.
let alloc = vtable.lower.and_then(|lower| {
let alloc = if ops.native_abi_matches_canonical_abi() {
None
} else {
let elem_layout = ops.elem_layout();
let layout = Layout::from_size_align(
vtable.layout.size() * rust_storage.len(),
vtable.layout.align(),
elem_layout.size() * rust_storage.len(),
elem_layout.align(),
)
.unwrap();
let (mut ptr, cleanup) = Cleanup::new(layout);
let cleanup = cleanup?;
// SAFETY: All items in `rust_storage` are already initialized so
// it should be safe to read them and move ownership into the
// canonical ABI format.
unsafe {
for item in rust_storage.iter() {
let item = item.assume_init_read();
lower(item, ptr);
ptr = ptr.add(vtable.layout.size());
ops.lower(item, ptr);
ptr = ptr.add(elem_layout.size());
}
}

Some(cleanup)
});
cleanup
};
AbiBuffer {
rust_storage,
alloc,
vtable,
ops,
cursor: 0,
}
}
Expand All @@ -78,7 +77,7 @@ impl<T: 'static> AbiBuffer<T> {
// If there's no `lower` operation then it means that `T`'s layout is
// the same in the canonical ABI so it can be used as-is. In this
// situation the list would have been un-tampered with above.
if self.vtable.lower.is_none() {
if self.ops.native_abi_matches_canonical_abi() {
// SAFETY: this should be in-bounds, so it should be safe.
let ptr = unsafe { self.rust_storage.as_ptr().add(self.cursor).cast() };
let len = self.rust_storage.len() - self.cursor;
Expand All @@ -94,7 +93,7 @@ impl<T: 'static> AbiBuffer<T> {
.unwrap_or(ptr::null_mut());
(
// SAFETY: this should be in-bounds, so it should be safe.
unsafe { ptr.add(self.cursor * self.vtable.layout.size()) },
unsafe { ptr.add(self.cursor * self.ops.elem_layout().size()) },
self.rust_storage.len() - self.cursor,
)
}
Expand All @@ -111,7 +110,7 @@ impl<T: 'static> AbiBuffer<T> {
/// Also note that this can be an expensive operation if a partial write
/// occurred as this will involve shifting items from the end of the vector
/// to the start of the vector.
pub fn into_vec(mut self) -> Vec<T> {
pub fn into_vec(mut self) -> Vec<O::Payload> {
self.take_vec()
}

Expand All @@ -127,25 +126,25 @@ impl<T: 'static> AbiBuffer<T> {
/// necessary for the starting `amt` items in this list.
pub(crate) fn advance(&mut self, amt: usize) {
assert!(amt + self.cursor <= self.rust_storage.len());
let Some(dealloc_lists) = self.vtable.dealloc_lists else {
if !self.ops.contains_lists() {
self.cursor += amt;
return;
};
}
let (mut ptr, len) = self.abi_ptr_and_len();
assert!(amt <= len);
for _ in 0..amt {
// SAFETY: we're managing the pointer passed to `dealloc_lists` and
// it was initialized with a `lower`, and then the pointer
// arithmetic should all be in-bounds.
unsafe {
dealloc_lists(ptr.cast_mut());
ptr = ptr.add(self.vtable.layout.size());
self.ops.dealloc_lists(ptr.cast_mut());
ptr = ptr.add(self.ops.elem_layout().size());
}
}
self.cursor += amt;
}

fn take_vec(&mut self) -> Vec<T> {
fn take_vec(&mut self) -> Vec<O::Payload> {
// First, if necessary, convert remaining values within `self.alloc`
// back into `self.rust_storage`. This is necessary when a lift
// operation is available meaning that the representation of `T` is
Expand All @@ -155,15 +154,15 @@ impl<T: 'static> AbiBuffer<T> {
// `AbiBuffer` was created it moved ownership of all values from the
// original vector into the `alloc` value. This is the reverse
// operation, moving all the values back into the vector.
if let Some(lift) = self.vtable.lift {
if !self.ops.native_abi_matches_canonical_abi() {
let (mut ptr, mut len) = self.abi_ptr_and_len();
// SAFETY: this should be safe as `lift` is operating on values that
// were initialized with a previous `lower`, and the pointer
// arithmetic here should all be in-bounds.
unsafe {
for dst in self.rust_storage[self.cursor..].iter_mut() {
dst.write(lift(ptr.cast_mut()));
ptr = ptr.add(self.vtable.layout.size());
dst.write(self.ops.lift(ptr.cast_mut()));
ptr = ptr.add(self.ops.elem_layout().size());
len -= 1;
}
assert_eq!(len, 0);
Expand All @@ -187,12 +186,15 @@ impl<T: 'static> AbiBuffer<T> {
let len = storage.len();
let cap = storage.capacity();
mem::forget(storage);
Vec::<T>::from_raw_parts(ptr.cast(), len, cap)
Vec::<O::Payload>::from_raw_parts(ptr.cast(), len, cap)
}
}
}

impl<T> Drop for AbiBuffer<T> {
impl<O> Drop for AbiBuffer<O>
where
O: StreamOps,
{
fn drop(&mut self) {
let _ = self.take_vec();
}
Expand All @@ -201,6 +203,7 @@ impl<T> Drop for AbiBuffer<T> {
#[cfg(test)]
mod tests {
use super::*;
use crate::rt::async_support::StreamVtable;
use std::sync::atomic::{AtomicUsize, Ordering::Relaxed};
use std::vec;

Expand Down
Loading