diff --git a/src/lib.rs b/src/lib.rs index c155fec..9bf4f04 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -58,6 +58,8 @@ pub use into_iter::IntoIter; mod drain; pub use drain::Drain; +mod retain; + #[cfg(feature = "serde")] mod serde_impl; #[cfg(all(test, feature = "serde"))] @@ -744,6 +746,42 @@ impl WordVec { } } + /// Retains only the elements specified by the predicate. + /// + /// In other words, remove all elements `e` for which `predicate(&mut e)` returns false. + /// This method operates in place, visiting each element exactly once in the original order, + /// and preserves the order of the retained elements. + pub fn retain(&mut self, mut should_retain: F) + where + F: FnMut(&mut T) -> bool, + { + let mut retain = retain::Retain::new(self); + loop { + if let retain::NextResult::Exhausted = retain.next(&mut should_retain) { + break; + } + } + } + + /// Creates an iterator which uses a closure to determine if an element should be removed. + /// + /// If the closure returns `true`, the element is removed from the vector + /// and yielded. If the closure returns `false`, or panics, the element + /// remains in the vector and will not be yielded. + /// + /// If the returned iterator is not exhausted, e.g. because it is dropped without iterating + /// or the iteration short-circuits, then the remaining elements will be retained. + /// Use [`retain`] with a negated predicate if you do not need the returned iterator. + /// + /// [`retain`]: Self::retain + #[doc(alias = "drain_filter")] + pub fn extract_if( + &mut self, + should_remove: impl FnMut(&mut T) -> bool, + ) -> impl Iterator { + retain::ExtractIf { retain: retain::Retain::new(self), should_remove } + } + /// Resizes the vector so that its length is equal to `len`. /// /// If `len` is greater than the current length, diff --git a/src/retain.rs b/src/retain.rs new file mode 100644 index 0000000..d54cf19 --- /dev/null +++ b/src/retain.rs @@ -0,0 +1,116 @@ +use core::mem::{self, MaybeUninit}; + +pub(super) struct Retain<'a, T> { + set_len: super::LengthSetter<'a>, + init_slice: &'a mut [MaybeUninit], + read_len: usize, + written_len: usize, +} + +impl<'a, T> Retain<'a, T> { + pub(super) fn new(vec: &'a mut super::WordVec) -> Self { + let (capacity_slice, old_len, mut set_len) = vec.as_uninit_slice_with_length_setter(); + + // SAFETY: length 0 is always safe + unsafe { set_len.set_len(0) }; + + Self { set_len, init_slice: &mut capacity_slice[..old_len], read_len: 0, written_len: 0 } + } +} + +impl Drop for Retain<'_, T> { + fn drop(&mut self) { + // Shift all unvisited elements forward. + let data_len = self.init_slice.len(); + let data_ptr = self.init_slice.as_mut_ptr(); + let moved_len = data_len - self.read_len; + unsafe { + core::ptr::copy(data_ptr.add(self.read_len), data_ptr.add(self.written_len), moved_len); + } + + // SAFETY: ensured by target_len setters + unsafe { + self.set_len.set_len(self.written_len + moved_len); + } + } +} + +impl Retain<'_, T> { + pub(super) fn next(&mut self, should_retain: impl FnOnce(&mut T) -> bool) -> NextResult { + let Some(item_uninit) = self.init_slice.get_mut(self.read_len) else { + return NextResult::Exhausted; + }; + + // SAFETY: init_slice[read_len..] are always initialized + let item_mut = unsafe { item_uninit.assume_init_mut() }; + + // If `should_retain` panics, `item` is no longer referenced, + // so the state of this struct is just as if the current `next` call never happened. + // Thus the destructor will work as expected. + let retain = should_retain(item_mut); + + if retain { + let src_index = self.read_len; + let dest_index = self.written_len; + + // init_slice[read_len] is moved to init_slice[written_len] after this step. + // If read_len == written_len, this just retains the item in place + // and has no safety implications. + // If read_len != written_len, by contract read_len > written_len, + // so init_slice[written_len..read_len] is uninitialized, + // and after this operation, init_slice[written_len] becomes initialized + // while init_slice[read_len] becomes uninitialized. + self.read_len += 1; + self.written_len += 1; + + if src_index != dest_index { + unsafe { + // SAFETY: read_len != written_len checked in condition + let [src, dest] = + self.init_slice.get_disjoint_unchecked_mut([src_index, dest_index]); + dest.write(mem::replace(src, MaybeUninit::uninit()).assume_init()); + } + } + // If src_index == dest_index, this move would be a no-op + + NextResult::Retained + } else { + // this never overflows because read_len < init_slice.len() <= usize::MAX + self.read_len += 1; + + // SAFETY: item can be safely moved out as an initialized value. + let item = mem::replace(item_uninit, MaybeUninit::uninit()); + let item = unsafe { item.assume_init() }; + + NextResult::Removed(item) + } + } +} + +pub(super) enum NextResult { + Exhausted, + Retained, + Removed(T), +} + +pub(super) struct ExtractIf<'a, T, F> { + pub(super) retain: Retain<'a, T>, + pub(super) should_remove: F, +} + +impl Iterator for ExtractIf<'_, T, F> +where + F: FnMut(&mut T) -> bool, +{ + type Item = T; + + fn next(&mut self) -> Option { + loop { + return match self.retain.next(|elem| !(self.should_remove)(elem)) { + NextResult::Exhausted => None, + NextResult::Retained => continue, + NextResult::Removed(item) => Some(item), + }; + } + } +} diff --git a/src/tests.rs b/src/tests.rs index 61f08cd..6b86f16 100644 --- a/src/tests.rs +++ b/src/tests.rs @@ -1,6 +1,8 @@ use alloc::string::{String, ToString}; use alloc::vec::Vec; use core::cell::Cell; +use core::mem; +use core::panic::AssertUnwindSafe; use crate::WordVec; @@ -539,6 +541,161 @@ fn test_drain_long_long_short_early_drop_back() { assert_eq!(wv.as_slice(), &[0, 1, 2, 6, 7]); } +fn test_retain_with( + initial_len: usize, + mut predicate: impl FnMut(&str) -> bool, + expect_retain_drops: usize, + expect_after_retain: &[&str], + retain_fn: impl FnOnce(&mut WordVec, N>, &mut dyn FnMut(&mut AssertDrop<'_>) -> bool), +) { + let counter = &Cell::new(0); + let mut wv = (0..initial_len) + .map(|i| AssertDrop { string: i.to_string(), counter }) + .collect::>(); + retain_fn(&mut wv, &mut |d| predicate(d.string.as_str())); + assert_eq!(counter.get(), expect_retain_drops); + assert_eq!(wv.iter().map(|d| d.string.as_str()).collect::>(), expect_after_retain); + drop(wv); + assert_eq!(counter.get(), initial_len); +} + +fn test_retain( + initial_len: usize, + mut predicate: impl FnMut(&str) -> bool, + expect_retain_drops: usize, + expect_after_retain: &[&str], +) { + test_retain_with::( + initial_len, + &mut predicate, + expect_retain_drops, + expect_after_retain, + |wv, predicate| wv.retain(|d| predicate(d)), + ); +} + +fn test_extract_if( + initial_len: usize, + mut predicate: impl FnMut(&str) -> bool, + expect_retain_drops: usize, + expect_after_retain: &[&str], +) { + test_retain_with::( + initial_len, + &mut predicate, + expect_retain_drops, + expect_after_retain, + |wv, predicate| wv.extract_if(|d| !predicate(d)).for_each(drop), + ); +} + +#[test] +fn test_retain_empty() { test_retain::<4>(0, |_| unreachable!(), 0, &[]); } + +#[test] +fn test_retain_everything() { test_retain::<4>(3, |_| true, 0, &["0", "1", "2"]); } + +#[test] +fn test_retain_nothing() { test_retain::<4>(3, |_| false, 3, &[]); } + +#[test] +fn test_retain_tft() { + let mut retain_seq = [true, false, true].into_iter(); + test_retain::<4>(3, |_| retain_seq.next().unwrap(), 1, &["0", "2"]); +} + +#[test] +fn test_retain_ftf() { + let mut retain_seq = [false, true, false].into_iter(); + test_retain::<4>(3, |_| retain_seq.next().unwrap(), 2, &["1"]); +} + +fn test_retain_panic(retain_prev: bool, expect_retain_drops: usize, expect_after_retain: &[&str]) { + extern crate std; + + let counter = &Cell::new(0); + let mut wv = + (0..3).map(|i| AssertDrop { string: i.to_string(), counter }).collect::>(); + + _ = std::panic::catch_unwind({ + let mut wv = AssertUnwindSafe(&mut wv); + move || { + let mut next_index = 0; + wv.retain(|_| { + let index = next_index; + next_index += 1; + + #[expect(clippy::manual_assert, reason = "clarity")] + if index == 1 { + panic!("intentional panic"); + } + + retain_prev + }); + } + }); + + assert_eq!(counter.get(), expect_retain_drops); + assert_eq!(wv.iter().map(|d| d.string.as_str()).collect::>(), expect_after_retain); + drop(wv); + assert_eq!(counter.get(), 3); +} + +#[test] +fn test_retain_shifted_panic() { test_retain_panic(false, 1, &["1", "2"]); } + +#[test] +fn test_retain_unshifted_panic() { test_retain_panic(true, 0, &["0", "1", "2"]); } + +#[test] +fn test_extract_if_empty() { test_extract_if::<4>(0, |_| unreachable!(), 0, &[]); } + +#[test] +fn test_extract_if_everything() { test_extract_if::<4>(3, |_| true, 0, &["0", "1", "2"]); } + +#[test] +fn test_extract_if_nothing() { test_extract_if::<4>(3, |_| false, 3, &[]); } + +#[test] +fn test_extract_if_tft() { + let mut extract_if_seq = [true, false, true].into_iter(); + test_extract_if::<4>(3, |_| extract_if_seq.next().unwrap(), 1, &["0", "2"]); +} + +#[test] +fn test_extract_if_ftf() { + let mut extract_if_seq = [false, true, false].into_iter(); + test_extract_if::<4>(3, |_| extract_if_seq.next().unwrap(), 2, &["1"]); +} + +fn test_extract_if_drop( + retain_first: bool, + expect_extract_result: &str, + expect_retain_drops: usize, + expect_after_retain: &[&str], +) { + let counter = &Cell::new(0); + let mut wv = + (0..3).map(|i| AssertDrop { string: i.to_string(), counter }).collect::>(); + + { + let mut retain = retain_first; + let mut iter = wv.extract_if(|_| !mem::replace(&mut retain, false)); + assert_eq!(iter.next().unwrap().string, expect_extract_result); + } + + assert_eq!(counter.get(), expect_retain_drops); + assert_eq!(wv.iter().map(|d| d.string.as_str()).collect::>(), expect_after_retain); + drop(wv); + assert_eq!(counter.get(), 3); +} + +#[test] +fn test_extract_if_shifted_drop() { test_extract_if_drop(true, "1", 1, &["0", "2"]); } + +#[test] +fn test_extract_if_unshifted_drop() { test_extract_if_drop(false, "0", 1, &["1", "2"]); } + fn assert_resize( initial_len: usize, resize_len: usize,