perf(fake-tcp) use cached tuples per dispatcher task to avoid RwLock

contentions. Use multi queue Tun. Upgraded tokio to 1.12.0.

This makes the entire Phantun forward process completely lock contention free
This commit is contained in:
Datong Sun 2021-09-22 21:53:52 -07:00
parent 04b0e97c1d
commit 8371256f0b
6 changed files with 123 additions and 77 deletions

View File

@ -1,6 +1,6 @@
[package] [package]
name = "fake-tcp" name = "fake-tcp"
version = "0.1.0" version = "0.1.1"
edition = "2018" edition = "2018"
authors = ["Datong Sun <dndx@idndx.com>"] authors = ["Datong Sun <dndx@idndx.com>"]
license = "MIT OR Apache-2.0" license = "MIT OR Apache-2.0"
@ -17,8 +17,8 @@ benchmark = []
[dependencies] [dependencies]
bytes = "1" bytes = "1"
pnet = "0.28.0" pnet = "0.28.0"
tokio-tun = "0.3.15" tokio = { version = "1.12.0", features = ["full"] }
tokio = { version = "1.11.0", features = ["full"] }
rand = { version = "0.8.4", features = ["small_rng"] } rand = { version = "0.8.4", features = ["small_rng"] }
log = "0.4" log = "0.4"
internet-checksum = "0.2.0" internet-checksum = "0.2.0"
dndx-fork-tokio-tun = "0.3.16"

View File

@ -1,9 +1,10 @@
#![cfg_attr(feature = "benchmark", feature(test))] #![cfg_attr(feature = "benchmark", feature(test))]
pub mod packet; pub mod packet;
extern crate dndx_fork_tokio_tun as tokio_tun;
use bytes::{Bytes, BytesMut}; use bytes::{Bytes, BytesMut};
use log::{info, trace}; use log::{error, info, trace, warn};
use packet::*; use packet::*;
use pnet::packet::{tcp, Packet}; use pnet::packet::{tcp, Packet};
use rand::prelude::*; use rand::prelude::*;
@ -12,18 +13,18 @@ use std::fmt;
use std::net::{Ipv4Addr, SocketAddrV4}; use std::net::{Ipv4Addr, SocketAddrV4};
use std::sync::atomic::{AtomicU32, Ordering}; use std::sync::atomic::{AtomicU32, Ordering};
use std::sync::{Arc, RwLock}; use std::sync::{Arc, RwLock};
use tokio::io::{AsyncReadExt, AsyncWriteExt}; use tokio::sync::broadcast;
use tokio::sync::mpsc::{self, Receiver, Sender}; use tokio::sync::mpsc::{self, Receiver, Sender};
use tokio::sync::watch; use tokio::sync::watch;
use tokio::sync::Mutex as AsyncMutex; use tokio::sync::Mutex as AsyncMutex;
use tokio::{io, time}; use tokio::time;
use tokio_tun::Tun; use tokio_tun::Tun;
const TIMEOUT: time::Duration = time::Duration::from_secs(1); const TIMEOUT: time::Duration = time::Duration::from_secs(1);
const RETRIES: usize = 6; const RETRIES: usize = 6;
const MPSC_BUFFER_LEN: usize = 512; const MPSC_BUFFER_LEN: usize = 512;
#[derive(Debug, Hash, Eq, PartialEq)] #[derive(Hash, Eq, PartialEq, Clone, Debug)]
pub struct AddrTuple { pub struct AddrTuple {
local_addr: SocketAddrV4, local_addr: SocketAddrV4,
remote_addr: SocketAddrV4, remote_addr: SocketAddrV4,
@ -38,12 +39,12 @@ impl AddrTuple {
} }
} }
#[derive(Debug)]
struct Shared { struct Shared {
tuples: RwLock<HashMap<AddrTuple, Arc<Sender<Bytes>>>>, tuples: RwLock<HashMap<AddrTuple, Sender<Bytes>>>,
listening: RwLock<HashSet<u16>>, listening: RwLock<HashSet<u16>>,
outgoing: Sender<Bytes>, tun: Vec<Arc<Tun>>,
ready: Sender<Socket>, ready: Sender<Socket>,
tuples_purge: broadcast::Sender<AddrTuple>,
} }
pub struct Stack { pub struct Stack {
@ -52,7 +53,6 @@ pub struct Stack {
ready: Receiver<Socket>, ready: Receiver<Socket>,
} }
#[derive(Debug)]
pub enum State { pub enum State {
Idle, Idle,
SynSent, SynSent,
@ -60,16 +60,9 @@ pub enum State {
Established, Established,
} }
#[derive(Debug)]
pub enum Mode {
Client,
Server,
}
#[derive(Debug)]
pub struct Socket { pub struct Socket {
mode: Mode,
shared: Arc<Shared>, shared: Arc<Shared>,
tun: Arc<Tun>,
incoming: AsyncMutex<Receiver<Bytes>>, incoming: AsyncMutex<Receiver<Bytes>>,
local_addr: SocketAddrV4, local_addr: SocketAddrV4,
remote_addr: SocketAddrV4, remote_addr: SocketAddrV4,
@ -82,8 +75,8 @@ pub struct Socket {
impl Socket { impl Socket {
fn new( fn new(
mode: Mode,
shared: Arc<Shared>, shared: Arc<Shared>,
tun: Arc<Tun>,
local_addr: SocketAddrV4, local_addr: SocketAddrV4,
remote_addr: SocketAddrV4, remote_addr: SocketAddrV4,
ack: Option<u32>, ack: Option<u32>,
@ -94,8 +87,8 @@ impl Socket {
( (
Socket { Socket {
mode,
shared, shared,
tun,
incoming: AsyncMutex::new(incoming_rx), incoming: AsyncMutex::new(incoming_rx),
local_addr, local_addr,
remote_addr, remote_addr,
@ -126,10 +119,10 @@ impl Socket {
match self.state { match self.state {
State::Established => { State::Established => {
let buf = self.build_tcp_packet(tcp::TcpFlags::ACK, Some(payload)); let buf = self.build_tcp_packet(tcp::TcpFlags::ACK, Some(payload));
self.seq.fetch_add(buf.len() as u32, Ordering::Relaxed); self.seq.fetch_add(payload.len() as u32, Ordering::Relaxed);
tokio::select! { tokio::select! {
res = self.shared.outgoing.send(buf) => { res = self.tun.send(&buf) => {
res.unwrap(); res.unwrap();
Some(()) Some(())
}, },
@ -186,7 +179,7 @@ impl Socket {
State::Idle => { State::Idle => {
let buf = self.build_tcp_packet(tcp::TcpFlags::SYN | tcp::TcpFlags::ACK, None); let buf = self.build_tcp_packet(tcp::TcpFlags::SYN | tcp::TcpFlags::ACK, None);
// ACK set by constructor // ACK set by constructor
self.shared.outgoing.send(buf).await.unwrap(); self.tun.send(&buf).await.unwrap();
self.state = State::SynReceived; self.state = State::SynReceived;
info!("Sent SYN + ACK to client"); info!("Sent SYN + ACK to client");
} }
@ -210,7 +203,9 @@ impl Socket {
info!("Connection from {:?} established", self.remote_addr); info!("Connection from {:?} established", self.remote_addr);
let ready = self.shared.ready.clone(); let ready = self.shared.ready.clone();
ready.send(self).await.unwrap(); if let Err(e) = ready.send(self).await {
error!("Unable to send accepted socket to ready queue: {}", e);
}
return; return;
} }
} else { } else {
@ -228,7 +223,7 @@ impl Socket {
match self.state { match self.state {
State::Idle => { State::Idle => {
let buf = self.build_tcp_packet(tcp::TcpFlags::SYN, None); let buf = self.build_tcp_packet(tcp::TcpFlags::SYN, None);
self.shared.outgoing.send(buf).await.unwrap(); self.tun.send(&buf).await.unwrap();
self.state = State::SynSent; self.state = State::SynSent;
info!("Sent SYN to server"); info!("Sent SYN to server");
} }
@ -253,7 +248,7 @@ impl Socket {
// send ACK to finish handshake // send ACK to finish handshake
let buf = self.build_tcp_packet(tcp::TcpFlags::ACK, None); let buf = self.build_tcp_packet(tcp::TcpFlags::ACK, None);
self.shared.outgoing.send(buf).await.unwrap(); self.tun.send(&buf).await.unwrap();
self.state = State::Established; self.state = State::Established;
@ -277,17 +272,16 @@ impl Socket {
impl Drop for Socket { impl Drop for Socket {
fn drop(&mut self) { fn drop(&mut self) {
let tuple = AddrTuple::new(self.local_addr, self.remote_addr);
// dissociates ourself from the dispatch map // dissociates ourself from the dispatch map
assert!(self assert!(self.shared.tuples.write().unwrap().remove(&tuple).is_some());
.shared // purge cache
.tuples self.shared.tuples_purge.send(tuple).unwrap();
.write()
.unwrap()
.remove(&AddrTuple::new(self.local_addr, self.remote_addr))
.is_some());
let buf = self.build_tcp_packet(tcp::TcpFlags::RST, None); let buf = self.build_tcp_packet(tcp::TcpFlags::RST, None);
self.shared.outgoing.try_send(buf).unwrap(); if let Err(e) = self.tun.try_send(&buf) {
warn!("Unable to send RST to remote end: {}", e);
}
self.close(); self.close();
info!("Fake TCP connection to {} closed", self); info!("Fake TCP connection to {} closed", self);
} }
@ -304,18 +298,27 @@ impl fmt::Display for Socket {
} }
impl Stack { impl Stack {
pub fn new(tun: Tun) -> Stack { pub fn new(tun: Vec<Tun>) -> Stack {
let (outgoing_tx, outgoing_rx) = mpsc::channel(MPSC_BUFFER_LEN); let tun: Vec<Arc<Tun>> = tun.into_iter().map(Arc::new).collect();
let (ready_tx, ready_rx) = mpsc::channel(MPSC_BUFFER_LEN); let (ready_tx, ready_rx) = mpsc::channel(MPSC_BUFFER_LEN);
let (tuples_purge_tx, _tuples_purge_rx) = broadcast::channel(16);
let shared = Arc::new(Shared { let shared = Arc::new(Shared {
tuples: RwLock::new(HashMap::new()), tuples: RwLock::new(HashMap::new()),
outgoing: outgoing_tx, tun: tun.clone(),
listening: RwLock::new(HashSet::new()), listening: RwLock::new(HashSet::new()),
ready: ready_tx, ready: ready_tx,
tuples_purge: tuples_purge_tx.clone(),
}); });
let local_ip = tun.destination().unwrap(); let local_ip = tun[0].destination().unwrap();
for t in tun {
tokio::spawn(Stack::reader_task(
t,
shared.clone(),
tuples_purge_tx.subscribe(),
));
}
tokio::spawn(Stack::dispatch(tun, outgoing_rx, shared.clone()));
Stack { Stack {
shared, shared,
local_ip, local_ip,
@ -337,8 +340,8 @@ impl Stack {
let local_addr = SocketAddrV4::new(self.local_ip, local_port); let local_addr = SocketAddrV4::new(self.local_ip, local_port);
let tuple = AddrTuple::new(local_addr, addr); let tuple = AddrTuple::new(local_addr, addr);
let (mut sock, incoming) = Socket::new( let (mut sock, incoming) = Socket::new(
Mode::Client,
self.shared.clone(), self.shared.clone(),
self.shared.tun.choose(&mut rng).unwrap().clone(),
local_addr, local_addr,
addr, addr,
None, None,
@ -347,53 +350,83 @@ impl Stack {
{ {
let mut tuples = self.shared.tuples.write().unwrap(); let mut tuples = self.shared.tuples.write().unwrap();
assert!(tuples.insert(tuple, Arc::new(incoming.clone())).is_none()); assert!(tuples.insert(tuple, incoming.clone()).is_none());
} }
sock.connect().await.map(|_| sock) sock.connect().await.map(|_| sock)
} }
async fn dispatch(tun: Tun, mut outgoing: Receiver<Bytes>, shared: Arc<Shared>) { async fn reader_task(
let (mut tun_r, mut tun_w) = io::split(tun); tun: Arc<Tun>,
shared: Arc<Shared>,
mut tuples_purge: broadcast::Receiver<AddrTuple>,
) {
let mut tuples: HashMap<AddrTuple, Sender<Bytes>> = HashMap::new();
loop { loop {
let mut buf = BytesMut::with_capacity(MAX_PACKET_LEN); let mut buf = BytesMut::with_capacity(MAX_PACKET_LEN);
buf.resize(MAX_PACKET_LEN, 0);
tokio::select! { tokio::select! {
buf = outgoing.recv() => { size = tun.recv(&mut buf) => {
let buf = buf.unwrap(); let size = size.unwrap();
tun_w.write_all(&buf).await.unwrap(); buf.truncate(size);
},
s = tun_r.read_buf(&mut buf) => {
s.unwrap();
let buf = buf.freeze(); let buf = buf.freeze();
if buf[0] >> 4 != 4 { if buf[0] >> 4 != 4 {
// not an IPv4 packet // not an IPv4 packet
continue; continue;
} }
let (ip_packet, tcp_packet) = parse_ipv4_packet(&buf); let (ip_packet, tcp_packet) = parse_ipv4_packet(&buf);
let local_addr = SocketAddrV4::new(ip_packet.get_destination(), tcp_packet.get_destination()); let local_addr =
SocketAddrV4::new(ip_packet.get_destination(), tcp_packet.get_destination());
let remote_addr = SocketAddrV4::new(ip_packet.get_source(), tcp_packet.get_source()); let remote_addr = SocketAddrV4::new(ip_packet.get_source(), tcp_packet.get_source());
let tuple = AddrTuple::new(local_addr, remote_addr); let tuple = AddrTuple::new(local_addr, remote_addr);
if let Some(c) = tuples.get(&tuple) {
let sender;
{
let tuples = shared.tuples.read().unwrap();
sender = tuples.get(&tuple).cloned();
}
if let Some(c) = sender {
c.send(buf).await.unwrap(); c.send(buf).await.unwrap();
continue; continue;
} else {
trace!("Cache miss, checking the shared tuples table for connection");
let sender;
{
let tuples = shared.tuples.read().unwrap();
sender = tuples.get(&tuple).cloned();
}
if let Some(c) = sender {
trace!("Storing connection information into local tuples");
tuples.insert(tuple, c.clone());
c.send(buf).await.unwrap();
continue;
}
} }
if tcp_packet.get_flags() == tcp::TcpFlags::SYN && shared.listening.read().unwrap().contains(&tcp_packet.get_destination()) { if tcp_packet.get_flags() == tcp::TcpFlags::SYN
&& shared
.listening
.read()
.unwrap()
.contains(&tcp_packet.get_destination())
{
// SYN seen on listening socket // SYN seen on listening socket
if tcp_packet.get_sequence() == 0 { if tcp_packet.get_sequence() == 0 {
let (sock, incoming) = Socket::new(Mode::Server, shared.clone(), local_addr, remote_addr, Some(tcp_packet.get_sequence() + 1), State::Idle); let (sock, incoming) = Socket::new(
assert!(shared.tuples.write().unwrap().insert(tuple, Arc::new(incoming)).is_none()); shared.clone(),
tun.clone(),
local_addr,
remote_addr,
Some(tcp_packet.get_sequence() + 1),
State::Idle,
);
assert!(shared
.tuples
.write()
.unwrap()
.insert(tuple, incoming)
.is_none());
tokio::spawn(sock.accept()); tokio::spawn(sock.accept());
} else { } else {
trace!("Bad TCP SYN packet from {}, sending RST", remote_addr); trace!("Bad TCP SYN packet from {}, sending RST", remote_addr);
@ -405,7 +438,7 @@ impl Stack {
tcp::TcpFlags::RST, tcp::TcpFlags::RST,
None, None,
); );
shared.outgoing.try_send(buf).unwrap(); shared.tun[0].try_send(&buf).unwrap();
} }
} else if (tcp_packet.get_flags() & tcp::TcpFlags::RST) == 0 { } else if (tcp_packet.get_flags() & tcp::TcpFlags::RST) == 0 {
info!("Unknown TCP packet from {}, sending RST", remote_addr); info!("Unknown TCP packet from {}, sending RST", remote_addr);
@ -417,8 +450,13 @@ impl Stack {
tcp::TcpFlags::RST, tcp::TcpFlags::RST,
None, None,
); );
shared.outgoing.try_send(buf).unwrap(); shared.tun[0].try_send(&buf).unwrap();
} }
},
tuple = tuples_purge.recv() => {
let tuple = tuple.unwrap();
tuples.remove(&tuple);
trace!("Removed cached tuple");
} }
} }
} }

View File

@ -18,8 +18,8 @@ pub fn build_tcp_packet(
payload: Option<&[u8]>, payload: Option<&[u8]>,
) -> Bytes { ) -> Bytes {
let wscale = (flags & tcp::TcpFlags::SYN) != 0; let wscale = (flags & tcp::TcpFlags::SYN) != 0;
let tcp_total_len = TCP_HEADER_LEN + if wscale {4} else {0} // nop + wscale let tcp_header_len = TCP_HEADER_LEN + if wscale { 4 } else { 0 }; // nop + wscale
+ payload.map_or(0, |payload| payload.len()); let tcp_total_len = tcp_header_len + payload.map_or(0, |payload| payload.len());
let total_len = IPV4_HEADER_LEN + tcp_total_len; let total_len = IPV4_HEADER_LEN + tcp_total_len;
let mut buf = BytesMut::with_capacity(total_len); let mut buf = BytesMut::with_capacity(total_len);
buf.resize(total_len, 0); buf.resize(total_len, 0);
@ -62,9 +62,10 @@ pub fn build_tcp_packet(
cksm.add_bytes(&local_addr.ip().octets()); cksm.add_bytes(&local_addr.ip().octets());
cksm.add_bytes(&remote_addr.ip().octets()); cksm.add_bytes(&remote_addr.ip().octets());
let ip::IpNextHeaderProtocol(tcp_protocol) = ip::IpNextHeaderProtocols::Tcp; let ip::IpNextHeaderProtocol(tcp_protocol) = ip::IpNextHeaderProtocols::Tcp;
let pseudo = [0u8, tcp_protocol, 0, tcp_total_len as u8]; let mut pseudo = [0u8, tcp_protocol, 0, 0];
pseudo[2..].copy_from_slice(&(tcp_total_len as u16).to_be_bytes());
cksm.add_bytes(&pseudo); cksm.add_bytes(&pseudo);
cksm.add_bytes(v4.packet()); cksm.add_bytes(tcp.packet());
tcp.set_checksum(u16::from_be_bytes(cksm.checksum())); tcp.set_checksum(u16::from_be_bytes(cksm.checksum()));
v4_buf.unsplit(tcp_buf); v4_buf.unsplit(tcp_buf);

View File

@ -1,20 +1,21 @@
[package] [package]
name = "phantun" name = "phantun"
version = "0.1.0" version = "0.1.1"
edition = "2018" edition = "2018"
authors = ["Datong Sun <dndx@idndx.com>"] authors = ["Datong Sun <dndx@idndx.com>"]
license = "MIT OR Apache-2.0" license = "MIT OR Apache-2.0"
repository = "https://github.com/dndx/phantun" repository = "https://github.com/dndx/phantun"
readme = "README.md" readme = "README.md"
description = """ description = """
Turns transforms UDP stream into (fake) TCP streams that can go through Transforms UDP stream into (fake) TCP streams that can go through
Layer 4 firewalls. Layer 3 & Layer 4 (NAPT) firewalls/NATs.
""" """
[dependencies] [dependencies]
clap = "2.33.3" clap = "2.33.3"
socket2 = { version = "0.4.2", features = ["all"] } socket2 = { version = "0.4.2", features = ["all"] }
fake-tcp = "0.1.0" fake-tcp = { path = "../fake-tcp" }
tokio-tun = "0.3.15" tokio = { version = "1.12.0", features = ["full"] }
tokio = { version = "1.11.0", features = ["full"] }
log = "0.4" log = "0.4"
pretty_env_logger = "0.4.0" pretty_env_logger = "0.4.0"
dndx-fork-tokio-tun = "0.3.16"
num_cpus = "1.13.0"

View File

@ -1,3 +1,5 @@
extern crate dndx_fork_tokio_tun as tokio_tun;
use clap::{App, Arg}; use clap::{App, Arg};
use fake_tcp::packet::MAX_PACKET_LEN; use fake_tcp::packet::MAX_PACKET_LEN;
use fake_tcp::{Socket, Stack}; use fake_tcp::{Socket, Stack};
@ -70,10 +72,10 @@ async fn main() {
.up() // or set it up manually using `sudo ip link set <tun-name> up`. .up() // or set it up manually using `sudo ip link set <tun-name> up`.
.address("192.168.200.1".parse().unwrap()) .address("192.168.200.1".parse().unwrap())
.destination("192.168.200.2".parse().unwrap()) .destination("192.168.200.2".parse().unwrap())
.try_build() .try_build_mq(num_cpus::get())
.unwrap(); .unwrap();
info!("Created TUN device {}", tun.name()); info!("Created TUN device {}", tun[0].name());
let udp_sock = Arc::new(new_udp_reuseport(local_addr)); let udp_sock = Arc::new(new_udp_reuseport(local_addr));
let connections = Arc::new(RwLock::new(HashMap::<SocketAddrV4, Arc<Socket>>::new())); let connections = Arc::new(RwLock::new(HashMap::<SocketAddrV4, Arc<Socket>>::new()));

View File

@ -1,3 +1,5 @@
extern crate dndx_fork_tokio_tun as tokio_tun;
use clap::{App, Arg}; use clap::{App, Arg};
use fake_tcp::packet::MAX_PACKET_LEN; use fake_tcp::packet::MAX_PACKET_LEN;
use fake_tcp::Stack; use fake_tcp::Stack;
@ -53,9 +55,11 @@ async fn main() {
.up() // or set it up manually using `sudo ip link set <tun-name> up`. .up() // or set it up manually using `sudo ip link set <tun-name> up`.
.address("192.168.201.1".parse().unwrap()) .address("192.168.201.1".parse().unwrap())
.destination("192.168.201.2".parse().unwrap()) .destination("192.168.201.2".parse().unwrap())
.try_build() .try_build_mq(num_cpus::get())
.unwrap(); .unwrap();
info!("Created TUN device {}", tun[0].name());
//thread::sleep(time::Duration::from_secs(5)); //thread::sleep(time::Duration::from_secs(5));
let mut stack = Stack::new(tun); let mut stack = Stack::new(tun);
stack.listen(local_port); stack.listen(local_port);