perf(fake-tcp) use cached tuples per dispatcher task to avoid RwLock

contentions. Use multi queue Tun. Upgraded tokio to 1.12.0.

This makes the entire Phantun forward process completely lock contention free
This commit is contained in:
Datong Sun 2021-09-22 21:53:52 -07:00
parent 04b0e97c1d
commit 8371256f0b
6 changed files with 123 additions and 77 deletions

View File

@ -1,6 +1,6 @@
[package]
name = "fake-tcp"
version = "0.1.0"
version = "0.1.1"
edition = "2018"
authors = ["Datong Sun <dndx@idndx.com>"]
license = "MIT OR Apache-2.0"
@ -17,8 +17,8 @@ benchmark = []
[dependencies]
bytes = "1"
pnet = "0.28.0"
tokio-tun = "0.3.15"
tokio = { version = "1.11.0", features = ["full"] }
tokio = { version = "1.12.0", features = ["full"] }
rand = { version = "0.8.4", features = ["small_rng"] }
log = "0.4"
internet-checksum = "0.2.0"
dndx-fork-tokio-tun = "0.3.16"

View File

@ -1,9 +1,10 @@
#![cfg_attr(feature = "benchmark", feature(test))]
pub mod packet;
extern crate dndx_fork_tokio_tun as tokio_tun;
use bytes::{Bytes, BytesMut};
use log::{info, trace};
use log::{error, info, trace, warn};
use packet::*;
use pnet::packet::{tcp, Packet};
use rand::prelude::*;
@ -12,18 +13,18 @@ use std::fmt;
use std::net::{Ipv4Addr, SocketAddrV4};
use std::sync::atomic::{AtomicU32, Ordering};
use std::sync::{Arc, RwLock};
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::sync::broadcast;
use tokio::sync::mpsc::{self, Receiver, Sender};
use tokio::sync::watch;
use tokio::sync::Mutex as AsyncMutex;
use tokio::{io, time};
use tokio::time;
use tokio_tun::Tun;
const TIMEOUT: time::Duration = time::Duration::from_secs(1);
const RETRIES: usize = 6;
const MPSC_BUFFER_LEN: usize = 512;
#[derive(Debug, Hash, Eq, PartialEq)]
#[derive(Hash, Eq, PartialEq, Clone, Debug)]
pub struct AddrTuple {
local_addr: SocketAddrV4,
remote_addr: SocketAddrV4,
@ -38,12 +39,12 @@ impl AddrTuple {
}
}
#[derive(Debug)]
struct Shared {
tuples: RwLock<HashMap<AddrTuple, Arc<Sender<Bytes>>>>,
tuples: RwLock<HashMap<AddrTuple, Sender<Bytes>>>,
listening: RwLock<HashSet<u16>>,
outgoing: Sender<Bytes>,
tun: Vec<Arc<Tun>>,
ready: Sender<Socket>,
tuples_purge: broadcast::Sender<AddrTuple>,
}
pub struct Stack {
@ -52,7 +53,6 @@ pub struct Stack {
ready: Receiver<Socket>,
}
#[derive(Debug)]
pub enum State {
Idle,
SynSent,
@ -60,16 +60,9 @@ pub enum State {
Established,
}
#[derive(Debug)]
pub enum Mode {
Client,
Server,
}
#[derive(Debug)]
pub struct Socket {
mode: Mode,
shared: Arc<Shared>,
tun: Arc<Tun>,
incoming: AsyncMutex<Receiver<Bytes>>,
local_addr: SocketAddrV4,
remote_addr: SocketAddrV4,
@ -82,8 +75,8 @@ pub struct Socket {
impl Socket {
fn new(
mode: Mode,
shared: Arc<Shared>,
tun: Arc<Tun>,
local_addr: SocketAddrV4,
remote_addr: SocketAddrV4,
ack: Option<u32>,
@ -94,8 +87,8 @@ impl Socket {
(
Socket {
mode,
shared,
tun,
incoming: AsyncMutex::new(incoming_rx),
local_addr,
remote_addr,
@ -126,10 +119,10 @@ impl Socket {
match self.state {
State::Established => {
let buf = self.build_tcp_packet(tcp::TcpFlags::ACK, Some(payload));
self.seq.fetch_add(buf.len() as u32, Ordering::Relaxed);
self.seq.fetch_add(payload.len() as u32, Ordering::Relaxed);
tokio::select! {
res = self.shared.outgoing.send(buf) => {
res = self.tun.send(&buf) => {
res.unwrap();
Some(())
},
@ -186,7 +179,7 @@ impl Socket {
State::Idle => {
let buf = self.build_tcp_packet(tcp::TcpFlags::SYN | tcp::TcpFlags::ACK, None);
// ACK set by constructor
self.shared.outgoing.send(buf).await.unwrap();
self.tun.send(&buf).await.unwrap();
self.state = State::SynReceived;
info!("Sent SYN + ACK to client");
}
@ -210,7 +203,9 @@ impl Socket {
info!("Connection from {:?} established", self.remote_addr);
let ready = self.shared.ready.clone();
ready.send(self).await.unwrap();
if let Err(e) = ready.send(self).await {
error!("Unable to send accepted socket to ready queue: {}", e);
}
return;
}
} else {
@ -228,7 +223,7 @@ impl Socket {
match self.state {
State::Idle => {
let buf = self.build_tcp_packet(tcp::TcpFlags::SYN, None);
self.shared.outgoing.send(buf).await.unwrap();
self.tun.send(&buf).await.unwrap();
self.state = State::SynSent;
info!("Sent SYN to server");
}
@ -253,7 +248,7 @@ impl Socket {
// send ACK to finish handshake
let buf = self.build_tcp_packet(tcp::TcpFlags::ACK, None);
self.shared.outgoing.send(buf).await.unwrap();
self.tun.send(&buf).await.unwrap();
self.state = State::Established;
@ -277,17 +272,16 @@ impl Socket {
impl Drop for Socket {
fn drop(&mut self) {
let tuple = AddrTuple::new(self.local_addr, self.remote_addr);
// dissociates ourself from the dispatch map
assert!(self
.shared
.tuples
.write()
.unwrap()
.remove(&AddrTuple::new(self.local_addr, self.remote_addr))
.is_some());
assert!(self.shared.tuples.write().unwrap().remove(&tuple).is_some());
// purge cache
self.shared.tuples_purge.send(tuple).unwrap();
let buf = self.build_tcp_packet(tcp::TcpFlags::RST, None);
self.shared.outgoing.try_send(buf).unwrap();
if let Err(e) = self.tun.try_send(&buf) {
warn!("Unable to send RST to remote end: {}", e);
}
self.close();
info!("Fake TCP connection to {} closed", self);
}
@ -304,18 +298,27 @@ impl fmt::Display for Socket {
}
impl Stack {
pub fn new(tun: Tun) -> Stack {
let (outgoing_tx, outgoing_rx) = mpsc::channel(MPSC_BUFFER_LEN);
pub fn new(tun: Vec<Tun>) -> Stack {
let tun: Vec<Arc<Tun>> = tun.into_iter().map(Arc::new).collect();
let (ready_tx, ready_rx) = mpsc::channel(MPSC_BUFFER_LEN);
let (tuples_purge_tx, _tuples_purge_rx) = broadcast::channel(16);
let shared = Arc::new(Shared {
tuples: RwLock::new(HashMap::new()),
outgoing: outgoing_tx,
tun: tun.clone(),
listening: RwLock::new(HashSet::new()),
ready: ready_tx,
tuples_purge: tuples_purge_tx.clone(),
});
let local_ip = tun.destination().unwrap();
let local_ip = tun[0].destination().unwrap();
for t in tun {
tokio::spawn(Stack::reader_task(
t,
shared.clone(),
tuples_purge_tx.subscribe(),
));
}
tokio::spawn(Stack::dispatch(tun, outgoing_rx, shared.clone()));
Stack {
shared,
local_ip,
@ -337,8 +340,8 @@ impl Stack {
let local_addr = SocketAddrV4::new(self.local_ip, local_port);
let tuple = AddrTuple::new(local_addr, addr);
let (mut sock, incoming) = Socket::new(
Mode::Client,
self.shared.clone(),
self.shared.tun.choose(&mut rng).unwrap().clone(),
local_addr,
addr,
None,
@ -347,37 +350,46 @@ impl Stack {
{
let mut tuples = self.shared.tuples.write().unwrap();
assert!(tuples.insert(tuple, Arc::new(incoming.clone())).is_none());
assert!(tuples.insert(tuple, incoming.clone()).is_none());
}
sock.connect().await.map(|_| sock)
}
async fn dispatch(tun: Tun, mut outgoing: Receiver<Bytes>, shared: Arc<Shared>) {
let (mut tun_r, mut tun_w) = io::split(tun);
async fn reader_task(
tun: Arc<Tun>,
shared: Arc<Shared>,
mut tuples_purge: broadcast::Receiver<AddrTuple>,
) {
let mut tuples: HashMap<AddrTuple, Sender<Bytes>> = HashMap::new();
loop {
let mut buf = BytesMut::with_capacity(MAX_PACKET_LEN);
buf.resize(MAX_PACKET_LEN, 0);
tokio::select! {
buf = outgoing.recv() => {
let buf = buf.unwrap();
tun_w.write_all(&buf).await.unwrap();
},
s = tun_r.read_buf(&mut buf) => {
s.unwrap();
size = tun.recv(&mut buf) => {
let size = size.unwrap();
buf.truncate(size);
let buf = buf.freeze();
if buf[0] >> 4 != 4 {
// not an IPv4 packet
continue;
}
let (ip_packet, tcp_packet) = parse_ipv4_packet(&buf);
let local_addr = SocketAddrV4::new(ip_packet.get_destination(), tcp_packet.get_destination());
let local_addr =
SocketAddrV4::new(ip_packet.get_destination(), tcp_packet.get_destination());
let remote_addr = SocketAddrV4::new(ip_packet.get_source(), tcp_packet.get_source());
let tuple = AddrTuple::new(local_addr, remote_addr);
if let Some(c) = tuples.get(&tuple) {
c.send(buf).await.unwrap();
continue;
} else {
trace!("Cache miss, checking the shared tuples table for connection");
let sender;
{
let tuples = shared.tuples.read().unwrap();
@ -385,15 +397,36 @@ impl Stack {
}
if let Some(c) = sender {
trace!("Storing connection information into local tuples");
tuples.insert(tuple, c.clone());
c.send(buf).await.unwrap();
continue;
}
}
if tcp_packet.get_flags() == tcp::TcpFlags::SYN && shared.listening.read().unwrap().contains(&tcp_packet.get_destination()) {
if tcp_packet.get_flags() == tcp::TcpFlags::SYN
&& shared
.listening
.read()
.unwrap()
.contains(&tcp_packet.get_destination())
{
// SYN seen on listening socket
if tcp_packet.get_sequence() == 0 {
let (sock, incoming) = Socket::new(Mode::Server, shared.clone(), local_addr, remote_addr, Some(tcp_packet.get_sequence() + 1), State::Idle);
assert!(shared.tuples.write().unwrap().insert(tuple, Arc::new(incoming)).is_none());
let (sock, incoming) = Socket::new(
shared.clone(),
tun.clone(),
local_addr,
remote_addr,
Some(tcp_packet.get_sequence() + 1),
State::Idle,
);
assert!(shared
.tuples
.write()
.unwrap()
.insert(tuple, incoming)
.is_none());
tokio::spawn(sock.accept());
} else {
trace!("Bad TCP SYN packet from {}, sending RST", remote_addr);
@ -405,7 +438,7 @@ impl Stack {
tcp::TcpFlags::RST,
None,
);
shared.outgoing.try_send(buf).unwrap();
shared.tun[0].try_send(&buf).unwrap();
}
} else if (tcp_packet.get_flags() & tcp::TcpFlags::RST) == 0 {
info!("Unknown TCP packet from {}, sending RST", remote_addr);
@ -417,8 +450,13 @@ impl Stack {
tcp::TcpFlags::RST,
None,
);
shared.outgoing.try_send(buf).unwrap();
shared.tun[0].try_send(&buf).unwrap();
}
},
tuple = tuples_purge.recv() => {
let tuple = tuple.unwrap();
tuples.remove(&tuple);
trace!("Removed cached tuple");
}
}
}

View File

@ -18,8 +18,8 @@ pub fn build_tcp_packet(
payload: Option<&[u8]>,
) -> Bytes {
let wscale = (flags & tcp::TcpFlags::SYN) != 0;
let tcp_total_len = TCP_HEADER_LEN + if wscale {4} else {0} // nop + wscale
+ payload.map_or(0, |payload| payload.len());
let tcp_header_len = TCP_HEADER_LEN + if wscale { 4 } else { 0 }; // nop + wscale
let tcp_total_len = tcp_header_len + payload.map_or(0, |payload| payload.len());
let total_len = IPV4_HEADER_LEN + tcp_total_len;
let mut buf = BytesMut::with_capacity(total_len);
buf.resize(total_len, 0);
@ -62,9 +62,10 @@ pub fn build_tcp_packet(
cksm.add_bytes(&local_addr.ip().octets());
cksm.add_bytes(&remote_addr.ip().octets());
let ip::IpNextHeaderProtocol(tcp_protocol) = ip::IpNextHeaderProtocols::Tcp;
let pseudo = [0u8, tcp_protocol, 0, tcp_total_len as u8];
let mut pseudo = [0u8, tcp_protocol, 0, 0];
pseudo[2..].copy_from_slice(&(tcp_total_len as u16).to_be_bytes());
cksm.add_bytes(&pseudo);
cksm.add_bytes(v4.packet());
cksm.add_bytes(tcp.packet());
tcp.set_checksum(u16::from_be_bytes(cksm.checksum()));
v4_buf.unsplit(tcp_buf);

View File

@ -1,20 +1,21 @@
[package]
name = "phantun"
version = "0.1.0"
version = "0.1.1"
edition = "2018"
authors = ["Datong Sun <dndx@idndx.com>"]
license = "MIT OR Apache-2.0"
repository = "https://github.com/dndx/phantun"
readme = "README.md"
description = """
Turns transforms UDP stream into (fake) TCP streams that can go through
Layer 4 firewalls.
Transforms UDP stream into (fake) TCP streams that can go through
Layer 3 & Layer 4 (NAPT) firewalls/NATs.
"""
[dependencies]
clap = "2.33.3"
socket2 = { version = "0.4.2", features = ["all"] }
fake-tcp = "0.1.0"
tokio-tun = "0.3.15"
tokio = { version = "1.11.0", features = ["full"] }
fake-tcp = { path = "../fake-tcp" }
tokio = { version = "1.12.0", features = ["full"] }
log = "0.4"
pretty_env_logger = "0.4.0"
dndx-fork-tokio-tun = "0.3.16"
num_cpus = "1.13.0"

View File

@ -1,3 +1,5 @@
extern crate dndx_fork_tokio_tun as tokio_tun;
use clap::{App, Arg};
use fake_tcp::packet::MAX_PACKET_LEN;
use fake_tcp::{Socket, Stack};
@ -70,10 +72,10 @@ async fn main() {
.up() // or set it up manually using `sudo ip link set <tun-name> up`.
.address("192.168.200.1".parse().unwrap())
.destination("192.168.200.2".parse().unwrap())
.try_build()
.try_build_mq(num_cpus::get())
.unwrap();
info!("Created TUN device {}", tun.name());
info!("Created TUN device {}", tun[0].name());
let udp_sock = Arc::new(new_udp_reuseport(local_addr));
let connections = Arc::new(RwLock::new(HashMap::<SocketAddrV4, Arc<Socket>>::new()));

View File

@ -1,3 +1,5 @@
extern crate dndx_fork_tokio_tun as tokio_tun;
use clap::{App, Arg};
use fake_tcp::packet::MAX_PACKET_LEN;
use fake_tcp::Stack;
@ -53,9 +55,11 @@ async fn main() {
.up() // or set it up manually using `sudo ip link set <tun-name> up`.
.address("192.168.201.1".parse().unwrap())
.destination("192.168.201.2".parse().unwrap())
.try_build()
.try_build_mq(num_cpus::get())
.unwrap();
info!("Created TUN device {}", tun[0].name());
//thread::sleep(time::Duration::from_secs(5));
let mut stack = Stack::new(tun);
stack.listen(local_port);