diff --git a/examples/async_datalink.rs b/examples/async_datalink.rs new file mode 100644 index 0000000..0b60f40 --- /dev/null +++ b/examples/async_datalink.rs @@ -0,0 +1,182 @@ +//! Basic demonstration of asynchronous datalink send/receive. + +use bytes::Bytes; +use futures::{future::poll_fn, stream::StreamExt}; +use nex::net::mac::MacAddr; +use nex::packet::builder::ethernet::EthernetPacketBuilder; +use nex::packet::builder::icmp::IcmpPacketBuilder; +use nex::packet::builder::icmpv6::Icmpv6PacketBuilder; +use nex::packet::builder::ipv4::Ipv4PacketBuilder; +use nex::packet::builder::ipv6::Ipv6PacketBuilder; +use nex::packet::ethernet::EtherType; +use nex::packet::frame::{Frame, ParseOption}; +use nex::packet::icmp::IcmpType; +use nex::packet::icmpv6::Icmpv6Type; +use nex_core::interface::Interface; +use nex_datalink::async_io::{async_channel, AsyncChannel}; +use nex_datalink::Config; +use nex_packet::ip::IpNextProtocol; +use nex_packet::ipv4::Ipv4Flags; +use nex_packet::packet::Packet; +use nex_packet::{icmp, icmpv6}; +use std::env; +use std::net::IpAddr; + +fn main() -> std::io::Result<()> { + let interface = match env::args().nth(2) { + Some(name) => nex::net::interface::get_interfaces() + .into_iter() + .find(|i| i.name == name) + .expect("Failed to get interface"), + None => Interface::default().expect("Failed to get default interface"), + }; + let use_tun = interface.is_tun(); + + let target_ip: IpAddr = env::args() + .nth(1) + .expect("Missing target IP") + .parse() + .expect("Failed to parse target IP"); + + let AsyncChannel::Ethernet(mut tx, mut rx) = async_channel(&interface, Config::default())? + else { + unreachable!(); + }; + let src_ip: IpAddr = match target_ip { + IpAddr::V4(_) => interface + .ipv4 + .get(0) + .map(|v| IpAddr::V4(v.addr())) + .expect("No IPv4 address"), + IpAddr::V6(_) => interface + .ipv6 + .iter() + .find(|v| nex::net::ip::is_global_ipv6(&v.addr())) + .map(|v| IpAddr::V6(v.addr())) + .expect("No global IPv6 address"), + }; + + let icmp_packet: Bytes = match (src_ip, target_ip) { + (IpAddr::V4(src), IpAddr::V4(dst)) => IcmpPacketBuilder::new(src, dst) + .icmp_type(IcmpType::EchoRequest) + .icmp_code(icmp::echo_request::IcmpCodes::NoCode) + .echo_fields(0x1234, 0x1) + .payload(Bytes::from_static(b"hello")) + .build() + .to_bytes(), + (IpAddr::V6(src), IpAddr::V6(dst)) => Icmpv6PacketBuilder::new(src, dst) + .icmpv6_type(Icmpv6Type::EchoRequest) + .icmpv6_code(icmpv6::echo_request::Icmpv6Codes::NoCode) + .echo_fields(0x1234, 0x1) + .payload(Bytes::from_static(b"hello")) + .build() + .to_bytes(), + _ => panic!("Source and destination IP version mismatch"), + }; + + let ip_packet = match (src_ip, target_ip) { + (IpAddr::V4(src), IpAddr::V4(dst)) => Ipv4PacketBuilder::new() + .source(src) + .destination(dst) + .protocol(IpNextProtocol::Icmp) + .flags(Ipv4Flags::DontFragment) + .payload(icmp_packet) + .build() + .to_bytes(), + (IpAddr::V6(src), IpAddr::V6(dst)) => Ipv6PacketBuilder::new() + .source(src) + .destination(dst) + .next_header(IpNextProtocol::Icmpv6) + .payload(icmp_packet) + .build() + .to_bytes(), + _ => unreachable!(), + }; + + let ethernet_packet = EthernetPacketBuilder::new() + .source(if use_tun { + MacAddr::zero() + } else { + interface.mac_addr.clone().unwrap() + }) + .destination(if use_tun { + MacAddr::zero() + } else { + interface.gateway.clone().unwrap().mac_addr + }) + .ethertype(match target_ip { + IpAddr::V4(_) => EtherType::Ipv4, + IpAddr::V6(_) => EtherType::Ipv6, + }) + .payload(ip_packet) + .build(); + + let packet = if use_tun { + ethernet_packet.ip_packet().unwrap() + } else { + ethernet_packet.to_bytes() + }; + + futures::executor::block_on(async { + // Send a packet using poll_fn. + match poll_fn(|cx| tx.poll_send(cx, &packet)).await { + Ok(_) => println!("Packet sent"), + Err(e) => println!("Failed to send packet: {}", e), + } + // Receive a packet via StreamExt::next. + println!("Waiting for ICMP Echo Reply..."); + loop { + match rx.next().await { + Some(Ok(packet)) => { + let mut parse_option = ParseOption::default(); + if interface.is_tun() { + parse_option.from_ip_packet = true; + parse_option.offset = if interface.is_loopback() { 14 } else { 0 }; + } + let frame = Frame::from_buf(&packet, parse_option).unwrap(); + + if let Some(ip_layer) = &frame.ip { + if let Some(icmp) = &ip_layer.icmp { + if icmp.icmp_type == IcmpType::EchoReply { + println!( + "Received ICMP Echo Reply from {}", + ip_layer.ipv4.as_ref().unwrap().source + ); + println!( + "---- Interface: {}, Total Length: {} bytes ----", + interface.name, + packet.len() + ); + println!("Frame: {:?}", frame); + break; + } + } + if let Some(icmpv6) = &ip_layer.icmpv6 { + if icmpv6.icmpv6_type == Icmpv6Type::EchoReply { + println!( + "Received ICMPv6 Echo Reply from {}", + ip_layer.ipv6.as_ref().unwrap().source + ); + println!( + "---- Interface: {}, Total Length: {} bytes ----", + interface.name, + packet.len() + ); + println!("Frame: {:?}", frame); + break; + } + } + } + } + Some(Err(e)) => eprintln!("Failed to receive: {}", e), + None => { + eprintln!("Stream ended unexpectedly"); + break; + } + } + } + Ok::<(), std::io::Error>(()) + })?; + + Ok(()) +} diff --git a/examples/async_dump.rs b/examples/async_dump.rs new file mode 100644 index 0000000..ff58247 --- /dev/null +++ b/examples/async_dump.rs @@ -0,0 +1,320 @@ +//! Basic packet capture using asynchronous receive channel. + +use bytes::Bytes; +use futures::stream::StreamExt; +use nex::net::interface::Interface; +use nex::net::mac::MacAddr; +use nex::packet::arp::ArpPacket; +use nex::packet::ethernet::{EtherType, EthernetPacket}; +use nex::packet::icmp::{IcmpPacket, IcmpType}; +use nex::packet::icmpv6::Icmpv6Packet; +use nex::packet::ip::IpNextProtocol; +use nex::packet::ipv4::Ipv4Packet; +use nex::packet::ipv6::Ipv6Packet; +use nex::packet::packet::Packet; +use nex::packet::tcp::TcpPacket; +use nex::packet::udp::UdpPacket; +use nex_datalink::async_io::{async_channel, AsyncChannel}; +use nex_datalink::Config; +use nex_packet::ethernet::EthernetHeader; +use nex_packet::{icmp, icmpv6}; +use std::net::IpAddr; + +fn main() -> std::io::Result<()> { + // Choose the default interface. + let interface = Interface::default().expect("no default interface"); + let AsyncChannel::Ethernet(_tx, mut rx) = async_channel(&interface, Config::default())? else { + unreachable!(); + }; + + futures::executor::block_on(async { + let mut capture_no: usize = 0; + // Receive packets asynchronously. + while let Some(Ok(packet)) = rx.next().await { + capture_no += 1; + println!( + "---- Interface: {}, No.: {}, Total Length: {} bytes ----", + interface.name, + capture_no, + packet.len() + ); + + if interface.is_tun() + || (cfg!(any(target_os = "macos", target_os = "ios")) && interface.is_loopback()) + { + let payload_offset: usize; + if interface.is_loopback() { + payload_offset = 14; + } else { + payload_offset = 0; + } + let payload = Bytes::copy_from_slice(&packet[payload_offset..]); + if packet.len() > payload_offset { + let version = Ipv4Packet::from_buf(&packet).unwrap().header.version; + let fake_eth = EthernetPacket { + header: EthernetHeader { + destination: MacAddr::zero(), + source: MacAddr::zero(), + ethertype: if version == 4 { + EtherType::Ipv4 + } else { + EtherType::Ipv6 + }, + }, + payload, + }; + handle_ethernet_frame(fake_eth); + } + } else { + handle_ethernet_frame(EthernetPacket::from_buf(&packet).unwrap()); + } + } + Ok::<(), std::io::Error>(()) + })?; + + Ok(()) +} + +fn handle_ethernet_frame(ethernet: EthernetPacket) { + let total_len = ethernet.total_len(); + let (header, payload) = ethernet.into_parts(); + match header.ethertype { + EtherType::Ipv4 => handle_ipv4_packet(payload), + EtherType::Ipv6 => handle_ipv6_packet(payload), + EtherType::Arp => handle_arp_packet(payload), + _ => { + println!( + "{} packet: {} > {}; ethertype: {:?} length: {}", + header.ethertype.name(), + header.source, + header.destination, + header.ethertype, + total_len, + ) + } + } +} + +fn handle_arp_packet(packet: Bytes) { + if let Some(arp) = ArpPacket::from_bytes(packet) { + println!( + "ARP packet: {}({}) > {}({}); operation: {:?}", + arp.header.sender_hw_addr, + arp.header.sender_proto_addr, + arp.header.target_hw_addr, + arp.header.target_proto_addr, + arp.header.operation + ); + } else { + println!("Malformed ARP Packet"); + } +} + +fn handle_ipv4_packet(packet: Bytes) { + if let Some(ipv4) = Ipv4Packet::from_bytes(packet) { + handle_transport_protocol( + IpAddr::V4(ipv4.header.source), + IpAddr::V4(ipv4.header.destination), + ipv4.header.next_level_protocol, + ipv4.payload, + ); + } else { + println!("Malformed IPv4 Packet"); + } +} + +fn handle_ipv6_packet(packet: Bytes) { + if let Some(ipv6) = Ipv6Packet::from_bytes(packet) { + handle_transport_protocol( + IpAddr::V6(ipv6.header.source), + IpAddr::V6(ipv6.header.destination), + ipv6.header.next_header, + ipv6.payload, + ); + } else { + println!("Malformed IPv6 Packet"); + } +} + +fn handle_transport_protocol( + source: IpAddr, + destination: IpAddr, + protocol: IpNextProtocol, + packet: Bytes, +) { + match protocol { + IpNextProtocol::Tcp => handle_tcp_packet(source, destination, packet), + IpNextProtocol::Udp => handle_udp_packet(source, destination, packet), + IpNextProtocol::Icmp => handle_icmp_packet(source, destination, packet), + IpNextProtocol::Icmpv6 => handle_icmpv6_packet(source, destination, packet), + _ => println!( + "Unknown {} packet: {} > {}; protocol: {:?} length: {}", + match source { + IpAddr::V4(..) => "IPv4", + _ => "IPv6", + }, + source, + destination, + protocol, + packet.len() + ), + } +} + +fn handle_tcp_packet(source: IpAddr, destination: IpAddr, packet: Bytes) { + if let Some(tcp) = TcpPacket::from_bytes(packet) { + println!( + "TCP Packet: {}:{} > {}:{}; length: {}", + source, + tcp.header.source, + destination, + tcp.header.destination, + tcp.total_len(), + ); + } else { + println!("Malformed TCP Packet"); + } +} + +fn handle_udp_packet(source: IpAddr, destination: IpAddr, packet: Bytes) { + let udp = UdpPacket::from_bytes(packet); + + if let Some(udp) = udp { + println!( + "UDP Packet: {}:{} > {}:{}; length: {}", + source, + udp.header.source, + destination, + udp.header.destination, + udp.total_len(), + ); + } else { + println!("Malformed UDP Packet"); + } +} + +fn handle_icmp_packet(source: IpAddr, destination: IpAddr, packet: Bytes) { + let icmp_packet = IcmpPacket::from_bytes(packet); + if let Some(icmp_packet) = icmp_packet { + let total_len = icmp_packet.total_len(); + match icmp_packet.header.icmp_type { + IcmpType::EchoRequest => { + let echo_request_packet = + icmp::echo_request::EchoRequestPacket::try_from(icmp_packet).unwrap(); + println!( + "ICMP echo request {} -> {} (seq={:?}, id={:?}), length: {}", + source, + destination, + echo_request_packet.sequence_number, + echo_request_packet.identifier, + total_len + ); + } + IcmpType::EchoReply => { + let echo_reply_packet = + icmp::echo_reply::EchoReplyPacket::try_from(icmp_packet).unwrap(); + println!( + "ICMP echo reply {} -> {} (seq={:?}, id={:?}), length: {}", + source, + destination, + echo_reply_packet.sequence_number, + echo_reply_packet.identifier, + total_len, + ); + } + IcmpType::DestinationUnreachable => { + let unreachable_packet = + icmp::destination_unreachable::DestinationUnreachablePacket::try_from( + icmp_packet, + ) + .unwrap(); + println!( + "ICMP destination unreachable {} -> {} (code={:?}), next_hop_mtu={}, length: {}", + source, + destination, + unreachable_packet.header.icmp_code, + unreachable_packet.next_hop_mtu, + total_len + ); + } + IcmpType::TimeExceeded => { + let time_exceeded_packet = + icmp::time_exceeded::TimeExceededPacket::try_from(icmp_packet).unwrap(); + println!( + "ICMP time exceeded {} -> {} (code={:?}), length: {}", + source, destination, time_exceeded_packet.header.icmp_code, total_len + ); + } + _ => { + println!( + "ICMP packet {} -> {} (type={:?}), length: {}", + source, destination, icmp_packet.header.icmp_type, total_len + ) + } + } + } else { + println!("Malformed ICMP Packet"); + } +} + +fn handle_icmpv6_packet(source: IpAddr, destination: IpAddr, packet: Bytes) { + let icmpv6_packet = Icmpv6Packet::from_bytes(packet); + if let Some(icmpv6_packet) = icmpv6_packet { + match icmpv6_packet.header.icmpv6_type { + nex::packet::icmpv6::Icmpv6Type::EchoRequest => { + let echo_request_packet = + icmpv6::echo_request::EchoRequestPacket::try_from(icmpv6_packet).unwrap(); + println!( + "ICMPv6 echo request {} -> {} (type={:?}), length: {}", + source, + destination, + echo_request_packet.header.icmpv6_type, + echo_request_packet.total_len(), + ); + } + nex::packet::icmpv6::Icmpv6Type::EchoReply => { + let echo_reply_packet = + icmpv6::echo_reply::EchoReplyPacket::try_from(icmpv6_packet).unwrap(); + println!( + "ICMPv6 echo reply {} -> {} (type={:?}), length: {}", + source, + destination, + echo_reply_packet.header.icmpv6_type, + echo_reply_packet.total_len(), + ); + } + nex::packet::icmpv6::Icmpv6Type::NeighborSolicitation => { + let ns_packet = + icmpv6::ndp::NeighborSolicitPacket::try_from(icmpv6_packet).unwrap(); + println!( + "ICMPv6 neighbor solicitation {} -> {} (type={:?}), length: {}", + source, + destination, + ns_packet.header.icmpv6_type, + ns_packet.total_len(), + ); + } + nex::packet::icmpv6::Icmpv6Type::NeighborAdvertisement => { + let na_packet = icmpv6::ndp::NeighborAdvertPacket::try_from(icmpv6_packet).unwrap(); + println!( + "ICMPv6 neighbor advertisement {} -> {} (type={:?}), length: {}", + source, + destination, + na_packet.header.icmpv6_type, + na_packet.total_len(), + ); + } + _ => { + println!( + "ICMPv6 packet {} -> {} (type={:?}), length: {}", + source, + destination, + icmpv6_packet.header.icmpv6_type, + icmpv6_packet.total_len(), + ) + } + } + } else { + println!("Malformed ICMPv6 Packet"); + } +} diff --git a/nex-datalink/Cargo.toml b/nex-datalink/Cargo.toml index 34d95e2..1603140 100644 --- a/nex-datalink/Cargo.toml +++ b/nex-datalink/Cargo.toml @@ -18,6 +18,7 @@ serde = { workspace = true, features = ["derive"], optional = true } pcap = { version = "2.0", optional = true } nex-core = { workspace = true } nex-sys = { workspace = true } +futures-core = "0.3" [target.'cfg(windows)'.dependencies.windows-sys] version = "0.59.0" @@ -32,3 +33,6 @@ features = [ [features] serde = ["dep:serde", "netdev/serde"] pcap = ["dep:pcap"] + +[dev-dependencies] +futures = "0.3" diff --git a/nex-datalink/src/async_io/bpf.rs b/nex-datalink/src/async_io/bpf.rs new file mode 100644 index 0000000..bf351ad --- /dev/null +++ b/nex-datalink/src/async_io/bpf.rs @@ -0,0 +1,280 @@ +//! Asynchronous raw datalink support for BSD BPF devices. + +use crate::async_io::{AsyncChannel, AsyncRawSender}; +use crate::bindings::bpf; +use crate::Config; +use futures_core::stream::Stream; +use nex_core::interface::Interface; +use nex_sys; +use std::collections::VecDeque; +use std::ffi::CString; +use std::io; +use std::mem; +use std::os::fd::RawFd; +use std::pin::Pin; +use std::sync::Arc; +use std::task::{Context, Poll}; + +const ETHERNET_HEADER_SIZE: usize = 14; +const ETHERNET_NULL_HEADER_SIZE: usize = 4; + +#[derive(Debug)] +struct Inner { + fd: RawFd, + loopback: bool, + buffer_offset: usize, +} + +impl Drop for Inner { + fn drop(&mut self) { + unsafe { nex_sys::close(self.fd) }; + } +} + +/// Sender half of an asynchronous BPF socket. +#[derive(Clone, Debug)] +pub struct AsyncBpfSocketSender { + inner: Arc, +} + +impl AsyncRawSender for AsyncBpfSocketSender { + fn poll_send(&mut self, cx: &mut Context<'_>, packet: &[u8]) -> Poll> { + let offset = if self.inner.loopback { + ETHERNET_HEADER_SIZE + } else { + 0 + }; + let ret = unsafe { + libc::write( + self.inner.fd, + packet[offset..].as_ptr() as *const libc::c_void, + (packet.len() - offset) as libc::size_t, + ) + }; + if ret >= 0 { + return Poll::Ready(Ok(())); + } + let err = io::Error::last_os_error(); + if err.kind() == io::ErrorKind::WouldBlock { + let mut pfd = libc::pollfd { + fd: self.inner.fd, + events: libc::POLLOUT, + revents: 0, + }; + unsafe { libc::poll(&mut pfd, 1, 0) }; + cx.waker().wake_by_ref(); + Poll::Pending + } else { + Poll::Ready(Err(err)) + } + } +} + +/// Receiver half of an asynchronous BPF socket. +#[derive(Debug)] +pub struct AsyncBpfSocketReceiver { + inner: Arc, + read_buffer: Vec, + packets: VecDeque<(usize, usize)>, +} + +impl Stream for AsyncBpfSocketReceiver { + type Item = io::Result>; + + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let me = self.get_mut(); + let header_size = if me.inner.loopback { + ETHERNET_NULL_HEADER_SIZE + } else { + 0 + }; + if me.packets.is_empty() { + let buffer = &mut me.read_buffer[me.inner.buffer_offset..]; + let ret = unsafe { + libc::read( + me.inner.fd, + buffer.as_mut_ptr() as *mut libc::c_void, + buffer.len() as libc::size_t, + ) + }; + if ret >= 0 { + let buflen = ret as usize; + let mut ptr = buffer.as_mut_ptr(); + let end = unsafe { buffer.as_ptr().add(buflen) }; + while (ptr as *const u8) < end { + unsafe { + let packet: *const bpf::bpf_hdr = mem::transmute(ptr); + let start = + ptr as isize + (*packet).bh_hdrlen as isize - buffer.as_ptr() as isize; + me.packets.push_back(( + start as usize + header_size, + (*packet).bh_caplen as usize - header_size, + )); + let offset = (*packet).bh_hdrlen as isize + (*packet).bh_caplen as isize; + ptr = ptr.offset(bpf::BPF_WORDALIGN(offset)); + } + } + } else { + let err = io::Error::last_os_error(); + if err.kind() == io::ErrorKind::WouldBlock { + let mut pfd = libc::pollfd { + fd: me.inner.fd, + events: libc::POLLIN, + revents: 0, + }; + unsafe { libc::poll(&mut pfd, 1, 0) }; + cx.waker().wake_by_ref(); + return Poll::Pending; + } else { + return Poll::Ready(Some(Err(err))); + } + } + } + if let Some((mut start, mut len)) = me.packets.pop_front() { + len += me.inner.buffer_offset; + if me.inner.loopback { + let padding = ETHERNET_HEADER_SIZE - me.inner.buffer_offset; + start -= padding; + } + for i in (&mut me.read_buffer[start..start + me.inner.buffer_offset]).iter_mut() { + *i = 0; + } + let pkt = me.read_buffer[start..start + len].to_vec(); + Poll::Ready(Some(Ok(pkt))) + } else { + cx.waker().wake_by_ref(); + Poll::Pending + } + } +} + +/// Create a new asynchronous BPF socket channel. +pub fn channel(network_interface: &Interface, config: Config) -> io::Result { + #[cfg(any(target_os = "macos", target_os = "ios", target_os = "openbsd"))] + fn get_fd(attempts: usize) -> RawFd { + for i in 0..attempts { + let file_name = format!("/dev/bpf{}", i); + let c_file_name = CString::new(file_name).unwrap(); + let fd = unsafe { libc::open(c_file_name.as_ptr(), libc::O_RDWR, 0) }; + if fd != -1 { + return fd; + } + } + -1 + } + #[cfg(any( + target_os = "freebsd", + target_os = "netbsd", + target_os = "illumos", + target_os = "solaris", + ))] + fn get_fd(_attempts: usize) -> RawFd { + let c_file_name = CString::new("/dev/bpf").unwrap(); + unsafe { libc::open(c_file_name.as_ptr(), libc::O_RDWR, 0) } + } + + let fd = get_fd(config.bpf_fd_attempts); + if fd == -1 { + return Err(io::Error::last_os_error()); + } + + let mut iface: bpf::ifreq = unsafe { mem::zeroed() }; + for (i, c) in network_interface.name.bytes().enumerate() { + iface.ifr_name[i] = c as libc::c_char; + } + + let buflen = config.read_buffer_size as libc::c_uint; + if unsafe { bpf::ioctl(fd, bpf::BIOCSBLEN, &buflen) } == -1 { + let err = io::Error::last_os_error(); + unsafe { + nex_sys::close(fd); + } + return Err(err); + } + + if unsafe { bpf::ioctl(fd, bpf::BIOCSETIF, &iface) } == -1 { + let err = io::Error::last_os_error(); + unsafe { + nex_sys::close(fd); + } + return Err(err); + } + + if unsafe { bpf::ioctl(fd, bpf::BIOCIMMEDIATE, &1) } == -1 { + let err = io::Error::last_os_error(); + unsafe { + nex_sys::close(fd); + } + return Err(err); + } + + let mut dlt: libc::c_uint = 0; + if unsafe { bpf::ioctl(fd, bpf::BIOCGDLT, &mut dlt) } == -1 { + let err = io::Error::last_os_error(); + unsafe { + nex_sys::close(fd); + } + return Err(err); + } + + let mut loopback = false; + let mut buffer_offset = 0usize; + let mut read_buffer_size = config.read_buffer_size; + if dlt == bpf::DLT_NULL { + loopback = true; + let align = mem::align_of::(); + buffer_offset = (ETHERNET_HEADER_SIZE - ETHERNET_NULL_HEADER_SIZE).next_multiple_of(align); + read_buffer_size += buffer_offset; + } else { + if unsafe { bpf::ioctl(fd, bpf::BIOCSHDRCMPLT, &1) } == -1 { + let err = io::Error::last_os_error(); + unsafe { + nex_sys::close(fd); + } + return Err(err); + } + } + + if unsafe { libc::fcntl(fd, libc::F_SETFL, libc::O_NONBLOCK) } == -1 { + let err = io::Error::last_os_error(); + unsafe { + nex_sys::close(fd); + } + return Err(err); + } + + let read_buffer = vec![0u8; read_buffer_size]; + + let inner = Arc::new(Inner { + fd, + loopback, + buffer_offset, + }); + let tx = AsyncBpfSocketSender { + inner: inner.clone(), + }; + let rx = AsyncBpfSocketReceiver { + inner, + read_buffer, + packets: VecDeque::with_capacity(read_buffer_size / 64), + }; + Ok(AsyncChannel::Ethernet(Box::new(tx), Box::new(rx))) +} + +#[cfg(test)] +mod tests { + use super::*; + use futures::future::poll_fn; + + #[test] + #[ignore] + fn async_raw_send() { + let iface = Interface::default().expect("no default interface"); + let AsyncChannel::Ethernet(mut tx, _rx) = + channel(&iface, Config::default()).expect("socket"); + let packet = [0u8; 42]; + futures::executor::block_on(async { + let _ = poll_fn(|cx| tx.poll_send(cx, &packet)).await; + }); + } +} diff --git a/nex-datalink/src/async_io/linux.rs b/nex-datalink/src/async_io/linux.rs new file mode 100644 index 0000000..e7f1ed1 --- /dev/null +++ b/nex-datalink/src/async_io/linux.rs @@ -0,0 +1,208 @@ +//! Asynchronous raw socket support for Linux using epoll. + +use crate::async_io::{AsyncChannel, AsyncRawSender}; +use crate::{ChannelType, Config}; +use futures_core::stream::Stream; +use nex_core::interface::Interface; +use nex_core::mac::MacAddr; +use nex_sys; +use std::io; +use std::mem; +use std::os::fd::RawFd; +use std::pin::Pin; +use std::sync::Arc; +use std::task::{Context, Poll}; + +fn network_addr_to_sockaddr( + ni: &Interface, + storage: *mut libc::sockaddr_storage, + proto: libc::c_int, +) -> usize { + unsafe { + let sll: *mut libc::sockaddr_ll = mem::transmute(storage); + (*sll).sll_family = libc::AF_PACKET as libc::sa_family_t; + if let Some(MacAddr(a, b, c, d, e, f)) = ni.mac_addr { + (*sll).sll_addr = [a, b, c, d, e, f, 0, 0]; + } + (*sll).sll_protocol = (proto as u16).to_be(); + (*sll).sll_halen = 6; + (*sll).sll_ifindex = ni.index as i32; + mem::size_of::() + } +} + +#[derive(Debug)] +struct Inner { + fd: RawFd, + send_addr: libc::sockaddr_ll, + epfd: RawFd, +} + +impl Drop for Inner { + fn drop(&mut self) { + unsafe { + nex_sys::close(self.fd); + nex_sys::close(self.epfd); + } + } +} + +/// Sender half of an asynchronous raw socket. +#[derive(Clone, Debug)] +pub struct AsyncRawSocketSender { + inner: Arc, +} + +impl AsyncRawSender for AsyncRawSocketSender { + fn poll_send(&mut self, cx: &mut Context<'_>, packet: &[u8]) -> Poll> { + let ret = unsafe { + libc::sendto( + self.inner.fd, + packet.as_ptr() as *const libc::c_void, + packet.len(), + 0, + &self.inner.send_addr as *const libc::sockaddr_ll as *const libc::sockaddr, + mem::size_of::() as libc::socklen_t, + ) + }; + if ret >= 0 { + return Poll::Ready(Ok(())); + } + let err = io::Error::last_os_error(); + if err.kind() == io::ErrorKind::WouldBlock { + unsafe { + let mut events = [mem::zeroed::()]; + libc::epoll_wait(self.inner.epfd, events.as_mut_ptr(), 1, 0); + } + cx.waker().wake_by_ref(); + Poll::Pending + } else { + Poll::Ready(Err(err)) + } + } +} + +/// Receiver half of an asynchronous raw socket. +#[derive(Debug)] +pub struct AsyncRawSocketReceiver { + inner: Arc, + read_buffer: Vec, +} + +impl Stream for AsyncRawSocketReceiver { + type Item = io::Result>; + + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let me = self.get_mut(); + let ret = unsafe { + libc::recv( + me.inner.fd, + me.read_buffer.as_mut_ptr() as *mut libc::c_void, + me.read_buffer.len(), + libc::MSG_DONTWAIT, + ) + }; + if ret >= 0 { + let n = ret as usize; + let pkt = me.read_buffer[..n].to_vec(); + return Poll::Ready(Some(Ok(pkt))); + } + let err = io::Error::last_os_error(); + if err.kind() == io::ErrorKind::WouldBlock { + unsafe { + let mut events = [mem::zeroed::()]; + libc::epoll_wait(me.inner.epfd, events.as_mut_ptr(), 1, 0); + } + cx.waker().wake_by_ref(); + Poll::Pending + } else { + Poll::Ready(Some(Err(err))) + } + } +} + +/// Create a new asynchronous raw socket channel. +pub fn channel(network_interface: &Interface, config: Config) -> io::Result { + let eth_p_all = 0x0003; + let (typ, proto) = match config.channel_type { + ChannelType::Layer2 => (libc::SOCK_RAW, eth_p_all), + ChannelType::Layer3(proto) => (libc::SOCK_DGRAM, proto as i32), + }; + let fd = unsafe { + libc::socket( + libc::AF_PACKET, + typ | libc::SOCK_NONBLOCK, + (proto as u16).to_be() as i32, + ) + }; + if fd == -1 { + return Err(io::Error::last_os_error()); + } + + let mut addr: libc::sockaddr_storage = unsafe { mem::zeroed() }; + let len = network_addr_to_sockaddr(network_interface, &mut addr, proto); + let send_addr = unsafe { *(&addr as *const _ as *const libc::sockaddr_ll) }; + let bind_addr = (&addr as *const libc::sockaddr_storage) as *const libc::sockaddr; + + if unsafe { libc::bind(fd, bind_addr, len as libc::socklen_t) } == -1 { + let err = io::Error::last_os_error(); + unsafe { + nex_sys::close(fd); + } + return Err(err); + } + + let epfd = unsafe { libc::epoll_create1(0) }; + if epfd == -1 { + let err = io::Error::last_os_error(); + unsafe { + nex_sys::close(fd); + } + return Err(err); + } + + let mut event = libc::epoll_event { + events: (libc::EPOLLIN | libc::EPOLLOUT) as u32, + u64: fd as u64, + }; + if unsafe { libc::epoll_ctl(epfd, libc::EPOLL_CTL_ADD, fd, &mut event) } == -1 { + let err = io::Error::last_os_error(); + unsafe { + nex_sys::close(epfd); + nex_sys::close(fd); + } + return Err(err); + } + + let inner = Arc::new(Inner { + fd, + send_addr, + epfd, + }); + let tx = AsyncRawSocketSender { + inner: inner.clone(), + }; + let rx = AsyncRawSocketReceiver { + inner, + read_buffer: vec![0u8; config.read_buffer_size], + }; + Ok(AsyncChannel::Ethernet(Box::new(tx), Box::new(rx))) +} + +#[cfg(test)] +mod tests { + use super::*; + use futures::future::poll_fn; + + #[test] + #[ignore] + fn async_raw_send() { + let iface = Interface::default().expect("no default interface"); + let AsyncChannel::Ethernet(mut tx, _rx) = + channel(&iface, Config::default()).expect("socket"); + let packet = [0u8; 42]; + futures::executor::block_on(async { + let _ = poll_fn(|cx| tx.poll_send(cx, &packet)).await; + }); + } +} diff --git a/nex-datalink/src/async_io/mod.rs b/nex-datalink/src/async_io/mod.rs new file mode 100644 index 0000000..3d95301 --- /dev/null +++ b/nex-datalink/src/async_io/mod.rs @@ -0,0 +1,75 @@ +#[cfg(any(target_os = "linux", target_os = "android"))] +pub mod linux; + +#[cfg(any( + target_os = "freebsd", + target_os = "netbsd", + target_os = "illumos", + target_os = "solaris", + target_os = "macos", + target_os = "ios", +))] +pub mod bpf; + +#[cfg(windows)] +pub mod wpcap; + +use std::io; +use std::task::{Context, Poll}; + +use futures_core::stream::Stream; + +use crate::Config; + +/// Trait for asynchronously sending raw packets. +pub trait AsyncRawSender: Send { + /// Attempt to send a packet asynchronously. + /// + /// The method returns `Poll::Ready` once the packet has been + /// transmitted or an error has occurred. If the socket is not + /// currently writable, it will return `Poll::Pending` and arrange for + /// the current task to be woken once progress can be made. + fn poll_send(&mut self, cx: &mut Context<'_>, packet: &[u8]) -> Poll>; +} + +/// Trait for asynchronously receiving raw packets. +/// +/// This is implemented for any type implementing [`Stream`] with +/// `Item = io::Result>`. +pub trait AsyncRawReceiver: Stream>> + Send + Unpin {} + +impl AsyncRawReceiver for T where T: Stream>> + Send + Unpin {} + +/// An asynchronous channel for sending and receiving at the data link layer. +#[non_exhaustive] +pub enum AsyncChannel { + /// An asynchronous datalink channel which sends and receives Ethernet packets. + Ethernet(Box, Box), +} + +/// Creates a new asynchronous datalink channel for sending and receiving raw packets. +#[inline] +pub fn async_channel( + network_interface: &nex_core::interface::Interface, + configuration: Config, +) -> io::Result { + #[cfg(all(any(target_os = "linux", target_os = "android")))] + { + linux::channel(network_interface, configuration) + } + #[cfg(all(any( + target_os = "freebsd", + target_os = "netbsd", + target_os = "illumos", + target_os = "solaris", + target_os = "macos", + target_os = "ios", + )))] + { + bpf::channel(network_interface, configuration) + } + #[cfg(windows)] + { + wpcap::channel(network_interface, configuration) + } +} diff --git a/nex-datalink/src/async_io/wpcap.rs b/nex-datalink/src/async_io/wpcap.rs new file mode 100644 index 0000000..37ae0bd --- /dev/null +++ b/nex-datalink/src/async_io/wpcap.rs @@ -0,0 +1,234 @@ +//! Asynchronous raw datalink support for Windows using the Npcap / WinPcap library. + +use crate::async_io::{AsyncChannel, AsyncRawSender}; +use crate::bindings::{bpf, windows}; +use crate::Config; +use futures_core::stream::Stream; +use nex_core::interface::Interface; +use std::cmp; +use std::collections::VecDeque; +use std::ffi::CString; +use std::io; +use std::mem; +use std::pin::Pin; +use std::slice; +use std::sync::{Arc, Mutex}; +use std::task::{Context, Poll, Waker}; +use std::thread; + +#[derive(Debug)] +struct WinPcapAdapter { + adapter: windows::LPADAPTER, +} + +impl Drop for WinPcapAdapter { + fn drop(&mut self) { + unsafe { windows::PacketCloseAdapter(self.adapter) }; + } +} + +unsafe impl Send for WinPcapAdapter {} +unsafe impl Sync for WinPcapAdapter {} + +#[derive(Clone, Debug)] +struct WinPcapPacket { + packet: windows::LPPACKET, +} + +impl Drop for WinPcapPacket { + fn drop(&mut self) { + unsafe { windows::PacketFreePacket(self.packet) }; + } +} + +unsafe impl Send for WinPcapPacket {} + +#[derive(Debug)] +struct Inner { + adapter: Arc, + packets: Arc>>>, + waker: Arc>>, +} + +unsafe impl Send for Inner {} +unsafe impl Sync for Inner {} + +/// Sender half of a WinPcap socket. +#[derive(Clone, Debug)] +pub struct AsyncWpcapSocketSender { + inner: Arc, + write_buffer: Vec, + packet: WinPcapPacket, +} + +impl AsyncRawSender for AsyncWpcapSocketSender { + fn poll_send(&mut self, _cx: &mut Context<'_>, packet: &[u8]) -> Poll> { + let len = cmp::min(packet.len(), self.write_buffer.len()); + self.write_buffer[..len].copy_from_slice(&packet[..len]); + unsafe { + windows::PacketInitPacket( + self.packet.packet, + self.write_buffer.as_mut_ptr() as windows::PVOID, + len as windows::UINT, + ); + } + let ret = + unsafe { windows::PacketSendPacket(self.inner.adapter.adapter, self.packet.packet, 1) }; + if ret == 0 { + Poll::Ready(Err(io::Error::last_os_error())) + } else { + Poll::Ready(Ok(())) + } + } +} + +/// Receiver half of a WinPcap socket. +#[derive(Debug)] +pub struct AsyncWpcapSocketReceiver { + inner: Arc, +} + +impl Stream for AsyncWpcapSocketReceiver { + type Item = io::Result>; + + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let mut queue = self.inner.packets.lock().unwrap(); + if let Some(pkt) = queue.pop_front() { + Poll::Ready(Some(Ok(pkt))) + } else { + *self.inner.waker.lock().unwrap() = Some(cx.waker().clone()); + Poll::Pending + } + } +} + +/// Create a new asynchronous WinPcap channel. +pub fn channel(network_interface: &Interface, config: Config) -> io::Result { + let mut write_buffer = vec![0u8; config.write_buffer_size]; + + let adapter = unsafe { + let npf_if_name: String = windows::to_npf_name(&network_interface.name); + let net_if_str = CString::new(npf_if_name.as_bytes()).unwrap(); + windows::PacketOpenAdapter(net_if_str.as_ptr() as *mut libc::c_char) + }; + if adapter.is_null() { + return Err(io::Error::last_os_error()); + } + + let ret = unsafe { windows::PacketSetHwFilter(adapter, windows::NDIS_PACKET_TYPE_PROMISCUOUS) }; + if ret == 0 { + unsafe { windows::PacketCloseAdapter(adapter) }; + return Err(io::Error::last_os_error()); + } + + let ret = unsafe { windows::PacketSetBuff(adapter, config.read_buffer_size as libc::c_int) }; + if ret == 0 { + unsafe { windows::PacketCloseAdapter(adapter) }; + return Err(io::Error::last_os_error()); + } + + let ret = unsafe { windows::PacketSetMinToCopy(adapter, 1) }; + if ret == 0 { + unsafe { windows::PacketCloseAdapter(adapter) }; + return Err(io::Error::last_os_error()); + } + + let write_packet = unsafe { windows::PacketAllocatePacket() }; + if write_packet.is_null() { + unsafe { windows::PacketCloseAdapter(adapter) }; + return Err(io::Error::last_os_error()); + } + unsafe { + windows::PacketInitPacket( + write_packet, + write_buffer.as_mut_ptr() as windows::PVOID, + config.write_buffer_size as windows::UINT, + ); + } + + let adapter = Arc::new(WinPcapAdapter { adapter }); + let packets = Arc::new(Mutex::new(VecDeque::new())); + let waker: Arc>> = Arc::new(Mutex::new(None)); + + { + let adapter = adapter.clone(); + let packets = packets.clone(); + let waker = waker.clone(); + let read_buffer_size = config.read_buffer_size; + thread::spawn(move || { + let mut read_buffer = vec![0u8; read_buffer_size]; + let read_packet = unsafe { windows::PacketAllocatePacket() }; + if read_packet.is_null() { + return; + } + unsafe { + windows::PacketInitPacket( + read_packet, + read_buffer.as_mut_ptr() as windows::PVOID, + read_buffer_size as windows::UINT, + ); + } + loop { + let ret = unsafe { windows::PacketReceivePacket(adapter.adapter, read_packet, 1) }; + if ret == 0 { + continue; + } + let buflen = unsafe { (*read_packet).ulBytesReceived as isize }; + let mut ptr = unsafe { (*read_packet).Buffer as *mut libc::c_char }; + let end = unsafe { ((*read_packet).Buffer as *mut libc::c_char).offset(buflen) }; + while ptr < end { + unsafe { + let hdr: *const bpf::bpf_hdr = mem::transmute(ptr); + let start = ptr as isize + (*hdr).bh_hdrlen as isize + - (*read_packet).Buffer as isize; + let caplen = (*hdr).bh_caplen as usize; + let data_ptr = ((*read_packet).Buffer as isize + start) as *const u8; + let data = slice::from_raw_parts(data_ptr, caplen).to_vec(); + { + let mut queue = packets.lock().unwrap(); + queue.push_back(data); + } + let offset = (*hdr).bh_hdrlen as isize + (*hdr).bh_caplen as isize; + ptr = ptr.offset(bpf::BPF_WORDALIGN(offset)); + } + } + if let Some(w) = waker.lock().unwrap().take() { + w.wake(); + } + } + }); + } + + let inner = Arc::new(Inner { + adapter, + packets, + waker, + }); + let tx = AsyncWpcapSocketSender { + inner: inner.clone(), + write_buffer, + packet: WinPcapPacket { + packet: write_packet, + }, + }; + let rx = AsyncWpcapSocketReceiver { inner }; + Ok(AsyncChannel::Ethernet(Box::new(tx), Box::new(rx))) +} + +#[cfg(test)] +mod tests { + use super::*; + use futures::future::poll_fn; + + #[test] + #[ignore] + fn async_raw_send() { + let iface = Interface::default().expect("no default interface"); + let AsyncChannel::Ethernet(mut tx, _rx) = + channel(&iface, Config::default()).expect("socket"); + let packet = [0u8; 42]; + futures::executor::block_on(async { + let _ = poll_fn(|cx| tx.poll_send(cx, &packet)).await; + }); + } +} diff --git a/nex-datalink/src/lib.rs b/nex-datalink/src/lib.rs index 90f714a..386f769 100644 --- a/nex-datalink/src/lib.rs +++ b/nex-datalink/src/lib.rs @@ -8,6 +8,8 @@ use std::time::Duration; mod bindings; +pub mod async_io; + #[cfg(windows)] #[path = "wpcap.rs"] mod backend; diff --git a/nex-packet/src/builder/icmpv6.rs b/nex-packet/src/builder/icmpv6.rs index 636517e..5d7f6bc 100644 --- a/nex-packet/src/builder/icmpv6.rs +++ b/nex-packet/src/builder/icmpv6.rs @@ -60,13 +60,15 @@ impl Icmpv6PacketBuilder { /// Calculate the checksum and set it in the header pub fn calculate_checksum(mut self) -> Self { - self.packet.header.checksum = icmpv6::checksum(&self.packet, &self.source, &self.destination); + self.packet.header.checksum = + icmpv6::checksum(&self.packet, &self.source, &self.destination); self } /// Return an `Icmpv6Packet` with checksum computed pub fn build(mut self) -> Icmpv6Packet { - self.packet.header.checksum = icmpv6::checksum(&self.packet, &self.source, &self.destination); + self.packet.header.checksum = + icmpv6::checksum(&self.packet, &self.source, &self.destination); self.packet } diff --git a/nex-packet/src/builder/tcp.rs b/nex-packet/src/builder/tcp.rs index 302336f..c199b8c 100644 --- a/nex-packet/src/builder/tcp.rs +++ b/nex-packet/src/builder/tcp.rs @@ -95,12 +95,14 @@ impl TcpPacketBuilder { /// Calculate the checksum and set it in the header pub fn calculate_checksum(mut self) -> Self { - self.packet.header.checksum = crate::tcp::checksum(&self.packet, &self.src_ip, &self.dst_ip); + self.packet.header.checksum = + crate::tcp::checksum(&self.packet, &self.src_ip, &self.dst_ip); self } /// Build the packet with checksum computed pub fn build(mut self) -> TcpPacket { - self.packet.header.checksum = crate::tcp::checksum(&self.packet, &self.src_ip, &self.dst_ip); + self.packet.header.checksum = + crate::tcp::checksum(&self.packet, &self.src_ip, &self.dst_ip); self.packet } /// Serialize the packet into bytes with checksum computed @@ -119,16 +121,19 @@ mod tests { #[test] fn tcp_builder_basic() { - let pkt = TcpPacketBuilder::new(IpAddr::V4(Ipv4Addr::new(192, 168, 1, 100)), IpAddr::V4(Ipv4Addr::new(192, 168, 1, 1))) - .source(1234) - .destination(80) - .sequence(1) - .acknowledgement(2) - .flags(TcpFlags::SYN) - .window(1024) - .urgent_ptr(0) - .payload(Bytes::from_static(b"abc")) - .build(); + let pkt = TcpPacketBuilder::new( + IpAddr::V4(Ipv4Addr::new(192, 168, 1, 100)), + IpAddr::V4(Ipv4Addr::new(192, 168, 1, 1)), + ) + .source(1234) + .destination(80) + .sequence(1) + .acknowledgement(2) + .flags(TcpFlags::SYN) + .window(1024) + .urgent_ptr(0) + .payload(Bytes::from_static(b"abc")) + .build(); assert_eq!(pkt.header.source, 1234); assert_eq!(pkt.header.destination, 80); assert_eq!(pkt.header.sequence, 1); diff --git a/nex-packet/src/builder/udp.rs b/nex-packet/src/builder/udp.rs index d7d807e..9b219f4 100644 --- a/nex-packet/src/builder/udp.rs +++ b/nex-packet/src/builder/udp.rs @@ -57,7 +57,8 @@ impl UdpPacketBuilder { /// Calculate the checksum and set it in the header pub fn calculate_checksum(mut self) -> Self { // Calculate the checksum and set it in the header - self.packet.header.checksum = crate::udp::checksum(&self.packet, &self.src_ip, &self.dst_ip); + self.packet.header.checksum = + crate::udp::checksum(&self.packet, &self.src_ip, &self.dst_ip); self } @@ -67,7 +68,8 @@ impl UdpPacketBuilder { let total_len = UDP_HEADER_LEN + self.packet.payload.len(); self.packet.header.length = (total_len as u16).into(); // Calculate the checksum - self.packet.header.checksum = crate::udp::checksum(&self.packet, &self.src_ip, &self.dst_ip); + self.packet.header.checksum = + crate::udp::checksum(&self.packet, &self.src_ip, &self.dst_ip); self.packet } @@ -93,11 +95,14 @@ mod tests { #[test] fn udp_builder_sets_length() { - let pkt = UdpPacketBuilder::new(IpAddr::V4(Ipv4Addr::new(192, 168, 1, 100)), IpAddr::V4(Ipv4Addr::new(192, 168, 1, 1))) - .source(1) - .destination(2) - .payload(Bytes::from_static(&[1, 2, 3])) - .build(); + let pkt = UdpPacketBuilder::new( + IpAddr::V4(Ipv4Addr::new(192, 168, 1, 100)), + IpAddr::V4(Ipv4Addr::new(192, 168, 1, 1)), + ) + .source(1) + .destination(2) + .payload(Bytes::from_static(&[1, 2, 3])) + .build(); assert_eq!(pkt.header.length, (UDP_HEADER_LEN + 3) as u16); assert_eq!(pkt.payload, Bytes::from_static(&[1, 2, 3])); } diff --git a/nex/Cargo.toml b/nex/Cargo.toml index 20ed3d5..960be4e 100644 --- a/nex/Cargo.toml +++ b/nex/Cargo.toml @@ -82,3 +82,11 @@ path = "../examples/async_tcp_socket.rs" [[example]] name = "async_udp_socket" path = "../examples/async_udp_socket.rs" + +[[example]] +name = "async_datalink" +path = "../examples/async_datalink.rs" + +[[example]] +name = "async_dump" +path = "../examples/async_dump.rs"