fix(phantun): use the same source IP for UDP packets

This fixes an issue when phantun may choose a source IP different from
the destination IP in the incoming packet.

Closes #177
This commit is contained in:
WGH 2025-01-01 19:27:46 +03:00
parent 62f0278c1a
commit 5b82789bfc
3 changed files with 110 additions and 13 deletions

View File

@ -21,4 +21,4 @@ pretty_env_logger = "0"
tokio-tun = "0" tokio-tun = "0"
num_cpus = "1" num_cpus = "1"
neli = "0" neli = "0"
nix = { version = "0", features = ["net"] } nix = { version = "0", features = ["net", "uio", "socket"] }

View File

@ -2,11 +2,11 @@ use clap::{crate_version, Arg, ArgAction, Command};
use fake_tcp::packet::MAX_PACKET_LEN; use fake_tcp::packet::MAX_PACKET_LEN;
use fake_tcp::{Socket, Stack}; use fake_tcp::{Socket, Stack};
use log::{debug, error, info}; use log::{debug, error, info};
use phantun::utils::{assign_ipv6_address, new_udp_reuseport}; use phantun::utils::{assign_ipv6_address, new_udp_reuseport, udp_recv_pktinfo};
use std::collections::HashMap; use std::collections::HashMap;
use std::fs; use std::fs;
use std::io; use std::io;
use std::net::{Ipv4Addr, SocketAddr}; use std::net::{IpAddr, Ipv4Addr, SocketAddr, SocketAddrV4, SocketAddrV6};
use std::sync::Arc; use std::sync::Arc;
use tokio::sync::{Notify, RwLock}; use tokio::sync::{Notify, RwLock};
use tokio::time; use tokio::time;
@ -175,17 +175,17 @@ async fn main() -> io::Result<()> {
let mut buf_r = [0u8; MAX_PACKET_LEN]; let mut buf_r = [0u8; MAX_PACKET_LEN];
loop { loop {
let (size, addr) = udp_sock.recv_from(&mut buf_r).await?; let (size, udp_remote_addr, udp_local_ip) = udp_recv_pktinfo(&udp_sock, &mut buf_r).await?;
// seen UDP packet to listening socket, this means: // seen UDP packet to listening socket, this means:
// 1. It is a new UDP connection, or // 1. It is a new UDP connection, or
// 2. It is some extra packets not filtered by more specific // 2. It is some extra packets not filtered by more specific
// connected UDP socket yet // connected UDP socket yet
if let Some(sock) = connections.read().await.get(&addr) { if let Some(sock) = connections.read().await.get(&udp_remote_addr) {
sock.send(&buf_r[..size]).await; sock.send(&buf_r[..size]).await;
continue; continue;
} }
info!("New UDP client from {}", addr); info!("New UDP client from {}", udp_remote_addr);
let sock = stack.connect(remote_addr).await; let sock = stack.connect(remote_addr).await;
if sock.is_none() { if sock.is_none() {
error!("Unable to connect to remote {}", remote_addr); error!("Unable to connect to remote {}", remote_addr);
@ -210,7 +210,7 @@ async fn main() -> io::Result<()> {
assert!(connections assert!(connections
.write() .write()
.await .await
.insert(addr, sock.clone()) .insert(udp_remote_addr, sock.clone())
.is_none()); .is_none());
debug!("inserted fake TCP socket into connection table"); debug!("inserted fake TCP socket into connection table");
@ -228,8 +228,34 @@ async fn main() -> io::Result<()> {
tokio::spawn(async move { tokio::spawn(async move {
let mut buf_udp = [0u8; MAX_PACKET_LEN]; let mut buf_udp = [0u8; MAX_PACKET_LEN];
let mut buf_tcp = [0u8; MAX_PACKET_LEN]; let mut buf_tcp = [0u8; MAX_PACKET_LEN];
let udp_sock = new_udp_reuseport(local_addr); // Always reply from the same address that the peer used to communicate with
udp_sock.connect(addr).await.unwrap(); // us. This avoids a frequent problem with IPv6 privacy extensions when we
// erroneously bind to wrong short-lived temporary address even if the peer
// explicitly used a persistent address to communicate to us.
//
// To do so, first bind to (<incoming packet dst_ip>, <local addr port>), and then
// connect to (<incoming packet src_ip>, <incoming packet src_port>).
let bind_addr = match (udp_remote_addr, udp_local_ip) {
(SocketAddr::V4(_), IpAddr::V4(udp_local_ipv4)) => {
SocketAddr::V4(SocketAddrV4::new(
udp_local_ipv4,
local_addr.port(),
))
}
(SocketAddr::V6(udp_remote_addr), IpAddr::V6(udp_local_ipv6)) => {
SocketAddr::V6(SocketAddrV6::new(
udp_local_ipv6,
local_addr.port(),
udp_remote_addr.flowinfo(),
udp_remote_addr.scope_id(),
))
}
(_, _) => {
panic!("unexpected family combination for udp_remote_addr={udp_remote_addr} and udp_local_addr={udp_local_ip}");
}
};
let udp_sock = new_udp_reuseport(bind_addr);
udp_sock.connect(udp_remote_addr).await.unwrap();
loop { loop {
tokio::select! { tokio::select! {
@ -247,7 +273,7 @@ async fn main() -> io::Result<()> {
Some(size) => { Some(size) => {
if size > 0 { if size > 0 {
if let Err(e) = udp_sock.send(&buf_tcp[..size]).await { if let Err(e) = udp_sock.send(&buf_tcp[..size]).await {
error!("Unable to send UDP packet to {}: {}, closing connection", e, addr); error!("Unable to send UDP packet to {}: {}, closing connection", e, udp_remote_addr);
quit.cancel(); quit.cancel();
return; return;
} }
@ -280,14 +306,14 @@ async fn main() -> io::Result<()> {
tokio::select! { tokio::select! {
_ = read_timeout => { _ = read_timeout => {
info!("No traffic seen in the last {:?}, closing connection", UDP_TTL); info!("No traffic seen in the last {:?}, closing connection", UDP_TTL);
connections.write().await.remove(&addr); connections.write().await.remove(&udp_remote_addr);
debug!("removed fake TCP socket from connections table"); debug!("removed fake TCP socket from connections table");
quit.cancel(); quit.cancel();
return; return;
}, },
_ = quit.cancelled() => { _ = quit.cancelled() => {
connections.write().await.remove(&addr); connections.write().await.remove(&udp_remote_addr);
debug!("removed fake TCP socket from connections table"); debug!("removed fake TCP socket from connections table");
return; return;
}, },

View File

@ -9,7 +9,8 @@ use neli::{
socket::NlSocketHandle, socket::NlSocketHandle,
types::RtBuffer, types::RtBuffer,
}; };
use std::net::{Ipv6Addr, SocketAddr}; use nix::sys::socket::{CmsgIterator, ControlMessageOwned, SockaddrLike as _};
use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6};
use tokio::net::UdpSocket; use tokio::net::UdpSocket;
pub fn new_udp_reuseport(local_addr: SocketAddr) -> UdpSocket { pub fn new_udp_reuseport(local_addr: SocketAddr) -> UdpSocket {
@ -27,11 +28,81 @@ pub fn new_udp_reuseport(local_addr: SocketAddr) -> UdpSocket {
// from tokio-rs/mio/blob/master/src/sys/unix/net.rs // from tokio-rs/mio/blob/master/src/sys/unix/net.rs
udp_sock.set_cloexec(true).unwrap(); udp_sock.set_cloexec(true).unwrap();
udp_sock.set_nonblocking(true).unwrap(); udp_sock.set_nonblocking(true).unwrap();
// enable IP_PKTINFO/IPV6_PKTINFO delivery so we know the destination address of incoming
// packets
if local_addr.is_ipv4() {
nix::sys::socket::setsockopt(&udp_sock, nix::sys::socket::sockopt::Ipv4PacketInfo, &true)
.unwrap();
} else {
nix::sys::socket::setsockopt(
&udp_sock,
nix::sys::socket::sockopt::Ipv6RecvPacketInfo,
&true,
)
.unwrap();
}
udp_sock.bind(&socket2::SockAddr::from(local_addr)).unwrap(); udp_sock.bind(&socket2::SockAddr::from(local_addr)).unwrap();
let udp_sock: std::net::UdpSocket = udp_sock.into(); let udp_sock: std::net::UdpSocket = udp_sock.into();
udp_sock.try_into().unwrap() udp_sock.try_into().unwrap()
} }
pub async fn udp_recv_pktinfo(
sock: &UdpSocket,
buf: &mut [u8],
) -> std::io::Result<(usize, std::net::SocketAddr, std::net::IpAddr)> {
use std::os::unix::io::AsRawFd;
use tokio::io::Interest;
sock.async_io(Interest::READABLE, || {
// FIXME this is somewhat excessive, we actually need only
// max(sizeof(in_pktinfo), sizeof(in6_pktinfo))
let mut control_buffer = nix::cmsg_space!(nix::libc::in_pktinfo, nix::libc::in6_pktinfo);
let iov = &mut [std::io::IoSliceMut::new(buf)];
let res = nix::sys::socket::recvmsg::<nix::sys::socket::SockaddrStorage>(
sock.as_raw_fd(),
iov,
Some(&mut control_buffer),
nix::sys::socket::MsgFlags::empty(),
)?;
let src_addr = res.address.expect("missing source address");
let src_addr: SocketAddr = {
if let Some(inaddr) = src_addr.as_sockaddr_in() {
SocketAddrV4::new(inaddr.ip(), inaddr.port()).into()
} else if let Some(in6addr) = src_addr.as_sockaddr_in6() {
SocketAddrV6::new(
in6addr.ip(),
in6addr.port(),
in6addr.flowinfo(),
in6addr.scope_id(),
)
.into()
} else {
panic!("unexpected source address family {:#?}", src_addr.family());
}
};
let dst_addr = dst_addr_from_cmsgs(res.cmsgs()?).expect("didn't receive pktinfo");
Ok((res.bytes, src_addr, dst_addr))
})
.await
}
fn dst_addr_from_cmsgs(cmsgs: CmsgIterator) -> Option<IpAddr> {
for cmsg in cmsgs {
if let ControlMessageOwned::Ipv4PacketInfo(pktinfo) = cmsg {
return Some(Ipv4Addr::from(pktinfo.ipi_addr.s_addr.to_ne_bytes()).into());
}
if let ControlMessageOwned::Ipv6PacketInfo(pktinfo) = cmsg {
return Some(Ipv6Addr::from(pktinfo.ipi6_addr.s6_addr).into());
}
}
None
}
pub fn assign_ipv6_address(device_name: &str, local: Ipv6Addr, peer: Ipv6Addr) { pub fn assign_ipv6_address(device_name: &str, local: Ipv6Addr, peer: Ipv6Addr) {
let index = nix::net::if_::if_nametoindex(device_name).unwrap(); let index = nix::net::if_::if_nametoindex(device_name).unwrap();