diff --git a/rust/ruby-prism/src/lib.rs b/rust/ruby-prism/src/lib.rs index af430e1495..3561761cc4 100644 --- a/rust/ruby-prism/src/lib.rs +++ b/rust/ruby-prism/src/lib.rs @@ -139,6 +139,11 @@ pub struct NodeList<'pr> { } impl<'pr> NodeList<'pr> { + unsafe fn at(&self, index: usize) -> Node<'pr> { + let node: *mut pm_node_t = *(self.pointer.as_ref().nodes.add(index)); + Node::new(self.parser, node) + } + /// Returns an iterator over the nodes. #[must_use] pub const fn iter(&self) -> NodeListIter<'pr> { @@ -161,6 +166,26 @@ impl<'pr> NodeList<'pr> { pub const fn is_empty(&self) -> bool { self.len() == 0 } + + /// Returns the first element of the list, or `None` if it is empty. + #[must_use] + pub fn first(&self) -> Option> { + if self.is_empty() { + None + } else { + Some(unsafe { self.at(0) }) + } + } + + /// Returns the last element of the list, or `None` if it is empty. + #[must_use] + pub fn last(&self) -> Option> { + if self.is_empty() { + None + } else { + Some(unsafe { self.at(self.len() - 1) }) + } + } } impl<'pr> IntoIterator for &NodeList<'pr> { @@ -245,6 +270,11 @@ pub struct ConstantList<'pr> { } impl<'pr> ConstantList<'pr> { + const unsafe fn at(&self, index: usize) -> ConstantId<'pr> { + let constant_id: pm_constant_id_t = *(self.pointer.as_ref().ids.add(index)); + ConstantId::new(self.parser, constant_id) + } + /// Returns an iterator over the constants in the list. #[must_use] pub const fn iter(&self) -> ConstantListIter<'pr> { @@ -267,6 +297,26 @@ impl<'pr> ConstantList<'pr> { pub const fn is_empty(&self) -> bool { self.len() == 0 } + + /// Returns the first element of the list, or `None` if it is empty. + #[must_use] + pub const fn first(&self) -> Option> { + if self.is_empty() { + None + } else { + Some(unsafe { self.at(0) }) + } + } + + /// Returns the last element of the list, or `None` if it is empty. + #[must_use] + pub const fn last(&self) -> Option> { + if self.is_empty() { + None + } else { + Some(unsafe { self.at(self.len() - 1) }) + } + } } impl<'pr> IntoIterator for &ConstantList<'pr> { @@ -815,23 +865,20 @@ mod tests { #[test] fn constant_id_test() { - let source = "module Foo; x = 1; end"; + let source = "module Foo; x = 1; y = 2; end"; let result = parse(source.as_ref()); let node = result.node(); assert_eq!(node.as_program_node().unwrap().statements().body().len(), 1); assert!(!node.as_program_node().unwrap().statements().body().is_empty()); - let module = node.as_program_node().unwrap().statements().body().iter().next().unwrap(); + let module = node.as_program_node().and_then(|pn| pn.statements().body().first()).unwrap(); let module = module.as_module_node().unwrap(); - assert_eq!(module.locals().len(), 1); + assert_eq!(module.locals().len(), 2); assert!(!module.locals().is_empty()); - let locals = module.locals().iter().collect::>(); - - assert_eq!(locals.len(), 1); - - assert_eq!(locals[0].as_slice(), b"x"); + assert_eq!(module.locals().first().unwrap().as_slice(), b"x"); + assert_eq!(module.locals().last().unwrap().as_slice(), b"y"); let source = "module Foo; end"; let result = parse(source.as_ref()); @@ -839,7 +886,7 @@ mod tests { let node = result.node(); assert_eq!(node.as_program_node().unwrap().statements().body().len(), 1); assert!(!node.as_program_node().unwrap().statements().body().is_empty()); - let module = node.as_program_node().unwrap().statements().body().iter().next().unwrap(); + let module = node.as_program_node().and_then(|pn| pn.statements().body().first()).unwrap(); let module = module.as_module_node().unwrap(); assert_eq!(module.locals().len(), 0);