diff --git a/nex-packet/src/checksum.rs b/nex-packet/src/checksum.rs new file mode 100644 index 0000000..9ae3c5e --- /dev/null +++ b/nex-packet/src/checksum.rs @@ -0,0 +1,105 @@ +//! Utilities for tracking checksum recalculation state. + +use std::net::{Ipv4Addr, Ipv6Addr}; + +/// Controls how and when checksum recalculation happens for a packet. +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub enum ChecksumMode { + /// Checksum updates are handled manually by the caller. + Manual, + /// Checksum updates happen automatically whenever a tracked field changes. + Automatic, +} + +impl Default for ChecksumMode { + fn default() -> Self { + ChecksumMode::Manual + } +} + +/// Tracks whether a packet's checksum needs to be recomputed. +#[derive(Clone, Copy, Debug, Default, PartialEq, Eq)] +pub struct ChecksumState { + mode: ChecksumMode, + dirty: bool, +} + +impl ChecksumState { + /// Creates a new checksum state with manual recalculation enabled. + pub fn new() -> Self { + Self::default() + } + + /// Returns the current mode controlling checksum updates. + pub fn mode(&self) -> ChecksumMode { + self.mode + } + + /// Sets how checksum updates should be handled. + pub fn set_mode(&mut self, mode: ChecksumMode) { + self.mode = mode; + } + + /// Enables automatic checksum recomputation. + pub fn enable_automatic(&mut self) { + self.mode = ChecksumMode::Automatic; + } + + /// Disables automatic checksum recomputation. + pub fn disable_automatic(&mut self) { + self.mode = ChecksumMode::Manual; + } + + /// Returns true if checksum recomputation is automatic. + pub fn automatic(&self) -> bool { + matches!(self.mode, ChecksumMode::Automatic) + } + + /// Marks the checksum as stale due to a field mutation. + pub fn mark_dirty(&mut self) { + self.dirty = true; + } + + /// Clears the dirty flag after a successful recomputation. + pub fn clear_dirty(&mut self) { + self.dirty = false; + } + + /// Returns true if the checksum needs to be recomputed. + pub fn is_dirty(&self) -> bool { + self.dirty + } +} + +/// Captures the pseudo-header inputs required for transport checksum calculations. +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub enum TransportChecksumContext { + /// Transport checksum associated with an IPv4 pseudo-header. + Ipv4 { + source: Ipv4Addr, + destination: Ipv4Addr, + }, + /// Transport checksum associated with an IPv6 pseudo-header. + Ipv6 { + source: Ipv6Addr, + destination: Ipv6Addr, + }, +} + +impl TransportChecksumContext { + /// Builds an IPv4 checksum context. + pub fn ipv4(source: Ipv4Addr, destination: Ipv4Addr) -> Self { + TransportChecksumContext::Ipv4 { + source, + destination, + } + } + + /// Builds an IPv6 checksum context. + pub fn ipv6(source: Ipv6Addr, destination: Ipv6Addr) -> Self { + TransportChecksumContext::Ipv6 { + source, + destination, + } + } +} diff --git a/nex-packet/src/icmp.rs b/nex-packet/src/icmp.rs index ae2154d..d95b6ae 100644 --- a/nex-packet/src/icmp.rs +++ b/nex-packet/src/icmp.rs @@ -1,8 +1,9 @@ //! An ICMP packet abstraction. +use crate::checksum::{ChecksumMode, ChecksumState}; use crate::ipv4::IPV4_HEADER_LEN; use crate::{ ethernet::ETHERNET_HEADER_LEN, - packet::{GenericMutablePacket, Packet}, + packet::{MutablePacket, Packet}, }; use bytes::{BufMut, Bytes, BytesMut}; use nex_core::bitfield::u16be; @@ -247,7 +248,154 @@ impl IcmpPacket { } /// Represents a mutable ICMP packet. -pub type MutableIcmpPacket<'a> = GenericMutablePacket<'a, IcmpPacket>; +pub struct MutableIcmpPacket<'a> { + buffer: &'a mut [u8], + checksum: ChecksumState, +} + +impl<'a> MutablePacket<'a> for MutableIcmpPacket<'a> { + type Packet = IcmpPacket; + + fn new(buffer: &'a mut [u8]) -> Option { + IcmpPacket::from_buf(buffer)?; + Some(Self { + buffer, + checksum: ChecksumState::new(), + }) + } + + fn packet(&self) -> &[u8] { + &*self.buffer + } + + fn packet_mut(&mut self) -> &mut [u8] { + &mut *self.buffer + } + + fn header(&self) -> &[u8] { + &self.packet()[..ICMP_COMMON_HEADER_LEN] + } + + fn header_mut(&mut self) -> &mut [u8] { + let (header, _) = (&mut *self.buffer).split_at_mut(ICMP_COMMON_HEADER_LEN); + header + } + + fn payload(&self) -> &[u8] { + &self.packet()[ICMP_COMMON_HEADER_LEN..] + } + + fn payload_mut(&mut self) -> &mut [u8] { + let (_, payload) = (&mut *self.buffer).split_at_mut(ICMP_COMMON_HEADER_LEN); + payload + } +} + +impl<'a> MutableIcmpPacket<'a> { + /// Create a mutable ICMP packet without performing validation. + pub fn new_unchecked(buffer: &'a mut [u8]) -> Self { + Self { + buffer, + checksum: ChecksumState::new(), + } + } + + fn raw(&self) -> &[u8] { + &*self.buffer + } + + fn raw_mut(&mut self) -> &mut [u8] { + &mut *self.buffer + } + + fn after_field_mutation(&mut self) { + self.checksum.mark_dirty(); + if self.checksum.automatic() { + let _ = self.recompute_checksum(); + } + } + + fn write_checksum(&mut self, value: u16) { + self.raw_mut()[2..4].copy_from_slice(&value.to_be_bytes()); + } + + /// Returns the checksum recalculation mode. + pub fn checksum_mode(&self) -> ChecksumMode { + self.checksum.mode() + } + + /// Sets how checksum updates should be handled. + pub fn set_checksum_mode(&mut self, mode: ChecksumMode) { + self.checksum.set_mode(mode); + if self.checksum.automatic() && self.checksum.is_dirty() { + let _ = self.recompute_checksum(); + } + } + + /// Enables automatic checksum recomputation. + pub fn enable_auto_checksum(&mut self) { + self.set_checksum_mode(ChecksumMode::Automatic); + } + + /// Disables automatic checksum recomputation. + pub fn disable_auto_checksum(&mut self) { + self.set_checksum_mode(ChecksumMode::Manual); + } + + /// Returns true if the checksum needs to be recomputed. + pub fn is_checksum_dirty(&self) -> bool { + self.checksum.is_dirty() + } + + /// Marks the checksum as dirty and recomputes it when automatic mode is enabled. + pub fn mark_checksum_dirty(&mut self) { + self.checksum.mark_dirty(); + if self.checksum.automatic() { + let _ = self.recompute_checksum(); + } + } + + /// Recomputes the checksum for the current packet contents. + pub fn recompute_checksum(&mut self) -> Option { + let checksum = crate::util::checksum(self.raw(), 1) as u16; + self.write_checksum(checksum); + self.checksum.clear_dirty(); + Some(checksum) + } + + /// Returns the current ICMP type field. + pub fn get_type(&self) -> IcmpType { + IcmpType::new(self.raw()[0]) + } + + /// Sets the ICMP type field and marks the checksum as dirty. + pub fn set_type(&mut self, icmp_type: IcmpType) { + self.raw_mut()[0] = icmp_type.value(); + self.after_field_mutation(); + } + + /// Returns the current ICMP code field. + pub fn get_code(&self) -> IcmpCode { + IcmpCode::new(self.raw()[1]) + } + + /// Sets the ICMP code field and marks the checksum as dirty. + pub fn set_code(&mut self, icmp_code: IcmpCode) { + self.raw_mut()[1] = icmp_code.value(); + self.after_field_mutation(); + } + + /// Returns the serialized checksum value. + pub fn get_checksum(&self) -> u16 { + u16::from_be_bytes([self.raw()[2], self.raw()[3]]) + } + + /// Sets the serialized checksum value and clears the dirty flag. + pub fn set_checksum(&mut self, checksum: u16) { + self.write_checksum(checksum); + self.checksum.clear_dirty(); + } +} /// Calculates a checksum of an ICMP packet. pub fn checksum(packet: &IcmpPacket) -> u16be { @@ -639,19 +787,43 @@ mod tests { } #[test] - fn test_mutable_icmp_packet_alias() { + fn test_mutable_icmp_packet_manual_checksum() { + let mut raw = [ + 8, 0, 0, 0, // type, code, checksum + 0, 1, 0, 1, // identifier, sequence + b'p', b'i', + ]; + + let mut packet = MutableIcmpPacket::new(&mut raw).expect("mutable icmp"); + packet.set_type(IcmpType::EchoReply); + assert!(packet.is_checksum_dirty()); + + let updated = packet.recompute_checksum().expect("checksum"); + assert_eq!(packet.get_checksum(), updated); + + let frozen = packet.freeze().expect("freeze"); + let expected: u16 = checksum(&frozen).into(); + assert_eq!(packet.get_checksum(), expected); + } + + #[test] + fn test_mutable_icmp_packet_auto_checksum() { let mut raw = [ 8, 0, 0, 0, // type, code, checksum 0, 1, 0, 1, // identifier, sequence b'p', b'i', ]; - let mut packet = ::new(&mut raw).expect("mutable icmp"); - packet.header_mut()[0] = IcmpType::EchoReply.value(); - packet.payload_mut()[0] = b'x'; + let mut packet = MutableIcmpPacket::new(&mut raw).expect("mutable icmp"); + let baseline = packet.recompute_checksum().expect("checksum"); + packet.enable_auto_checksum(); + packet.set_code(IcmpCode::new(1)); + + assert!(!packet.is_checksum_dirty()); let frozen = packet.freeze().expect("freeze"); - assert_eq!(frozen.header.icmp_type, IcmpType::EchoReply); - assert_eq!(frozen.payload[0], b'x'); + let expected: u16 = checksum(&frozen).into(); + assert_ne!(baseline, expected); + assert_eq!(packet.get_checksum(), expected); } } diff --git a/nex-packet/src/icmpv6.rs b/nex-packet/src/icmpv6.rs index 55142ea..da71766 100644 --- a/nex-packet/src/icmpv6.rs +++ b/nex-packet/src/icmpv6.rs @@ -1,9 +1,11 @@ //! An ICMPv6 packet abstraction. +use crate::checksum::{ChecksumMode, ChecksumState, TransportChecksumContext}; +use crate::ip::IpNextProtocol; use crate::ipv6::IPV6_HEADER_LEN; use crate::{ ethernet::ETHERNET_HEADER_LEN, - packet::{GenericMutablePacket, Packet}, + packet::{MutablePacket, Packet}, }; use std::net::Ipv6Addr; @@ -294,15 +296,207 @@ impl Packet for Icmpv6Packet { } /// Represents a mutable ICMPv6 packet. -pub type MutableIcmpv6Packet<'a> = GenericMutablePacket<'a, Icmpv6Packet>; +pub struct MutableIcmpv6Packet<'a> { + buffer: &'a mut [u8], + checksum: ChecksumState, + checksum_context: Option, +} + +impl<'a> MutablePacket<'a> for MutableIcmpv6Packet<'a> { + type Packet = Icmpv6Packet; + + fn new(buffer: &'a mut [u8]) -> Option { + Icmpv6Packet::from_buf(buffer)?; + Some(Self { + buffer, + checksum: ChecksumState::new(), + checksum_context: None, + }) + } + + fn packet(&self) -> &[u8] { + &*self.buffer + } + + fn packet_mut(&mut self) -> &mut [u8] { + &mut *self.buffer + } + + fn header(&self) -> &[u8] { + &self.packet()[..ICMPV6_COMMON_HEADER_LEN] + } + + fn header_mut(&mut self) -> &mut [u8] { + let (header, _) = (&mut *self.buffer).split_at_mut(ICMPV6_COMMON_HEADER_LEN); + header + } + + fn payload(&self) -> &[u8] { + &self.packet()[ICMPV6_COMMON_HEADER_LEN..] + } + + fn payload_mut(&mut self) -> &mut [u8] { + let (_, payload) = (&mut *self.buffer).split_at_mut(ICMPV6_COMMON_HEADER_LEN); + payload + } +} + +impl<'a> MutableIcmpv6Packet<'a> { + /// Create a mutable ICMPv6 packet without performing validation. + pub fn new_unchecked(buffer: &'a mut [u8]) -> Self { + Self { + buffer, + checksum: ChecksumState::new(), + checksum_context: None, + } + } + + fn raw(&self) -> &[u8] { + &*self.buffer + } + + fn raw_mut(&mut self) -> &mut [u8] { + &mut *self.buffer + } + + fn after_field_mutation(&mut self) { + self.checksum.mark_dirty(); + if self.checksum.automatic() { + let _ = self.recompute_checksum(); + } + } + + fn write_checksum(&mut self, value: u16) { + self.raw_mut()[2..4].copy_from_slice(&value.to_be_bytes()); + } + + /// Returns the checksum recalculation mode. + pub fn checksum_mode(&self) -> ChecksumMode { + self.checksum.mode() + } + + /// Sets how checksum updates should be handled. + pub fn set_checksum_mode(&mut self, mode: ChecksumMode) { + self.checksum.set_mode(mode); + if self.checksum.automatic() && self.checksum.is_dirty() { + let _ = self.recompute_checksum(); + } + } + + /// Enables automatic checksum recomputation. + pub fn enable_auto_checksum(&mut self) { + self.set_checksum_mode(ChecksumMode::Automatic); + } + + /// Disables automatic checksum recomputation. + pub fn disable_auto_checksum(&mut self) { + self.set_checksum_mode(ChecksumMode::Manual); + } + + /// Returns true if the checksum needs to be recomputed. + pub fn is_checksum_dirty(&self) -> bool { + self.checksum.is_dirty() + } + + /// Marks the checksum as dirty and recomputes it when automatic mode is enabled. + pub fn mark_checksum_dirty(&mut self) { + self.checksum.mark_dirty(); + if self.checksum.automatic() { + let _ = self.recompute_checksum(); + } + } + + /// Sets the pseudo-header context required for checksum calculation. + pub fn set_checksum_context(&mut self, context: TransportChecksumContext) { + self.checksum_context = match context { + TransportChecksumContext::Ipv6 { .. } => Some(context), + _ => None, + }; + + if self.checksum.automatic() && self.checksum.is_dirty() { + let _ = self.recompute_checksum(); + } + } + + /// Configures the pseudo-header context for IPv6 checksums. + pub fn set_ipv6_checksum_context(&mut self, source: Ipv6Addr, destination: Ipv6Addr) { + self.set_checksum_context(TransportChecksumContext::ipv6(source, destination)); + } + + /// Clears the configured pseudo-header context. + pub fn clear_checksum_context(&mut self) { + self.checksum_context = None; + } + + /// Returns the configured pseudo-header context. + pub fn checksum_context(&self) -> Option { + self.checksum_context + } + + /// Recomputes the checksum using the configured pseudo-header context. + pub fn recompute_checksum(&mut self) -> Option { + let context = match self.checksum_context? { + TransportChecksumContext::Ipv6 { + source, + destination, + } => (source, destination), + _ => return None, + }; + + let checksum = crate::util::ipv6_checksum( + self.raw(), + 1, + &[], + &context.0, + &context.1, + IpNextProtocol::Icmpv6, + ) as u16; + + self.write_checksum(checksum); + self.checksum.clear_dirty(); + Some(checksum) + } + + /// Returns the ICMPv6 type field. + pub fn get_type(&self) -> Icmpv6Type { + Icmpv6Type::new(self.raw()[0]) + } + + /// Sets the ICMPv6 type field and marks the checksum as dirty. + pub fn set_type(&mut self, icmpv6_type: Icmpv6Type) { + self.raw_mut()[0] = icmpv6_type.value(); + self.after_field_mutation(); + } + + /// Returns the ICMPv6 code field. + pub fn get_code(&self) -> Icmpv6Code { + Icmpv6Code::new(self.raw()[1]) + } + + /// Sets the ICMPv6 code field and marks the checksum as dirty. + pub fn set_code(&mut self, icmpv6_code: Icmpv6Code) { + self.raw_mut()[1] = icmpv6_code.value(); + self.after_field_mutation(); + } + + /// Returns the serialized checksum value. + pub fn get_checksum(&self) -> u16 { + u16::from_be_bytes([self.raw()[2], self.raw()[3]]) + } + + /// Sets the serialized checksum value and clears the dirty flag. + pub fn set_checksum(&mut self, checksum: u16) { + self.write_checksum(checksum); + self.checksum.clear_dirty(); + } +} #[cfg(test)] mod tests { use super::*; - use crate::packet::MutablePacket; #[test] - fn test_mutable_icmpv6_packet_alias() { + fn test_mutable_icmpv6_packet_manual_checksum() { let mut raw = [ Icmpv6Type::EchoRequest.value(), 0, @@ -316,14 +510,50 @@ mod tests { b'i', ]; - let mut packet = - ::new(&mut raw).expect("mutable icmpv6"); - packet.header_mut()[0] = Icmpv6Type::EchoReply.value(); - packet.payload_mut()[0] = b'x'; + let mut packet = MutableIcmpv6Packet::new(&mut raw).expect("mutable icmpv6"); + let addr = Ipv6Addr::LOCALHOST; + packet.set_ipv6_checksum_context(addr, addr); + packet.set_type(Icmpv6Type::EchoReply); + + assert!(packet.is_checksum_dirty()); + + let updated = packet.recompute_checksum().expect("checksum"); + assert_eq!(packet.get_checksum(), updated); + + let frozen = packet.freeze().expect("freeze"); + let expected = checksum(&frozen, &addr, &addr); + assert_eq!(packet.get_checksum(), expected); + } + + #[test] + fn test_mutable_icmpv6_packet_auto_checksum() { + let mut raw = [ + Icmpv6Type::EchoRequest.value(), + 0, + 0, + 0, + 0, + 1, + 0, + 1, + b'p', + b'i', + ]; + + let mut packet = MutableIcmpv6Packet::new(&mut raw).expect("mutable icmpv6"); + let addr = Ipv6Addr::LOCALHOST; + packet.set_ipv6_checksum_context(addr, addr); + let baseline = packet.recompute_checksum().expect("checksum"); + + packet.enable_auto_checksum(); + packet.set_code(Icmpv6Code::new(1)); + + assert!(!packet.is_checksum_dirty()); let frozen = packet.freeze().expect("freeze"); - assert_eq!(frozen.header.icmpv6_type, Icmpv6Type::EchoReply); - assert_eq!(frozen.payload[0], b'x'); + let expected = checksum(&frozen, &addr, &addr); + assert_ne!(baseline, expected); + assert_eq!(packet.get_checksum(), expected); } } diff --git a/nex-packet/src/ipv4.rs b/nex-packet/src/ipv4.rs index ec64165..81ce672 100644 --- a/nex-packet/src/ipv4.rs +++ b/nex-packet/src/ipv4.rs @@ -1,8 +1,10 @@ //! An IPv4 packet abstraction. use crate::{ + checksum::{ChecksumMode, ChecksumState}, ip::IpNextProtocol, packet::{MutablePacket, Packet}, + util, }; use bytes::{BufMut, Bytes, BytesMut}; use nex_core::bitfield::*; @@ -420,6 +422,7 @@ impl Ipv4Packet { /// Represents a mutable IPv4 packet. pub struct MutableIpv4Packet<'a> { buffer: &'a mut [u8], + checksum: ChecksumState, } impl<'a> MutablePacket<'a> for MutableIpv4Packet<'a> { @@ -445,7 +448,10 @@ impl<'a> MutablePacket<'a> for MutableIpv4Packet<'a> { return None; } - Some(Self { buffer }) + Some(Self { + buffer, + checksum: ChecksumState::new(), + }) } fn packet(&self) -> &[u8] { @@ -484,7 +490,10 @@ impl<'a> MutablePacket<'a> for MutableIpv4Packet<'a> { impl<'a> MutableIpv4Packet<'a> { /// Create a mutable packet without validating the header fields. pub fn new_unchecked(buffer: &'a mut [u8]) -> Self { - Self { buffer } + Self { + buffer, + checksum: ChecksumState::new(), + } } fn raw(&self) -> &[u8] { @@ -495,6 +504,66 @@ impl<'a> MutableIpv4Packet<'a> { &mut *self.buffer } + fn after_field_mutation(&mut self) { + self.checksum.mark_dirty(); + if self.checksum.automatic() { + let _ = self.recompute_checksum(); + } + } + + fn write_checksum(&mut self, checksum: u16) { + self.raw_mut()[10..12].copy_from_slice(&checksum.to_be_bytes()); + } + + /// Returns the current checksum recalculation mode. + pub fn checksum_mode(&self) -> ChecksumMode { + self.checksum.mode() + } + + /// Updates the checksum recalculation mode. + pub fn set_checksum_mode(&mut self, mode: ChecksumMode) { + self.checksum.set_mode(mode); + if self.checksum.automatic() && self.checksum.is_dirty() { + let _ = self.recompute_checksum(); + } + } + + /// Enables automatic checksum recalculation. + pub fn enable_auto_checksum(&mut self) { + self.set_checksum_mode(ChecksumMode::Automatic); + } + + /// Disables automatic checksum recalculation. + pub fn disable_auto_checksum(&mut self) { + self.set_checksum_mode(ChecksumMode::Manual); + } + + /// Returns true when the checksum must be recomputed before serialization. + pub fn is_checksum_dirty(&self) -> bool { + self.checksum.is_dirty() + } + + /// Marks the checksum as stale and triggers recomputation when automatic mode is enabled. + pub fn mark_checksum_dirty(&mut self) { + self.checksum.mark_dirty(); + if self.checksum.automatic() { + let _ = self.recompute_checksum(); + } + } + + /// Recomputes the IPv4 header checksum using the current buffer contents. + pub fn recompute_checksum(&mut self) -> Option { + let header_len = self.header_len(); + if header_len > self.raw().len() { + return None; + } + + let checksum = util::checksum(&self.raw()[..header_len], 5) as u16; + self.write_checksum(checksum); + self.checksum.clear_dirty(); + Some(checksum) + } + /// Returns the header length in bytes. pub fn header_len(&self) -> usize { let ihl = (self.raw()[0] & 0x0F) as usize; @@ -527,6 +596,7 @@ impl<'a> MutableIpv4Packet<'a> { pub fn set_version(&mut self, version: u8) { let buffer = self.raw_mut(); buffer[0] = (buffer[0] & 0x0F) | ((version & 0x0F) << 4); + self.after_field_mutation(); } /// Retrieve the header length in 32-bit words. @@ -538,6 +608,7 @@ impl<'a> MutableIpv4Packet<'a> { pub fn set_header_length(&mut self, ihl: u8) { let buffer = self.raw_mut(); buffer[0] = (buffer[0] & 0xF0) | (ihl & 0x0F); + self.after_field_mutation(); } /// Retrieve the DSCP field. @@ -549,6 +620,7 @@ impl<'a> MutableIpv4Packet<'a> { pub fn set_dscp(&mut self, dscp: u8) { let buffer = self.raw_mut(); buffer[1] = (buffer[1] & 0x03) | ((dscp & 0x3F) << 2); + self.after_field_mutation(); } /// Retrieve the ECN field. @@ -560,6 +632,7 @@ impl<'a> MutableIpv4Packet<'a> { pub fn set_ecn(&mut self, ecn: u8) { let buffer = self.raw_mut(); buffer[1] = (buffer[1] & 0xFC) | (ecn & 0x03); + self.after_field_mutation(); } /// Retrieve the total length field. @@ -570,6 +643,7 @@ impl<'a> MutableIpv4Packet<'a> { /// Update the total length field. pub fn set_total_length(&mut self, len: u16) { self.raw_mut()[2..4].copy_from_slice(&len.to_be_bytes()); + self.after_field_mutation(); } /// Retrieve the identification field. @@ -580,6 +654,7 @@ impl<'a> MutableIpv4Packet<'a> { /// Update the identification field. pub fn set_identification(&mut self, id: u16) { self.raw_mut()[4..6].copy_from_slice(&id.to_be_bytes()); + self.after_field_mutation(); } /// Retrieve the flags field. @@ -591,6 +666,7 @@ impl<'a> MutableIpv4Packet<'a> { pub fn set_flags(&mut self, flags: u8) { let buffer = self.raw_mut(); buffer[6] = (buffer[6] & 0x1F) | ((flags & 0x07) << 5); + self.after_field_mutation(); } /// Retrieve the fragment offset field. @@ -603,6 +679,7 @@ impl<'a> MutableIpv4Packet<'a> { let buffer = self.raw_mut(); let combined = (u16::from_be_bytes([buffer[6], buffer[7]]) & 0xE000) | (offset & 0x1FFF); buffer[6..8].copy_from_slice(&combined.to_be_bytes()); + self.after_field_mutation(); } /// Retrieve the TTL field. @@ -613,6 +690,7 @@ impl<'a> MutableIpv4Packet<'a> { /// Update the TTL field. pub fn set_ttl(&mut self, ttl: u8) { self.raw_mut()[8] = ttl; + self.after_field_mutation(); } /// Retrieve the next-level protocol field. @@ -623,6 +701,7 @@ impl<'a> MutableIpv4Packet<'a> { /// Update the next-level protocol field. pub fn set_next_level_protocol(&mut self, proto: IpNextProtocol) { self.raw_mut()[9] = proto.value(); + self.after_field_mutation(); } /// Retrieve the checksum field. @@ -632,7 +711,8 @@ impl<'a> MutableIpv4Packet<'a> { /// Update the checksum field. pub fn set_checksum(&mut self, checksum: u16) { - self.raw_mut()[10..12].copy_from_slice(&checksum.to_be_bytes()); + self.write_checksum(checksum); + self.checksum.clear_dirty(); } /// Retrieve the source address. @@ -648,6 +728,7 @@ impl<'a> MutableIpv4Packet<'a> { /// Update the source address. pub fn set_source(&mut self, addr: Ipv4Addr) { self.raw_mut()[12..16].copy_from_slice(&addr.octets()); + self.after_field_mutation(); } /// Retrieve the destination address. @@ -663,6 +744,7 @@ impl<'a> MutableIpv4Packet<'a> { /// Update the destination address. pub fn set_destination(&mut self, addr: Ipv4Addr) { self.raw_mut()[16..20].copy_from_slice(&addr.octets()); + self.after_field_mutation(); } } @@ -870,4 +952,45 @@ mod tests { assert_eq!(frozen.header.destination, Ipv4Addr::new(192, 0, 2, 1)); assert_eq!(frozen.payload[0], 0x11); } + + #[test] + fn test_ipv4_auto_checksum_updates() { + let mut raw = [ + 0x45, 0x00, 0x00, 0x1c, 0x1c, 0x46, 0x40, 0x00, 0x40, 0x06, 0x00, 0x00, 0xc0, 0xa8, + 0x00, 0x01, 0xc0, 0xa8, 0x00, 0xc7, 0xde, 0xad, 0xbe, 0xef, 0xca, 0xfe, 0xba, 0xbe, + ]; + + let mut packet = MutableIpv4Packet::new(&mut raw).expect("mutable ipv4"); + packet.enable_auto_checksum(); + let baseline = packet.recompute_checksum().expect("checksum"); + let before = packet.get_checksum(); + assert_eq!(baseline, before); + + packet.set_ttl(0x41); + let after = packet.get_checksum(); + assert_ne!(before, after); + assert!(!packet.is_checksum_dirty()); + + let frozen = packet.freeze().expect("freeze"); + let expected = checksum(&frozen); + assert_eq!(after, expected); + } + + #[test] + fn test_ipv4_manual_checksum_tracking() { + let mut raw = [ + 0x45, 0x00, 0x00, 0x1c, 0x1c, 0x46, 0x40, 0x00, 0x40, 0x06, 0xb1, 0xe6, 0xc0, 0xa8, + 0x00, 0x01, 0xc0, 0xa8, 0x00, 0xc7, 0xde, 0xad, 0xbe, 0xef, 0xca, 0xfe, 0xba, 0xbe, + ]; + + let mut packet = MutableIpv4Packet::new(&mut raw).expect("mutable ipv4"); + assert!(!packet.is_checksum_dirty()); + + packet.set_identification(0x1c47); + assert!(packet.is_checksum_dirty()); + + let recomputed = packet.recompute_checksum().expect("checksum"); + assert_eq!(recomputed, packet.get_checksum()); + assert!(!packet.is_checksum_dirty()); + } } diff --git a/nex-packet/src/lib.rs b/nex-packet/src/lib.rs index 85c087c..7903d1c 100644 --- a/nex-packet/src/lib.rs +++ b/nex-packet/src/lib.rs @@ -2,6 +2,7 @@ pub mod arp; pub mod builder; +pub mod checksum; pub mod dhcp; pub mod dns; pub mod ethernet; diff --git a/nex-packet/src/tcp.rs b/nex-packet/src/tcp.rs index 0550c4d..d0e62bc 100644 --- a/nex-packet/src/tcp.rs +++ b/nex-packet/src/tcp.rs @@ -1,5 +1,6 @@ //! A TCP packet abstraction. +use crate::checksum::{ChecksumMode, ChecksumState, TransportChecksumContext}; use crate::ip::IpNextProtocol; use crate::packet::{MutablePacket, Packet}; @@ -668,6 +669,8 @@ impl TcpPacket { /// Represents a mutable TCP packet. pub struct MutableTcpPacket<'a> { buffer: &'a mut [u8], + checksum: ChecksumState, + checksum_context: Option, } impl<'a> MutablePacket<'a> for MutableTcpPacket<'a> { @@ -688,7 +691,11 @@ impl<'a> MutablePacket<'a> for MutableTcpPacket<'a> { return None; } - Some(Self { buffer }) + Some(Self { + buffer, + checksum: ChecksumState::new(), + checksum_context: None, + }) } fn packet(&self) -> &[u8] { @@ -725,7 +732,11 @@ impl<'a> MutablePacket<'a> for MutableTcpPacket<'a> { impl<'a> MutableTcpPacket<'a> { /// Create a packet without validating the header fields. pub fn new_unchecked(buffer: &'a mut [u8]) -> Self { - Self { buffer } + Self { + buffer, + checksum: ChecksumState::new(), + checksum_context: None, + } } fn raw(&self) -> &[u8] { @@ -736,6 +747,115 @@ impl<'a> MutableTcpPacket<'a> { &mut *self.buffer } + fn after_field_mutation(&mut self) { + self.checksum.mark_dirty(); + if self.checksum.automatic() { + let _ = self.recompute_checksum(); + } + } + + fn write_checksum(&mut self, value: u16) { + self.raw_mut()[16..18].copy_from_slice(&value.to_be_bytes()); + } + + /// Returns the checksum recalculation mode for the packet. + pub fn checksum_mode(&self) -> ChecksumMode { + self.checksum.mode() + } + + /// Updates how checksum recalculation should be handled. + pub fn set_checksum_mode(&mut self, mode: ChecksumMode) { + self.checksum.set_mode(mode); + if self.checksum.automatic() && self.checksum.is_dirty() { + let _ = self.recompute_checksum(); + } + } + + /// Enables automatic checksum recomputation after field mutations. + pub fn enable_auto_checksum(&mut self) { + self.set_checksum_mode(ChecksumMode::Automatic); + } + + /// Disables automatic checksum recomputation. + pub fn disable_auto_checksum(&mut self) { + self.set_checksum_mode(ChecksumMode::Manual); + } + + /// Returns true if the checksum needs to be updated before serialization. + pub fn is_checksum_dirty(&self) -> bool { + self.checksum.is_dirty() + } + + /// Marks the checksum as dirty and recomputes it when automatic mode is enabled. + pub fn mark_checksum_dirty(&mut self) { + self.checksum.mark_dirty(); + if self.checksum.automatic() { + let _ = self.recompute_checksum(); + } + } + + /// Configures the pseudo-header context required for checksum calculation. + pub fn set_checksum_context(&mut self, context: TransportChecksumContext) { + self.checksum_context = Some(context); + if self.checksum.automatic() && self.checksum.is_dirty() { + let _ = self.recompute_checksum(); + } + } + + /// Sets an IPv4 pseudo-header context for checksum calculation. + pub fn set_ipv4_checksum_context(&mut self, source: Ipv4Addr, destination: Ipv4Addr) { + self.set_checksum_context(TransportChecksumContext::ipv4(source, destination)); + } + + /// Sets an IPv6 pseudo-header context for checksum calculation. + pub fn set_ipv6_checksum_context(&mut self, source: Ipv6Addr, destination: Ipv6Addr) { + self.set_checksum_context(TransportChecksumContext::ipv6(source, destination)); + } + + /// Clears the configured pseudo-header context. + pub fn clear_checksum_context(&mut self) { + self.checksum_context = None; + } + + /// Returns the currently configured pseudo-header context. + pub fn checksum_context(&self) -> Option { + self.checksum_context + } + + /// Recomputes the checksum using the configured pseudo-header context. + pub fn recompute_checksum(&mut self) -> Option { + let context = self.checksum_context?; + + let checksum = match context { + TransportChecksumContext::Ipv4 { + source, + destination, + } => util::ipv4_checksum( + self.raw(), + 8, + &[], + &source, + &destination, + IpNextProtocol::Tcp, + ) as u16, + TransportChecksumContext::Ipv6 { + source, + destination, + } => util::ipv6_checksum( + self.raw(), + 8, + &[], + &source, + &destination, + IpNextProtocol::Tcp, + ) as u16, + }; + + self.write_checksum(checksum); + self.checksum.clear_dirty(); + Some(checksum) + } + /// Returns the header length in bytes. pub fn header_len(&self) -> usize { let offset = (self.raw()[12] >> 4).max(TCP_MIN_DATA_OFFSET); @@ -754,6 +874,7 @@ impl<'a> MutableTcpPacket<'a> { pub fn set_source(&mut self, value: u16) { self.raw_mut()[0..2].copy_from_slice(&value.to_be_bytes()); + self.after_field_mutation(); } pub fn get_destination(&self) -> u16 { @@ -762,6 +883,7 @@ impl<'a> MutableTcpPacket<'a> { pub fn set_destination(&mut self, value: u16) { self.raw_mut()[2..4].copy_from_slice(&value.to_be_bytes()); + self.after_field_mutation(); } pub fn get_sequence(&self) -> u32 { @@ -770,6 +892,7 @@ impl<'a> MutableTcpPacket<'a> { pub fn set_sequence(&mut self, value: u32) { self.raw_mut()[4..8].copy_from_slice(&value.to_be_bytes()); + self.after_field_mutation(); } pub fn get_acknowledgement(&self) -> u32 { @@ -778,6 +901,7 @@ impl<'a> MutableTcpPacket<'a> { pub fn set_acknowledgement(&mut self, value: u32) { self.raw_mut()[8..12].copy_from_slice(&value.to_be_bytes()); + self.after_field_mutation(); } pub fn get_data_offset(&self) -> u8 { @@ -787,6 +911,7 @@ impl<'a> MutableTcpPacket<'a> { pub fn set_data_offset(&mut self, offset: u8) { let buf = self.raw_mut(); buf[12] = (buf[12] & 0x0F) | ((offset & 0x0F) << 4); + self.after_field_mutation(); } pub fn get_reserved(&self) -> u8 { @@ -796,6 +921,7 @@ impl<'a> MutableTcpPacket<'a> { pub fn set_reserved(&mut self, value: u8) { let buf = self.raw_mut(); buf[12] = (buf[12] & 0xF0) | (value & 0x0F); + self.after_field_mutation(); } pub fn get_flags(&self) -> u8 { @@ -804,6 +930,7 @@ impl<'a> MutableTcpPacket<'a> { pub fn set_flags(&mut self, flags: u8) { self.raw_mut()[13] = flags; + self.after_field_mutation(); } pub fn get_window(&self) -> u16 { @@ -812,6 +939,7 @@ impl<'a> MutableTcpPacket<'a> { pub fn set_window(&mut self, value: u16) { self.raw_mut()[14..16].copy_from_slice(&value.to_be_bytes()); + self.after_field_mutation(); } pub fn get_checksum(&self) -> u16 { @@ -819,7 +947,8 @@ impl<'a> MutableTcpPacket<'a> { } pub fn set_checksum(&mut self, value: u16) { - self.raw_mut()[16..18].copy_from_slice(&value.to_be_bytes()); + self.write_checksum(value); + self.checksum.clear_dirty(); } pub fn get_urgent_ptr(&self) -> u16 { @@ -828,6 +957,7 @@ impl<'a> MutableTcpPacket<'a> { pub fn set_urgent_ptr(&mut self, value: u16) { self.raw_mut()[18..20].copy_from_slice(&value.to_be_bytes()); + self.after_field_mutation(); } pub fn options(&self) -> &[u8] { @@ -1017,4 +1147,50 @@ mod tests { assert_eq!(frozen.header.flags, 0x11); assert_eq!(frozen.payload[0], b'H'); } + + #[test] + fn test_tcp_auto_checksum_with_context() { + let mut raw = [ + 0x00, 0x50, 0x01, 0xbb, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x50, 0x18, + 0x40, 0x00, 0x00, 0x00, 0x00, 0x00, b'h', b'e', b'l', b'l', b'o', + ]; + + let mut packet = MutableTcpPacket::new(&mut raw).expect("mutable tcp"); + let src = Ipv4Addr::new(192, 0, 2, 1); + let dst = Ipv4Addr::new(198, 51, 100, 2); + packet.set_ipv4_checksum_context(src, dst); + packet.enable_auto_checksum(); + + let baseline = packet.recompute_checksum().expect("checksum"); + assert_eq!(baseline, packet.get_checksum()); + + packet.set_window(0x2000); + let updated = packet.get_checksum(); + assert_ne!(baseline, updated); + assert!(!packet.is_checksum_dirty()); + + let frozen = packet.freeze().expect("freeze"); + let expected = ipv4_checksum(&frozen, &src, &dst); + assert_eq!(updated, expected as u16); + } + + #[test] + fn test_tcp_manual_checksum_tracking() { + let mut raw = [ + 0x12, 0x34, 0xab, 0xcd, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x50, 0x02, + 0x10, 0x00, 0x00, 0x00, 0x00, 0x00, + ]; + + let mut packet = MutableTcpPacket::new(&mut raw).expect("mutable tcp"); + let src = Ipv6Addr::LOCALHOST; + let dst = Ipv6Addr::LOCALHOST; + packet.set_ipv6_checksum_context(src, dst); + + packet.set_flags(0x12); + assert!(packet.is_checksum_dirty()); + + let recomputed = packet.recompute_checksum().expect("checksum"); + assert_eq!(recomputed, packet.get_checksum()); + assert!(!packet.is_checksum_dirty()); + } } diff --git a/nex-packet/src/udp.rs b/nex-packet/src/udp.rs index da59aff..7e72003 100644 --- a/nex-packet/src/udp.rs +++ b/nex-packet/src/udp.rs @@ -1,5 +1,6 @@ //! A UDP packet abstraction. +use crate::checksum::{ChecksumMode, ChecksumState, TransportChecksumContext}; use crate::ip::IpNextProtocol; use crate::packet::{MutablePacket, Packet}; @@ -109,6 +110,8 @@ impl Packet for UdpPacket { /// Represents a mutable UDP packet. pub struct MutableUdpPacket<'a> { buffer: &'a mut [u8], + checksum: ChecksumState, + checksum_context: Option, } impl<'a> MutablePacket<'a> for MutableUdpPacket<'a> { @@ -130,7 +133,11 @@ impl<'a> MutablePacket<'a> for MutableUdpPacket<'a> { } } - Some(Self { buffer }) + Some(Self { + buffer, + checksum: ChecksumState::new(), + checksum_context: None, + }) } fn packet(&self) -> &[u8] { @@ -165,7 +172,11 @@ impl<'a> MutablePacket<'a> for MutableUdpPacket<'a> { impl<'a> MutableUdpPacket<'a> { /// Create a new packet without validating length fields. pub fn new_unchecked(buffer: &'a mut [u8]) -> Self { - Self { buffer } + Self { + buffer, + checksum: ChecksumState::new(), + checksum_context: None, + } } fn raw(&self) -> &[u8] { @@ -176,6 +187,115 @@ impl<'a> MutableUdpPacket<'a> { &mut *self.buffer } + fn after_field_mutation(&mut self) { + self.checksum.mark_dirty(); + if self.checksum.automatic() { + let _ = self.recompute_checksum(); + } + } + + fn write_checksum(&mut self, checksum: u16) { + self.raw_mut()[6..8].copy_from_slice(&checksum.to_be_bytes()); + } + + /// Returns the checksum recalculation mode. + pub fn checksum_mode(&self) -> ChecksumMode { + self.checksum.mode() + } + + /// Sets the checksum recalculation mode. + pub fn set_checksum_mode(&mut self, mode: ChecksumMode) { + self.checksum.set_mode(mode); + if self.checksum.automatic() && self.checksum.is_dirty() { + let _ = self.recompute_checksum(); + } + } + + /// Enables automatic checksum recalculation when tracked fields change. + pub fn enable_auto_checksum(&mut self) { + self.set_checksum_mode(ChecksumMode::Automatic); + } + + /// Disables automatic checksum recalculation. + pub fn disable_auto_checksum(&mut self) { + self.set_checksum_mode(ChecksumMode::Manual); + } + + /// Returns true if the checksum needs to be recomputed. + pub fn is_checksum_dirty(&self) -> bool { + self.checksum.is_dirty() + } + + /// Marks the checksum as stale and recomputes it when automatic mode is enabled. + pub fn mark_checksum_dirty(&mut self) { + self.checksum.mark_dirty(); + if self.checksum.automatic() { + let _ = self.recompute_checksum(); + } + } + + /// Defines the pseudo-header context used when recomputing the checksum. + pub fn set_checksum_context(&mut self, context: TransportChecksumContext) { + self.checksum_context = Some(context); + if self.checksum.automatic() && self.checksum.is_dirty() { + let _ = self.recompute_checksum(); + } + } + + /// Sets an IPv4 pseudo-header context used for checksum recomputation. + pub fn set_ipv4_checksum_context(&mut self, source: Ipv4Addr, destination: Ipv4Addr) { + self.set_checksum_context(TransportChecksumContext::ipv4(source, destination)); + } + + /// Sets an IPv6 pseudo-header context used for checksum recomputation. + pub fn set_ipv6_checksum_context(&mut self, source: Ipv6Addr, destination: Ipv6Addr) { + self.set_checksum_context(TransportChecksumContext::ipv6(source, destination)); + } + + /// Clears the configured checksum pseudo-header context. + pub fn clear_checksum_context(&mut self) { + self.checksum_context = None; + } + + /// Provides access to the configured checksum pseudo-header context. + pub fn checksum_context(&self) -> Option { + self.checksum_context + } + + /// Recomputes the UDP checksum if a pseudo-header context is available. + pub fn recompute_checksum(&mut self) -> Option { + let context = self.checksum_context?; + + let checksum = match context { + TransportChecksumContext::Ipv4 { + source, + destination, + } => util::ipv4_checksum( + self.raw(), + 3, + &[], + &source, + &destination, + IpNextProtocol::Udp, + ) as u16, + TransportChecksumContext::Ipv6 { + source, + destination, + } => util::ipv6_checksum( + self.raw(), + 3, + &[], + &source, + &destination, + IpNextProtocol::Udp, + ) as u16, + }; + + self.write_checksum(checksum); + self.checksum.clear_dirty(); + Some(checksum) + } + /// Returns the total length derived from the UDP length field. pub fn total_len(&self) -> usize { let field = u16::from_be_bytes([self.raw()[4], self.raw()[5]]); @@ -197,6 +317,7 @@ impl<'a> MutableUdpPacket<'a> { pub fn set_source(&mut self, port: u16) { self.raw_mut()[0..2].copy_from_slice(&port.to_be_bytes()); + self.after_field_mutation(); } pub fn get_destination(&self) -> u16 { @@ -205,6 +326,7 @@ impl<'a> MutableUdpPacket<'a> { pub fn set_destination(&mut self, port: u16) { self.raw_mut()[2..4].copy_from_slice(&port.to_be_bytes()); + self.after_field_mutation(); } pub fn get_length(&self) -> u16 { @@ -213,6 +335,7 @@ impl<'a> MutableUdpPacket<'a> { pub fn set_length(&mut self, length: u16) { self.raw_mut()[4..6].copy_from_slice(&length.to_be_bytes()); + self.after_field_mutation(); } pub fn get_checksum(&self) -> u16 { @@ -220,7 +343,8 @@ impl<'a> MutableUdpPacket<'a> { } pub fn set_checksum(&mut self, checksum: u16) { - self.raw_mut()[6..8].copy_from_slice(&checksum.to_be_bytes()); + self.write_checksum(checksum); + self.checksum.clear_dirty(); } } @@ -358,4 +482,56 @@ mod tests { assert_eq!(frozen.header.checksum, 0xffff); assert_eq!(&raw[UDP_HEADER_LEN], &b'x'); } + + #[test] + fn test_udp_auto_checksum_with_context() { + let mut raw = [ + 0x12, 0x34, // source + 0xab, 0xcd, // destination + 0x00, 0x0c, // length + 0x00, 0x00, // checksum placeholder + b'd', b'a', b't', b'a', // payload + ]; + + let mut packet = MutableUdpPacket::new(&mut raw).expect("mutable udp"); + let src = Ipv4Addr::new(192, 168, 0, 1); + let dst = Ipv4Addr::new(192, 168, 0, 2); + packet.set_ipv4_checksum_context(src, dst); + packet.enable_auto_checksum(); + + let baseline = packet.recompute_checksum().expect("checksum"); + assert_eq!(baseline, packet.get_checksum()); + + packet.set_destination(0xabce); + let updated = packet.get_checksum(); + assert_ne!(baseline, updated); + assert!(!packet.is_checksum_dirty()); + + let frozen = packet.freeze().expect("freeze"); + let expected = ipv4_checksum(&frozen, &src, &dst); + assert_eq!(updated, expected as u16); + } + + #[test] + fn test_udp_manual_checksum_tracking() { + let mut raw = [ + 0x12, 0x34, // source + 0xab, 0xcd, // destination + 0x00, 0x0c, // length + 0x00, 0x00, // checksum placeholder + b'd', b'a', b't', b'a', // payload + ]; + + let mut packet = MutableUdpPacket::new(&mut raw).expect("mutable udp"); + let src = Ipv4Addr::new(10, 0, 0, 1); + let dst = Ipv4Addr::new(10, 0, 0, 2); + packet.set_ipv4_checksum_context(src, dst); + + packet.set_source(0x2222); + assert!(packet.is_checksum_dirty()); + + let recomputed = packet.recompute_checksum().expect("checksum"); + assert_eq!(recomputed, packet.get_checksum()); + assert!(!packet.is_checksum_dirty()); + } }