diff --git a/examples/async_icmp_socket.rs b/examples/async_icmp_socket.rs index d948357..93fa3cc 100644 --- a/examples/async_icmp_socket.rs +++ b/examples/async_icmp_socket.rs @@ -98,7 +98,6 @@ async fn main() -> std::io::Result<()> { .icmp_code(icmp::echo_request::IcmpCodes::NoCode) .echo_fields(id, seq) .payload(Bytes::from_static(b"ping")) - .calculate_checksum() .to_bytes(); let target = SocketAddr::new(IpAddr::V4(addr), 0); let _ = socket.send_to(&pkt, target).await; diff --git a/examples/icmp_ping.rs b/examples/icmp_ping.rs index 076fe5e..431b280 100644 --- a/examples/icmp_ping.rs +++ b/examples/icmp_ping.rs @@ -70,7 +70,6 @@ fn main() { .icmp_code(icmp::echo_request::IcmpCodes::NoCode) .echo_fields(0x1234, 0x1) .payload(Bytes::from_static(b"hello")) - .calculate_checksum() .build() .to_bytes(), (IpAddr::V6(src), IpAddr::V6(dst)) => Icmpv6PacketBuilder::new(src, dst) @@ -78,7 +77,6 @@ fn main() { .icmpv6_code(icmpv6::echo_request::Icmpv6Codes::NoCode) .echo_fields(0x1234, 0x1) .payload(Bytes::from_static(b"hello")) - .calculate_checksum() .build() .to_bytes(), _ => panic!("Source and destination IP version mismatch"), diff --git a/examples/icmp_socket.rs b/examples/icmp_socket.rs index aa33741..a666057 100644 --- a/examples/icmp_socket.rs +++ b/examples/icmp_socket.rs @@ -56,14 +56,12 @@ fn main() -> std::io::Result<()> { .icmp_code(icmp::echo_request::IcmpCodes::NoCode) .echo_fields(0x1234, 1) .payload(Bytes::from_static(b"hello")) - .calculate_checksum() .to_bytes(), (IpAddr::V6(src), IpAddr::V6(dst)) => Icmpv6PacketBuilder::new(src, dst) .icmpv6_type(nex_packet::icmpv6::Icmpv6Type::EchoRequest) .icmpv6_code(icmpv6::echo_request::Icmpv6Codes::NoCode) .echo_fields(0x1234, 1) .payload(Bytes::from_static(b"hello")) - .calculate_checksum() .to_bytes(), _ => unreachable!(), }; diff --git a/examples/tcp_ping.rs b/examples/tcp_ping.rs index 7e05b43..1bd4810 100644 --- a/examples/tcp_ping.rs +++ b/examples/tcp_ping.rs @@ -106,7 +106,7 @@ fn main() { } // Packet builder for TCP SYN - let tcp_packet = TcpPacketBuilder::new() + let tcp_packet = TcpPacketBuilder::new(src_ip, dst_ip) .source(53443) .destination(target_socket.port()) .flags(TcpFlags::SYN) @@ -118,7 +118,6 @@ fn main() { TcpOptionPacket::nop(), TcpOptionPacket::wscale(7), ]) - .calculate_checksum(&src_ip, &dst_ip) .build(); let ip_packet: Bytes; diff --git a/examples/udp_ping.rs b/examples/udp_ping.rs index 929acf2..580e785 100644 --- a/examples/udp_ping.rs +++ b/examples/udp_ping.rs @@ -65,10 +65,9 @@ fn main() { .expect("No global IPv6 address on interface"), }; - let udp_packet = UdpPacketBuilder::new() + let udp_packet = UdpPacketBuilder::new(src_ip, target_ip) .source(SRC_PORT) .destination(DST_PORT) - .calculate_checksum(&src_ip, &target_ip) .build(); let ip_packet: Bytes = match (src_ip, target_ip) { diff --git a/nex-packet/src/builder/icmp.rs b/nex-packet/src/builder/icmp.rs index 1d2caff..690f75b 100644 --- a/nex-packet/src/builder/icmp.rs +++ b/nex-packet/src/builder/icmp.rs @@ -1,7 +1,7 @@ use std::net::Ipv4Addr; use crate::{ - icmp::{self, checksum, IcmpCode, IcmpHeader, IcmpPacket, IcmpType}, + icmp::{self, IcmpCode, IcmpHeader, IcmpPacket, IcmpType}, packet::Packet, }; use bytes::{BufMut, Bytes, BytesMut}; @@ -62,15 +62,15 @@ impl IcmpPacketBuilder { self } + /// Calculate the checksum and set it in the header pub fn calculate_checksum(mut self) -> Self { - // Calculate the checksum and set it in the header - self.packet.header.checksum = checksum(&self.packet); + self.packet.header.checksum = icmp::checksum(&self.packet); self } /// Return an `IcmpPacket` with checksum computed pub fn build(mut self) -> IcmpPacket { - self.packet.header.checksum = checksum(&self.packet); + self.packet.header.checksum = icmp::checksum(&self.packet); self.packet } diff --git a/nex-packet/src/builder/icmpv6.rs b/nex-packet/src/builder/icmpv6.rs index 7d613d7..636517e 100644 --- a/nex-packet/src/builder/icmpv6.rs +++ b/nex-packet/src/builder/icmpv6.rs @@ -1,7 +1,7 @@ use std::net::Ipv6Addr; use crate::{ - icmpv6::{self, checksum, Icmpv6Code, Icmpv6Header, Icmpv6Packet, Icmpv6Type}, + icmpv6::{self, Icmpv6Code, Icmpv6Header, Icmpv6Packet, Icmpv6Type}, packet::Packet, }; use bytes::{BufMut, Bytes, BytesMut}; @@ -58,15 +58,15 @@ impl Icmpv6PacketBuilder { self } + /// Calculate the checksum and set it in the header pub fn calculate_checksum(mut self) -> Self { - // Calculate the checksum and set it in the header - self.packet.header.checksum = checksum(&self.packet, &self.source, &self.destination); + self.packet.header.checksum = icmpv6::checksum(&self.packet, &self.source, &self.destination); self } /// Return an `Icmpv6Packet` with checksum computed pub fn build(mut self) -> Icmpv6Packet { - self.packet.header.checksum = checksum(&self.packet, &self.source, &self.destination); + self.packet.header.checksum = icmpv6::checksum(&self.packet, &self.source, &self.destination); self.packet } diff --git a/nex-packet/src/builder/ipv4.rs b/nex-packet/src/builder/ipv4.rs index 4395d4f..526cfbc 100644 --- a/nex-packet/src/builder/ipv4.rs +++ b/nex-packet/src/builder/ipv4.rs @@ -100,7 +100,6 @@ impl Ipv4PacketBuilder { pub fn build(mut self) -> Ipv4Packet { let total_length = self.packet.header_len() + self.packet.payload_len(); self.packet.header.total_length = total_length as u16be; - self.packet.header.checksum = 0; self.packet.header.checksum = crate::ipv4::checksum(&self.packet); self.packet } diff --git a/nex-packet/src/builder/tcp.rs b/nex-packet/src/builder/tcp.rs index c18627e..302336f 100644 --- a/nex-packet/src/builder/tcp.rs +++ b/nex-packet/src/builder/tcp.rs @@ -7,13 +7,17 @@ use bytes::Bytes; /// Builder for constructing TCP packets #[derive(Debug, Clone)] pub struct TcpPacketBuilder { + src_ip: IpAddr, + dst_ip: IpAddr, packet: TcpPacket, } impl TcpPacketBuilder { /// Create a new builder - pub fn new() -> Self { + pub fn new(src_ip: IpAddr, dst_ip: IpAddr) -> Self { Self { + src_ip, + dst_ip, packet: TcpPacket { header: TcpHeader { source: 0, @@ -89,29 +93,33 @@ impl TcpPacketBuilder { self } - pub fn calculate_checksum(mut self, src_ip: &IpAddr, dst_ip: &IpAddr) -> Self { - // Calculate the checksum and set it in the header - self.packet.header.checksum = crate::tcp::checksum(&self.packet, src_ip, dst_ip); + /// Calculate the checksum and set it in the header + pub fn calculate_checksum(mut self) -> Self { + self.packet.header.checksum = crate::tcp::checksum(&self.packet, &self.src_ip, &self.dst_ip); self } - pub fn build(self) -> TcpPacket { + /// Build the packet with checksum computed + pub fn build(mut self) -> TcpPacket { + self.packet.header.checksum = crate::tcp::checksum(&self.packet, &self.src_ip, &self.dst_ip); self.packet } - + /// Serialize the packet into bytes with checksum computed pub fn to_bytes(self) -> Bytes { - self.packet.to_bytes() + self.build().to_bytes() } } #[cfg(test)] mod tests { + use std::net::Ipv4Addr; + use super::*; use crate::tcp::TcpFlags; use bytes::Bytes; #[test] fn tcp_builder_basic() { - let pkt = TcpPacketBuilder::new() + let pkt = TcpPacketBuilder::new(IpAddr::V4(Ipv4Addr::new(192, 168, 1, 100)), IpAddr::V4(Ipv4Addr::new(192, 168, 1, 1))) .source(1234) .destination(80) .sequence(1) diff --git a/nex-packet/src/builder/udp.rs b/nex-packet/src/builder/udp.rs index f55d11d..d7d807e 100644 --- a/nex-packet/src/builder/udp.rs +++ b/nex-packet/src/builder/udp.rs @@ -7,13 +7,17 @@ use bytes::Bytes; /// Builder for constructing UDP packets #[derive(Debug, Clone)] pub struct UdpPacketBuilder { + src_ip: IpAddr, + dst_ip: IpAddr, packet: UdpPacket, } impl UdpPacketBuilder { /// Create a new builder - pub fn new() -> Self { + pub fn new(src_ip: IpAddr, dst_ip: IpAddr) -> Self { Self { + src_ip, + dst_ip, packet: UdpPacket { header: UdpHeader { source: 0, @@ -50,21 +54,24 @@ impl UdpPacketBuilder { self } - pub fn calculate_checksum(mut self, src_ip: &IpAddr, dst_ip: &IpAddr) -> Self { + /// Calculate the checksum and set it in the header + pub fn calculate_checksum(mut self) -> Self { // Calculate the checksum and set it in the header - self.packet.header.checksum = crate::udp::checksum(&self.packet, src_ip, dst_ip); + self.packet.header.checksum = crate::udp::checksum(&self.packet, &self.src_ip, &self.dst_ip); self } - /// Build the packet + /// Build the packet with checksum computed pub fn build(mut self) -> UdpPacket { // Automatically compute the length let total_len = UDP_HEADER_LEN + self.packet.payload.len(); self.packet.header.length = (total_len as u16).into(); + // Calculate the checksum + self.packet.header.checksum = crate::udp::checksum(&self.packet, &self.src_ip, &self.dst_ip); self.packet } - /// Serialize the packet into bytes + /// Serialize the packet into bytes with checksum computed pub fn to_bytes(self) -> Bytes { self.build().to_bytes() } @@ -79,12 +86,14 @@ impl UdpPacketBuilder { #[cfg(test)] mod tests { + use std::net::Ipv4Addr; + use super::*; use bytes::Bytes; #[test] fn udp_builder_sets_length() { - let pkt = UdpPacketBuilder::new() + let pkt = UdpPacketBuilder::new(IpAddr::V4(Ipv4Addr::new(192, 168, 1, 100)), IpAddr::V4(Ipv4Addr::new(192, 168, 1, 1))) .source(1) .destination(2) .payload(Bytes::from_static(&[1, 2, 3]))