From 6a39e9e9d06a2dd4f1cfeb7333dc0b1fb6263290 Mon Sep 17 00:00:00 2001 From: Datong Sun Date: Fri, 22 Aug 2025 20:51:16 -0700 Subject: [PATCH] perf(phantun): avoid heap allocation with `udp_recv_from_pktinfo()` --- phantun/src/utils.rs | 35 +++++++++++++++++++++-------------- 1 file changed, 21 insertions(+), 14 deletions(-) diff --git a/phantun/src/utils.rs b/phantun/src/utils.rs index 9e027a0..d0617c4 100644 --- a/phantun/src/utils.rs +++ b/phantun/src/utils.rs @@ -10,8 +10,12 @@ use neli::{ types::RtBuffer, utils::Groups, }; -use nix::sys::socket::{CmsgIterator, ControlMessageOwned, SockaddrLike as _}; +use nix::sys::socket::{ + CmsgIterator, ControlMessageOwned, MsgFlags, SockaddrLike, SockaddrStorage, cmsg_space, +}; use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6}; +use std::os::unix::io::AsRawFd; +use tokio::io::Interest; use tokio::net::UdpSocket; pub fn new_udp_reuseport(local_addr: SocketAddr) -> UdpSocket { @@ -49,26 +53,28 @@ pub fn new_udp_reuseport(local_addr: SocketAddr) -> UdpSocket { udp_sock.try_into().unwrap() } +/// Similiar to `UdpSocket::recv_from()`, but returns a 3rd value `IPAddr` +/// which corresponds to where the UDP datagram was destined to, this is useful +/// for disambigous when socket can receive on multiple IP address +/// or interfaces. pub async fn udp_recv_pktinfo( sock: &UdpSocket, buf: &mut [u8], -) -> std::io::Result<(usize, std::net::SocketAddr, std::net::IpAddr)> { - use std::os::unix::io::AsRawFd; - use tokio::io::Interest; - use nix::sys::socket::cmsg_space; - +) -> std::io::Result<(usize, SocketAddr, IpAddr)> { sock.async_io(Interest::READABLE, || { - let control_buffer_size = std::cmp::max( - cmsg_space::(), - cmsg_space::(), - ); - let mut control_buffer = vec![0u8; control_buffer_size]; + // according to documented struct definition in RFC 3542, + // sizeof(in6_pktinfo) should always be larger than sizeof(in_pktinfo), + // this assert just double checks that. The goal is to avoid + // a heap allocation with Vec at runtime. + assert!(cmsg_space::() >= cmsg_space::()); + + let mut control_message_buffer = [0u8; cmsg_space::()]; let iov = &mut [std::io::IoSliceMut::new(buf)]; - let res = nix::sys::socket::recvmsg::( + let res = nix::sys::socket::recvmsg::( sock.as_raw_fd(), iov, - Some(&mut control_buffer), - nix::sys::socket::MsgFlags::empty(), + Some(&mut control_message_buffer), + MsgFlags::empty(), )?; let src_addr = res.address.expect("missing source address"); @@ -104,6 +110,7 @@ fn dst_addr_from_cmsgs(cmsgs: CmsgIterator) -> Option { return Some(Ipv6Addr::from(pktinfo.ipi6_addr.s6_addr).into()); } } + None }