diff --git a/phantun/Cargo.toml b/phantun/Cargo.toml index 07130fb..35f70bd 100644 --- a/phantun/Cargo.toml +++ b/phantun/Cargo.toml @@ -21,4 +21,4 @@ pretty_env_logger = "0" tokio-tun = "0" num_cpus = "1" neli = "0" -nix = { version = "0", features = ["net"] } +nix = { version = "0", features = ["net", "uio", "socket"] } diff --git a/phantun/src/bin/client.rs b/phantun/src/bin/client.rs index 2219953..5a541b9 100644 --- a/phantun/src/bin/client.rs +++ b/phantun/src/bin/client.rs @@ -2,11 +2,11 @@ use clap::{crate_version, Arg, ArgAction, Command}; use fake_tcp::packet::MAX_PACKET_LEN; use fake_tcp::{Socket, Stack}; use log::{debug, error, info}; -use phantun::utils::{assign_ipv6_address, new_udp_reuseport}; +use phantun::utils::{assign_ipv6_address, new_udp_reuseport, udp_recv_pktinfo}; use std::collections::HashMap; use std::fs; use std::io; -use std::net::{Ipv4Addr, SocketAddr}; +use std::net::{IpAddr, Ipv4Addr, SocketAddr, SocketAddrV4, SocketAddrV6}; use std::sync::Arc; use tokio::sync::{Notify, RwLock}; use tokio::time; @@ -175,17 +175,17 @@ async fn main() -> io::Result<()> { let mut buf_r = [0u8; MAX_PACKET_LEN]; loop { - let (size, addr) = udp_sock.recv_from(&mut buf_r).await?; + let (size, udp_remote_addr, udp_local_ip) = udp_recv_pktinfo(&udp_sock, &mut buf_r).await?; // seen UDP packet to listening socket, this means: // 1. It is a new UDP connection, or // 2. It is some extra packets not filtered by more specific // connected UDP socket yet - if let Some(sock) = connections.read().await.get(&addr) { + if let Some(sock) = connections.read().await.get(&udp_remote_addr) { sock.send(&buf_r[..size]).await; continue; } - info!("New UDP client from {}", addr); + info!("New UDP client from {}", udp_remote_addr); let sock = stack.connect(remote_addr).await; if sock.is_none() { error!("Unable to connect to remote {}", remote_addr); @@ -210,7 +210,7 @@ async fn main() -> io::Result<()> { assert!(connections .write() .await - .insert(addr, sock.clone()) + .insert(udp_remote_addr, sock.clone()) .is_none()); debug!("inserted fake TCP socket into connection table"); @@ -228,8 +228,34 @@ async fn main() -> io::Result<()> { tokio::spawn(async move { let mut buf_udp = [0u8; MAX_PACKET_LEN]; let mut buf_tcp = [0u8; MAX_PACKET_LEN]; - let udp_sock = new_udp_reuseport(local_addr); - udp_sock.connect(addr).await.unwrap(); + // Always reply from the same address that the peer used to communicate with + // us. This avoids a frequent problem with IPv6 privacy extensions when we + // erroneously bind to wrong short-lived temporary address even if the peer + // explicitly used a persistent address to communicate to us. + // + // To do so, first bind to (, ), and then + // connect to (, ). + let bind_addr = match (udp_remote_addr, udp_local_ip) { + (SocketAddr::V4(_), IpAddr::V4(udp_local_ipv4)) => { + SocketAddr::V4(SocketAddrV4::new( + udp_local_ipv4, + local_addr.port(), + )) + } + (SocketAddr::V6(udp_remote_addr), IpAddr::V6(udp_local_ipv6)) => { + SocketAddr::V6(SocketAddrV6::new( + udp_local_ipv6, + local_addr.port(), + udp_remote_addr.flowinfo(), + udp_remote_addr.scope_id(), + )) + } + (_, _) => { + panic!("unexpected family combination for udp_remote_addr={udp_remote_addr} and udp_local_addr={udp_local_ip}"); + } + }; + let udp_sock = new_udp_reuseport(bind_addr); + udp_sock.connect(udp_remote_addr).await.unwrap(); loop { tokio::select! { @@ -247,7 +273,7 @@ async fn main() -> io::Result<()> { Some(size) => { if size > 0 { if let Err(e) = udp_sock.send(&buf_tcp[..size]).await { - error!("Unable to send UDP packet to {}: {}, closing connection", e, addr); + error!("Unable to send UDP packet to {}: {}, closing connection", e, udp_remote_addr); quit.cancel(); return; } @@ -280,14 +306,14 @@ async fn main() -> io::Result<()> { tokio::select! { _ = read_timeout => { info!("No traffic seen in the last {:?}, closing connection", UDP_TTL); - connections.write().await.remove(&addr); + connections.write().await.remove(&udp_remote_addr); debug!("removed fake TCP socket from connections table"); quit.cancel(); return; }, _ = quit.cancelled() => { - connections.write().await.remove(&addr); + connections.write().await.remove(&udp_remote_addr); debug!("removed fake TCP socket from connections table"); return; }, diff --git a/phantun/src/utils.rs b/phantun/src/utils.rs index b8630e7..f426fdf 100644 --- a/phantun/src/utils.rs +++ b/phantun/src/utils.rs @@ -9,7 +9,8 @@ use neli::{ socket::NlSocketHandle, types::RtBuffer, }; -use std::net::{Ipv6Addr, SocketAddr}; +use nix::sys::socket::{CmsgIterator, ControlMessageOwned, SockaddrLike as _}; +use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6}; use tokio::net::UdpSocket; pub fn new_udp_reuseport(local_addr: SocketAddr) -> UdpSocket { @@ -27,11 +28,81 @@ pub fn new_udp_reuseport(local_addr: SocketAddr) -> UdpSocket { // from tokio-rs/mio/blob/master/src/sys/unix/net.rs udp_sock.set_cloexec(true).unwrap(); udp_sock.set_nonblocking(true).unwrap(); + + // enable IP_PKTINFO/IPV6_PKTINFO delivery so we know the destination address of incoming + // packets + if local_addr.is_ipv4() { + nix::sys::socket::setsockopt(&udp_sock, nix::sys::socket::sockopt::Ipv4PacketInfo, &true) + .unwrap(); + } else { + nix::sys::socket::setsockopt( + &udp_sock, + nix::sys::socket::sockopt::Ipv6RecvPacketInfo, + &true, + ) + .unwrap(); + } + udp_sock.bind(&socket2::SockAddr::from(local_addr)).unwrap(); let udp_sock: std::net::UdpSocket = udp_sock.into(); udp_sock.try_into().unwrap() } +pub async fn udp_recv_pktinfo( + sock: &UdpSocket, + buf: &mut [u8], +) -> std::io::Result<(usize, std::net::SocketAddr, std::net::IpAddr)> { + use std::os::unix::io::AsRawFd; + use tokio::io::Interest; + + sock.async_io(Interest::READABLE, || { + // FIXME this is somewhat excessive, we actually need only + // max(sizeof(in_pktinfo), sizeof(in6_pktinfo)) + let mut control_buffer = nix::cmsg_space!(nix::libc::in_pktinfo, nix::libc::in6_pktinfo); + let iov = &mut [std::io::IoSliceMut::new(buf)]; + let res = nix::sys::socket::recvmsg::( + sock.as_raw_fd(), + iov, + Some(&mut control_buffer), + nix::sys::socket::MsgFlags::empty(), + )?; + + let src_addr = res.address.expect("missing source address"); + let src_addr: SocketAddr = { + if let Some(inaddr) = src_addr.as_sockaddr_in() { + SocketAddrV4::new(inaddr.ip(), inaddr.port()).into() + } else if let Some(in6addr) = src_addr.as_sockaddr_in6() { + SocketAddrV6::new( + in6addr.ip(), + in6addr.port(), + in6addr.flowinfo(), + in6addr.scope_id(), + ) + .into() + } else { + panic!("unexpected source address family {:#?}", src_addr.family()); + } + }; + + let dst_addr = dst_addr_from_cmsgs(res.cmsgs()?).expect("didn't receive pktinfo"); + + Ok((res.bytes, src_addr, dst_addr)) + }) + .await +} + +fn dst_addr_from_cmsgs(cmsgs: CmsgIterator) -> Option { + for cmsg in cmsgs { + if let ControlMessageOwned::Ipv4PacketInfo(pktinfo) = cmsg { + return Some(Ipv4Addr::from(pktinfo.ipi_addr.s_addr.to_ne_bytes()).into()); + } + if let ControlMessageOwned::Ipv6PacketInfo(pktinfo) = cmsg { + return Some(Ipv6Addr::from(pktinfo.ipi6_addr.s6_addr).into()); + } + } + None +} + pub fn assign_ipv6_address(device_name: &str, local: Ipv6Addr, peer: Ipv6Addr) { let index = nix::net::if_::if_nametoindex(device_name).unwrap();