Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
105 changes: 105 additions & 0 deletions nex-packet/src/checksum.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
//! Utilities for tracking checksum recalculation state.

use std::net::{Ipv4Addr, Ipv6Addr};

/// Controls how and when checksum recalculation happens for a packet.
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum ChecksumMode {
/// Checksum updates are handled manually by the caller.
Manual,
/// Checksum updates happen automatically whenever a tracked field changes.
Automatic,
}

impl Default for ChecksumMode {
fn default() -> Self {
ChecksumMode::Manual
}
}

/// Tracks whether a packet's checksum needs to be recomputed.
#[derive(Clone, Copy, Debug, Default, PartialEq, Eq)]
pub struct ChecksumState {
mode: ChecksumMode,
dirty: bool,
}

impl ChecksumState {
/// Creates a new checksum state with manual recalculation enabled.
pub fn new() -> Self {
Self::default()
}

/// Returns the current mode controlling checksum updates.
pub fn mode(&self) -> ChecksumMode {
self.mode
}

/// Sets how checksum updates should be handled.
pub fn set_mode(&mut self, mode: ChecksumMode) {
self.mode = mode;
}

/// Enables automatic checksum recomputation.
pub fn enable_automatic(&mut self) {
self.mode = ChecksumMode::Automatic;
}

/// Disables automatic checksum recomputation.
pub fn disable_automatic(&mut self) {
self.mode = ChecksumMode::Manual;
}

/// Returns true if checksum recomputation is automatic.
pub fn automatic(&self) -> bool {
matches!(self.mode, ChecksumMode::Automatic)
}

/// Marks the checksum as stale due to a field mutation.
pub fn mark_dirty(&mut self) {
self.dirty = true;
}

/// Clears the dirty flag after a successful recomputation.
pub fn clear_dirty(&mut self) {
self.dirty = false;
}

/// Returns true if the checksum needs to be recomputed.
pub fn is_dirty(&self) -> bool {
self.dirty
}
}

/// Captures the pseudo-header inputs required for transport checksum calculations.
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum TransportChecksumContext {
/// Transport checksum associated with an IPv4 pseudo-header.
Ipv4 {
source: Ipv4Addr,
destination: Ipv4Addr,
},
/// Transport checksum associated with an IPv6 pseudo-header.
Ipv6 {
source: Ipv6Addr,
destination: Ipv6Addr,
},
}

impl TransportChecksumContext {
/// Builds an IPv4 checksum context.
pub fn ipv4(source: Ipv4Addr, destination: Ipv4Addr) -> Self {
TransportChecksumContext::Ipv4 {
source,
destination,
}
}

/// Builds an IPv6 checksum context.
pub fn ipv6(source: Ipv6Addr, destination: Ipv6Addr) -> Self {
TransportChecksumContext::Ipv6 {
source,
destination,
}
}
}
188 changes: 180 additions & 8 deletions nex-packet/src/icmp.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
//! An ICMP packet abstraction.
use crate::checksum::{ChecksumMode, ChecksumState};
use crate::ipv4::IPV4_HEADER_LEN;
use crate::{
ethernet::ETHERNET_HEADER_LEN,
packet::{GenericMutablePacket, Packet},
packet::{MutablePacket, Packet},
};
use bytes::{BufMut, Bytes, BytesMut};
use nex_core::bitfield::u16be;
Expand Down Expand Up @@ -247,7 +248,154 @@ impl IcmpPacket {
}

/// Represents a mutable ICMP packet.
pub type MutableIcmpPacket<'a> = GenericMutablePacket<'a, IcmpPacket>;
pub struct MutableIcmpPacket<'a> {
buffer: &'a mut [u8],
checksum: ChecksumState,
}

impl<'a> MutablePacket<'a> for MutableIcmpPacket<'a> {
type Packet = IcmpPacket;

fn new(buffer: &'a mut [u8]) -> Option<Self> {
IcmpPacket::from_buf(buffer)?;
Some(Self {
buffer,
checksum: ChecksumState::new(),
})
}

fn packet(&self) -> &[u8] {
&*self.buffer
}

fn packet_mut(&mut self) -> &mut [u8] {
&mut *self.buffer
}

fn header(&self) -> &[u8] {
&self.packet()[..ICMP_COMMON_HEADER_LEN]
}

fn header_mut(&mut self) -> &mut [u8] {
let (header, _) = (&mut *self.buffer).split_at_mut(ICMP_COMMON_HEADER_LEN);
header
}

fn payload(&self) -> &[u8] {
&self.packet()[ICMP_COMMON_HEADER_LEN..]
}

fn payload_mut(&mut self) -> &mut [u8] {
let (_, payload) = (&mut *self.buffer).split_at_mut(ICMP_COMMON_HEADER_LEN);
payload
}
}

impl<'a> MutableIcmpPacket<'a> {
/// Create a mutable ICMP packet without performing validation.
pub fn new_unchecked(buffer: &'a mut [u8]) -> Self {
Self {
buffer,
checksum: ChecksumState::new(),
}
}

fn raw(&self) -> &[u8] {
&*self.buffer
}

fn raw_mut(&mut self) -> &mut [u8] {
&mut *self.buffer
}

fn after_field_mutation(&mut self) {
self.checksum.mark_dirty();
if self.checksum.automatic() {
let _ = self.recompute_checksum();
}
}

fn write_checksum(&mut self, value: u16) {
self.raw_mut()[2..4].copy_from_slice(&value.to_be_bytes());
}

/// Returns the checksum recalculation mode.
pub fn checksum_mode(&self) -> ChecksumMode {
self.checksum.mode()
}

/// Sets how checksum updates should be handled.
pub fn set_checksum_mode(&mut self, mode: ChecksumMode) {
self.checksum.set_mode(mode);
if self.checksum.automatic() && self.checksum.is_dirty() {
let _ = self.recompute_checksum();
}
}

/// Enables automatic checksum recomputation.
pub fn enable_auto_checksum(&mut self) {
self.set_checksum_mode(ChecksumMode::Automatic);
}

/// Disables automatic checksum recomputation.
pub fn disable_auto_checksum(&mut self) {
self.set_checksum_mode(ChecksumMode::Manual);
}

/// Returns true if the checksum needs to be recomputed.
pub fn is_checksum_dirty(&self) -> bool {
self.checksum.is_dirty()
}

/// Marks the checksum as dirty and recomputes it when automatic mode is enabled.
pub fn mark_checksum_dirty(&mut self) {
self.checksum.mark_dirty();
if self.checksum.automatic() {
let _ = self.recompute_checksum();
}
}

/// Recomputes the checksum for the current packet contents.
pub fn recompute_checksum(&mut self) -> Option<u16> {
let checksum = crate::util::checksum(self.raw(), 1) as u16;
self.write_checksum(checksum);
self.checksum.clear_dirty();
Some(checksum)
}

/// Returns the current ICMP type field.
pub fn get_type(&self) -> IcmpType {
IcmpType::new(self.raw()[0])
}

/// Sets the ICMP type field and marks the checksum as dirty.
pub fn set_type(&mut self, icmp_type: IcmpType) {
self.raw_mut()[0] = icmp_type.value();
self.after_field_mutation();
}

/// Returns the current ICMP code field.
pub fn get_code(&self) -> IcmpCode {
IcmpCode::new(self.raw()[1])
}

/// Sets the ICMP code field and marks the checksum as dirty.
pub fn set_code(&mut self, icmp_code: IcmpCode) {
self.raw_mut()[1] = icmp_code.value();
self.after_field_mutation();
}

/// Returns the serialized checksum value.
pub fn get_checksum(&self) -> u16 {
u16::from_be_bytes([self.raw()[2], self.raw()[3]])
}

/// Sets the serialized checksum value and clears the dirty flag.
pub fn set_checksum(&mut self, checksum: u16) {
self.write_checksum(checksum);
self.checksum.clear_dirty();
}
}

/// Calculates a checksum of an ICMP packet.
pub fn checksum(packet: &IcmpPacket) -> u16be {
Expand Down Expand Up @@ -639,19 +787,43 @@ mod tests {
}

#[test]
fn test_mutable_icmp_packet_alias() {
fn test_mutable_icmp_packet_manual_checksum() {
let mut raw = [
8, 0, 0, 0, // type, code, checksum
0, 1, 0, 1, // identifier, sequence
b'p', b'i',
];

let mut packet = MutableIcmpPacket::new(&mut raw).expect("mutable icmp");
packet.set_type(IcmpType::EchoReply);
assert!(packet.is_checksum_dirty());

let updated = packet.recompute_checksum().expect("checksum");
assert_eq!(packet.get_checksum(), updated);

let frozen = packet.freeze().expect("freeze");
let expected: u16 = checksum(&frozen).into();
assert_eq!(packet.get_checksum(), expected);
}

#[test]
fn test_mutable_icmp_packet_auto_checksum() {
let mut raw = [
8, 0, 0, 0, // type, code, checksum
0, 1, 0, 1, // identifier, sequence
b'p', b'i',
];

let mut packet = <MutableIcmpPacket as MutablePacket>::new(&mut raw).expect("mutable icmp");
packet.header_mut()[0] = IcmpType::EchoReply.value();
packet.payload_mut()[0] = b'x';
let mut packet = MutableIcmpPacket::new(&mut raw).expect("mutable icmp");
let baseline = packet.recompute_checksum().expect("checksum");
packet.enable_auto_checksum();
packet.set_code(IcmpCode::new(1));

assert!(!packet.is_checksum_dirty());

let frozen = packet.freeze().expect("freeze");
assert_eq!(frozen.header.icmp_type, IcmpType::EchoReply);
assert_eq!(frozen.payload[0], b'x');
let expected: u16 = checksum(&frozen).into();
assert_ne!(baseline, expected);
assert_eq!(packet.get_checksum(), expected);
}
}
Loading