Skip to content

Commit 2a6d845

Browse files
authored
Refactor streams with a StreamOps trait (#1435)
* Refactor streams with a `StreamOps` trait This is similar to #1409 except applied to streams now as well as futures. A new `StreamOps` trait defines all the operations for streams to help make it easier to hook up low-level functionality into higher-level bindings. * Fix tests * Fix debug impl
1 parent 355ab39 commit 2a6d845

File tree

2 files changed

+283
-150
lines changed

2 files changed

+283
-150
lines changed

crates/guest-rust/src/rt/async_support/abi_buffer.rs

Lines changed: 35 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
use crate::rt::async_support::StreamVtable;
1+
use crate::rt::async_support::StreamOps;
22
use crate::rt::Cleanup;
33
use std::alloc::Layout;
44
use std::mem::{self, MaybeUninit};
@@ -16,25 +16,23 @@ use std::vec::Vec;
1616
///
1717
/// This value is created through the [`StreamWrite`](super::StreamWrite)
1818
/// future's return value.
19-
pub struct AbiBuffer<T: 'static> {
20-
rust_storage: Vec<MaybeUninit<T>>,
21-
vtable: &'static StreamVtable<T>,
19+
pub struct AbiBuffer<O: StreamOps> {
20+
rust_storage: Vec<MaybeUninit<O::Payload>>,
21+
ops: O,
2222
alloc: Option<Cleanup>,
2323
cursor: usize,
2424
}
2525

26-
impl<T: 'static> AbiBuffer<T> {
27-
pub(crate) fn new(mut vec: Vec<T>, vtable: &'static StreamVtable<T>) -> AbiBuffer<T> {
28-
assert_eq!(vtable.lower.is_some(), vtable.lift.is_some());
29-
26+
impl<O: StreamOps> AbiBuffer<O> {
27+
pub(crate) fn new(mut vec: Vec<O::Payload>, mut ops: O) -> AbiBuffer<O> {
3028
// SAFETY: We're converting `Vec<T>` to `Vec<MaybeUninit<T>>`, which
3129
// should be safe.
3230
let rust_storage = unsafe {
3331
let ptr = vec.as_mut_ptr();
3432
let len = vec.len();
3533
let cap = vec.capacity();
3634
mem::forget(vec);
37-
Vec::<MaybeUninit<T>>::from_raw_parts(ptr.cast(), len, cap)
35+
Vec::<MaybeUninit<O::Payload>>::from_raw_parts(ptr.cast(), len, cap)
3836
};
3937

4038
// If `lower` is provided then the canonical ABI format is different
@@ -43,31 +41,32 @@ impl<T: 'static> AbiBuffer<T> {
4341
// Note that this is probably pretty inefficient for "big" use cases
4442
// but it's hoped that "big" use cases are using `u8` and therefore
4543
// skip this entirely.
46-
let alloc = vtable.lower.and_then(|lower| {
44+
let alloc = if ops.native_abi_matches_canonical_abi() {
45+
None
46+
} else {
47+
let elem_layout = ops.elem_layout();
4748
let layout = Layout::from_size_align(
48-
vtable.layout.size() * rust_storage.len(),
49-
vtable.layout.align(),
49+
elem_layout.size() * rust_storage.len(),
50+
elem_layout.align(),
5051
)
5152
.unwrap();
5253
let (mut ptr, cleanup) = Cleanup::new(layout);
53-
let cleanup = cleanup?;
5454
// SAFETY: All items in `rust_storage` are already initialized so
5555
// it should be safe to read them and move ownership into the
5656
// canonical ABI format.
5757
unsafe {
5858
for item in rust_storage.iter() {
5959
let item = item.assume_init_read();
60-
lower(item, ptr);
61-
ptr = ptr.add(vtable.layout.size());
60+
ops.lower(item, ptr);
61+
ptr = ptr.add(elem_layout.size());
6262
}
6363
}
64-
65-
Some(cleanup)
66-
});
64+
cleanup
65+
};
6766
AbiBuffer {
6867
rust_storage,
6968
alloc,
70-
vtable,
69+
ops,
7170
cursor: 0,
7271
}
7372
}
@@ -78,7 +77,7 @@ impl<T: 'static> AbiBuffer<T> {
7877
// If there's no `lower` operation then it means that `T`'s layout is
7978
// the same in the canonical ABI so it can be used as-is. In this
8079
// situation the list would have been un-tampered with above.
81-
if self.vtable.lower.is_none() {
80+
if self.ops.native_abi_matches_canonical_abi() {
8281
// SAFETY: this should be in-bounds, so it should be safe.
8382
let ptr = unsafe { self.rust_storage.as_ptr().add(self.cursor).cast() };
8483
let len = self.rust_storage.len() - self.cursor;
@@ -94,7 +93,7 @@ impl<T: 'static> AbiBuffer<T> {
9493
.unwrap_or(ptr::null_mut());
9594
(
9695
// SAFETY: this should be in-bounds, so it should be safe.
97-
unsafe { ptr.add(self.cursor * self.vtable.layout.size()) },
96+
unsafe { ptr.add(self.cursor * self.ops.elem_layout().size()) },
9897
self.rust_storage.len() - self.cursor,
9998
)
10099
}
@@ -111,7 +110,7 @@ impl<T: 'static> AbiBuffer<T> {
111110
/// Also note that this can be an expensive operation if a partial write
112111
/// occurred as this will involve shifting items from the end of the vector
113112
/// to the start of the vector.
114-
pub fn into_vec(mut self) -> Vec<T> {
113+
pub fn into_vec(mut self) -> Vec<O::Payload> {
115114
self.take_vec()
116115
}
117116

@@ -127,25 +126,25 @@ impl<T: 'static> AbiBuffer<T> {
127126
/// necessary for the starting `amt` items in this list.
128127
pub(crate) fn advance(&mut self, amt: usize) {
129128
assert!(amt + self.cursor <= self.rust_storage.len());
130-
let Some(dealloc_lists) = self.vtable.dealloc_lists else {
129+
if !self.ops.contains_lists() {
131130
self.cursor += amt;
132131
return;
133-
};
132+
}
134133
let (mut ptr, len) = self.abi_ptr_and_len();
135134
assert!(amt <= len);
136135
for _ in 0..amt {
137136
// SAFETY: we're managing the pointer passed to `dealloc_lists` and
138137
// it was initialized with a `lower`, and then the pointer
139138
// arithmetic should all be in-bounds.
140139
unsafe {
141-
dealloc_lists(ptr.cast_mut());
142-
ptr = ptr.add(self.vtable.layout.size());
140+
self.ops.dealloc_lists(ptr.cast_mut());
141+
ptr = ptr.add(self.ops.elem_layout().size());
143142
}
144143
}
145144
self.cursor += amt;
146145
}
147146

148-
fn take_vec(&mut self) -> Vec<T> {
147+
fn take_vec(&mut self) -> Vec<O::Payload> {
149148
// First, if necessary, convert remaining values within `self.alloc`
150149
// back into `self.rust_storage`. This is necessary when a lift
151150
// operation is available meaning that the representation of `T` is
@@ -155,15 +154,15 @@ impl<T: 'static> AbiBuffer<T> {
155154
// `AbiBuffer` was created it moved ownership of all values from the
156155
// original vector into the `alloc` value. This is the reverse
157156
// operation, moving all the values back into the vector.
158-
if let Some(lift) = self.vtable.lift {
157+
if !self.ops.native_abi_matches_canonical_abi() {
159158
let (mut ptr, mut len) = self.abi_ptr_and_len();
160159
// SAFETY: this should be safe as `lift` is operating on values that
161160
// were initialized with a previous `lower`, and the pointer
162161
// arithmetic here should all be in-bounds.
163162
unsafe {
164163
for dst in self.rust_storage[self.cursor..].iter_mut() {
165-
dst.write(lift(ptr.cast_mut()));
166-
ptr = ptr.add(self.vtable.layout.size());
164+
dst.write(self.ops.lift(ptr.cast_mut()));
165+
ptr = ptr.add(self.ops.elem_layout().size());
167166
len -= 1;
168167
}
169168
assert_eq!(len, 0);
@@ -187,12 +186,15 @@ impl<T: 'static> AbiBuffer<T> {
187186
let len = storage.len();
188187
let cap = storage.capacity();
189188
mem::forget(storage);
190-
Vec::<T>::from_raw_parts(ptr.cast(), len, cap)
189+
Vec::<O::Payload>::from_raw_parts(ptr.cast(), len, cap)
191190
}
192191
}
193192
}
194193

195-
impl<T> Drop for AbiBuffer<T> {
194+
impl<O> Drop for AbiBuffer<O>
195+
where
196+
O: StreamOps,
197+
{
196198
fn drop(&mut self) {
197199
let _ = self.take_vec();
198200
}
@@ -201,6 +203,7 @@ impl<T> Drop for AbiBuffer<T> {
201203
#[cfg(test)]
202204
mod tests {
203205
use super::*;
206+
use crate::rt::async_support::StreamVtable;
204207
use std::sync::atomic::{AtomicUsize, Ordering::Relaxed};
205208
use std::vec;
206209

0 commit comments

Comments
 (0)