From 86c6a3f8018462d22e8486863769d15dc7c82ba1 Mon Sep 17 00:00:00 2001 From: Datong Sun Date: Sat, 18 Sep 2021 11:58:45 -0700 Subject: [PATCH] perf(client) use different UDP sockets for individual UDP connections for better load sharing between threads This removes the bottleneck with a single listening UDP socket. --- Cargo.toml | 1 + src/bin/client.rs | 87 ++++++++++++++++++++++++++++++--------------- src/fake_tcp/mod.rs | 2 +- 3 files changed, 60 insertions(+), 30 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index da00fde..9f45316 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -13,3 +13,4 @@ rand = { version = "0.8.4", features = ["small_rng"] } clap = "2.33.3" log = "0.4" pretty_env_logger = "0.4.0" +socket2 = { version = "0.4.2", features = ["all"] } diff --git a/src/bin/client.rs b/src/bin/client.rs index 9bd31e7..8f1e241 100644 --- a/src/bin/client.rs +++ b/src/bin/client.rs @@ -1,8 +1,9 @@ use clap::{App, Arg}; use log::{debug, error, info}; -use lru_time_cache::{LruCache, TimedEntry}; use phantom::fake_tcp::packet::MAX_PACKET_LEN; use phantom::fake_tcp::{Socket, Stack}; +use std::collections::HashMap; +use std::convert::TryInto; use std::net::{SocketAddr, SocketAddrV4}; use std::sync::Arc; use std::time::Duration; @@ -13,6 +14,17 @@ use tokio_tun::TunBuilder; const UDP_TTL: Duration = Duration::from_secs(180); +fn new_udp_reuseport(addr: SocketAddrV4) -> UdpSocket { + let udp_sock = socket2::Socket::new(socket2::Domain::IPV4, socket2::Type::DGRAM, None).unwrap(); + udp_sock.set_reuse_port(true).unwrap(); + // from tokio-rs/mio/blob/master/src/sys/unix/net.rs + udp_sock.set_cloexec(true).unwrap(); + udp_sock.set_nonblocking(true).unwrap(); + udp_sock.bind(&socket2::SockAddr::from(addr)).unwrap(); + let udp_sock: std::net::UdpSocket = udp_sock.into(); + udp_sock.try_into().unwrap() +} + #[tokio::main] async fn main() { pretty_env_logger::init(); @@ -63,21 +75,22 @@ async fn main() { info!("Created TUN device {}", tun.name()); - let udp_sock = Arc::new(UdpSocket::bind(local_addr).await.unwrap()); - let connections = Arc::new(RwLock::new( - LruCache::>::with_expiry_duration(UDP_TTL), - )); + let udp_sock = Arc::new(new_udp_reuseport(local_addr)); + let connections = Arc::new(RwLock::new(HashMap::>::new())); let mut stack = Stack::new(tun); let main_loop = tokio::spawn(async move { let mut buf_r = [0u8; MAX_PACKET_LEN]; - let mut cleanup_timer = time::interval(Duration::from_secs(5)); loop { tokio::select! { Ok((size, SocketAddr::V4(addr))) = udp_sock.recv_from(&mut buf_r) => { - if let Some(sock) = connections.read().await.peek(&addr) { + // seen UDP packet to listening socket, this means: + // 1. It is a new UDP connection, or + // 2. It is some extra packets not filtered by more specific + // connected UDP socket yet + if let Some(sock) = connections.read().await.get(&addr) { sock.send(&buf_r[..size]).await; continue; } @@ -90,44 +103,60 @@ async fn main() { } let sock = Arc::new(sock.unwrap()); + // send first packet let res = sock.send(&buf_r[..size]).await; if res.is_none() { continue; } assert!(connections.write().await.insert(addr, sock.clone()).is_none()); - debug!("inserted fake TCP socket into LruCache"); - let udp_sock = udp_sock.clone(); + debug!("inserted fake TCP socket into connection table"); let connections = connections.clone(); + + // spawn "fastpath" UDP socket and task, this will offload main task + // from forwarding UDP packets tokio::spawn(async move { + let mut buf_udp = [0u8; MAX_PACKET_LEN]; + let mut buf_tcp = [0u8; MAX_PACKET_LEN]; + let udp_sock = new_udp_reuseport(local_addr); + udp_sock.connect(addr).await.unwrap(); + loop { - let mut buf_r = [0u8; MAX_PACKET_LEN]; - match sock.recv(&mut buf_r).await { - Some(size) => { - udp_sock.send_to(&buf_r[..size], addr).await.unwrap(); + let read_timeout = time::sleep(UDP_TTL); + + tokio::select! { + Ok(size) = udp_sock.recv(&mut buf_udp) => { + if sock.send(&buf_udp[..size]).await.is_none() { + connections.write().await.remove(&addr); + debug!("removed fake TCP socket from connections table"); + return; + } }, - None => { + res = sock.recv(&mut buf_tcp) => { + match res { + Some(size) => { + if size > 0 { + udp_sock.send(&buf_tcp[..size]).await.unwrap(); + } + }, + None => { + connections.write().await.remove(&addr); + debug!("removed fake TCP socket from connections table"); + return; + }, + } + }, + _ = read_timeout => { + info!("No traffic seen in the last {:?}, closing connection", UDP_TTL); connections.write().await.remove(&addr); - debug!("removed fake TCP socket from LruCache"); + debug!("removed fake TCP socket from connections table"); return; - }, - } + } + }; } }); }, - _ = cleanup_timer.tick() => { - let mut total = 0; - - for c in connections.write().await.notify_iter() { - if let TimedEntry::Expired(_addr, sock) = c { - sock.close(); - total += 1; - } - } - - debug!("Cleaned {} stale connections", total); - }, } } }); diff --git a/src/fake_tcp/mod.rs b/src/fake_tcp/mod.rs index 609bdae..c7d55d6 100644 --- a/src/fake_tcp/mod.rs +++ b/src/fake_tcp/mod.rs @@ -19,7 +19,7 @@ use tokio_tun::Tun; const TIMEOUT: time::Duration = time::Duration::from_secs(1); const RETRIES: usize = 6; -const MPSC_BUFFER_LEN: usize = 128; +const MPSC_BUFFER_LEN: usize = 512; #[derive(Debug, Hash, Eq, PartialEq)] pub struct AddrTuple {