perf(phantun): avoid heap allocation with udp_recv_from_pktinfo()

This commit is contained in:
Datong Sun 2025-08-22 20:51:16 -07:00
parent cedee0c699
commit 6a39e9e9d0

View File

@ -10,8 +10,12 @@ use neli::{
types::RtBuffer, types::RtBuffer,
utils::Groups, utils::Groups,
}; };
use nix::sys::socket::{CmsgIterator, ControlMessageOwned, SockaddrLike as _}; use nix::sys::socket::{
CmsgIterator, ControlMessageOwned, MsgFlags, SockaddrLike, SockaddrStorage, cmsg_space,
};
use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6}; use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6};
use std::os::unix::io::AsRawFd;
use tokio::io::Interest;
use tokio::net::UdpSocket; use tokio::net::UdpSocket;
pub fn new_udp_reuseport(local_addr: SocketAddr) -> UdpSocket { pub fn new_udp_reuseport(local_addr: SocketAddr) -> UdpSocket {
@ -49,26 +53,28 @@ pub fn new_udp_reuseport(local_addr: SocketAddr) -> UdpSocket {
udp_sock.try_into().unwrap() udp_sock.try_into().unwrap()
} }
/// Similiar to `UdpSocket::recv_from()`, but returns a 3rd value `IPAddr`
/// which corresponds to where the UDP datagram was destined to, this is useful
/// for disambigous when socket can receive on multiple IP address
/// or interfaces.
pub async fn udp_recv_pktinfo( pub async fn udp_recv_pktinfo(
sock: &UdpSocket, sock: &UdpSocket,
buf: &mut [u8], buf: &mut [u8],
) -> std::io::Result<(usize, std::net::SocketAddr, std::net::IpAddr)> { ) -> std::io::Result<(usize, SocketAddr, IpAddr)> {
use std::os::unix::io::AsRawFd;
use tokio::io::Interest;
use nix::sys::socket::cmsg_space;
sock.async_io(Interest::READABLE, || { sock.async_io(Interest::READABLE, || {
let control_buffer_size = std::cmp::max( // according to documented struct definition in RFC 3542,
cmsg_space::<nix::libc::in_pktinfo>(), // sizeof(in6_pktinfo) should always be larger than sizeof(in_pktinfo),
cmsg_space::<nix::libc::in6_pktinfo>(), // this assert just double checks that. The goal is to avoid
); // a heap allocation with Vec at runtime.
let mut control_buffer = vec![0u8; control_buffer_size]; assert!(cmsg_space::<nix::libc::in6_pktinfo>() >= cmsg_space::<nix::libc::in_pktinfo>());
let mut control_message_buffer = [0u8; cmsg_space::<nix::libc::in6_pktinfo>()];
let iov = &mut [std::io::IoSliceMut::new(buf)]; let iov = &mut [std::io::IoSliceMut::new(buf)];
let res = nix::sys::socket::recvmsg::<nix::sys::socket::SockaddrStorage>( let res = nix::sys::socket::recvmsg::<SockaddrStorage>(
sock.as_raw_fd(), sock.as_raw_fd(),
iov, iov,
Some(&mut control_buffer), Some(&mut control_message_buffer),
nix::sys::socket::MsgFlags::empty(), MsgFlags::empty(),
)?; )?;
let src_addr = res.address.expect("missing source address"); let src_addr = res.address.expect("missing source address");
@ -104,6 +110,7 @@ fn dst_addr_from_cmsgs(cmsgs: CmsgIterator) -> Option<IpAddr> {
return Some(Ipv6Addr::from(pktinfo.ipi6_addr.s6_addr).into()); return Some(Ipv6Addr::from(pktinfo.ipi6_addr.s6_addr).into());
} }
} }
None None
} }