diff --git a/fake-tcp/Cargo.toml b/fake-tcp/Cargo.toml index 70d3017..7c1c96c 100644 --- a/fake-tcp/Cargo.toml +++ b/fake-tcp/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "fake-tcp" -version = "0.1.0" +version = "0.1.1" edition = "2018" authors = ["Datong Sun "] license = "MIT OR Apache-2.0" @@ -17,8 +17,8 @@ benchmark = [] [dependencies] bytes = "1" pnet = "0.28.0" -tokio-tun = "0.3.15" -tokio = { version = "1.11.0", features = ["full"] } +tokio = { version = "1.12.0", features = ["full"] } rand = { version = "0.8.4", features = ["small_rng"] } log = "0.4" internet-checksum = "0.2.0" +dndx-fork-tokio-tun = "0.3.16" diff --git a/fake-tcp/src/lib.rs b/fake-tcp/src/lib.rs index 32c1270..eb7a5ba 100644 --- a/fake-tcp/src/lib.rs +++ b/fake-tcp/src/lib.rs @@ -1,9 +1,10 @@ #![cfg_attr(feature = "benchmark", feature(test))] pub mod packet; +extern crate dndx_fork_tokio_tun as tokio_tun; use bytes::{Bytes, BytesMut}; -use log::{info, trace}; +use log::{error, info, trace, warn}; use packet::*; use pnet::packet::{tcp, Packet}; use rand::prelude::*; @@ -12,18 +13,18 @@ use std::fmt; use std::net::{Ipv4Addr, SocketAddrV4}; use std::sync::atomic::{AtomicU32, Ordering}; use std::sync::{Arc, RwLock}; -use tokio::io::{AsyncReadExt, AsyncWriteExt}; +use tokio::sync::broadcast; use tokio::sync::mpsc::{self, Receiver, Sender}; use tokio::sync::watch; use tokio::sync::Mutex as AsyncMutex; -use tokio::{io, time}; +use tokio::time; use tokio_tun::Tun; const TIMEOUT: time::Duration = time::Duration::from_secs(1); const RETRIES: usize = 6; const MPSC_BUFFER_LEN: usize = 512; -#[derive(Debug, Hash, Eq, PartialEq)] +#[derive(Hash, Eq, PartialEq, Clone, Debug)] pub struct AddrTuple { local_addr: SocketAddrV4, remote_addr: SocketAddrV4, @@ -38,12 +39,12 @@ impl AddrTuple { } } -#[derive(Debug)] struct Shared { - tuples: RwLock>>>, + tuples: RwLock>>, listening: RwLock>, - outgoing: Sender, + tun: Vec>, ready: Sender, + tuples_purge: broadcast::Sender, } pub struct Stack { @@ -52,7 +53,6 @@ pub struct Stack { ready: Receiver, } -#[derive(Debug)] pub enum State { Idle, SynSent, @@ -60,16 +60,9 @@ pub enum State { Established, } -#[derive(Debug)] -pub enum Mode { - Client, - Server, -} - -#[derive(Debug)] pub struct Socket { - mode: Mode, shared: Arc, + tun: Arc, incoming: AsyncMutex>, local_addr: SocketAddrV4, remote_addr: SocketAddrV4, @@ -82,8 +75,8 @@ pub struct Socket { impl Socket { fn new( - mode: Mode, shared: Arc, + tun: Arc, local_addr: SocketAddrV4, remote_addr: SocketAddrV4, ack: Option, @@ -94,8 +87,8 @@ impl Socket { ( Socket { - mode, shared, + tun, incoming: AsyncMutex::new(incoming_rx), local_addr, remote_addr, @@ -126,10 +119,10 @@ impl Socket { match self.state { State::Established => { let buf = self.build_tcp_packet(tcp::TcpFlags::ACK, Some(payload)); - self.seq.fetch_add(buf.len() as u32, Ordering::Relaxed); + self.seq.fetch_add(payload.len() as u32, Ordering::Relaxed); tokio::select! { - res = self.shared.outgoing.send(buf) => { + res = self.tun.send(&buf) => { res.unwrap(); Some(()) }, @@ -186,7 +179,7 @@ impl Socket { State::Idle => { let buf = self.build_tcp_packet(tcp::TcpFlags::SYN | tcp::TcpFlags::ACK, None); // ACK set by constructor - self.shared.outgoing.send(buf).await.unwrap(); + self.tun.send(&buf).await.unwrap(); self.state = State::SynReceived; info!("Sent SYN + ACK to client"); } @@ -210,7 +203,9 @@ impl Socket { info!("Connection from {:?} established", self.remote_addr); let ready = self.shared.ready.clone(); - ready.send(self).await.unwrap(); + if let Err(e) = ready.send(self).await { + error!("Unable to send accepted socket to ready queue: {}", e); + } return; } } else { @@ -228,7 +223,7 @@ impl Socket { match self.state { State::Idle => { let buf = self.build_tcp_packet(tcp::TcpFlags::SYN, None); - self.shared.outgoing.send(buf).await.unwrap(); + self.tun.send(&buf).await.unwrap(); self.state = State::SynSent; info!("Sent SYN to server"); } @@ -253,7 +248,7 @@ impl Socket { // send ACK to finish handshake let buf = self.build_tcp_packet(tcp::TcpFlags::ACK, None); - self.shared.outgoing.send(buf).await.unwrap(); + self.tun.send(&buf).await.unwrap(); self.state = State::Established; @@ -277,17 +272,16 @@ impl Socket { impl Drop for Socket { fn drop(&mut self) { + let tuple = AddrTuple::new(self.local_addr, self.remote_addr); // dissociates ourself from the dispatch map - assert!(self - .shared - .tuples - .write() - .unwrap() - .remove(&AddrTuple::new(self.local_addr, self.remote_addr)) - .is_some()); + assert!(self.shared.tuples.write().unwrap().remove(&tuple).is_some()); + // purge cache + self.shared.tuples_purge.send(tuple).unwrap(); let buf = self.build_tcp_packet(tcp::TcpFlags::RST, None); - self.shared.outgoing.try_send(buf).unwrap(); + if let Err(e) = self.tun.try_send(&buf) { + warn!("Unable to send RST to remote end: {}", e); + } self.close(); info!("Fake TCP connection to {} closed", self); } @@ -304,18 +298,27 @@ impl fmt::Display for Socket { } impl Stack { - pub fn new(tun: Tun) -> Stack { - let (outgoing_tx, outgoing_rx) = mpsc::channel(MPSC_BUFFER_LEN); + pub fn new(tun: Vec) -> Stack { + let tun: Vec> = tun.into_iter().map(Arc::new).collect(); let (ready_tx, ready_rx) = mpsc::channel(MPSC_BUFFER_LEN); + let (tuples_purge_tx, _tuples_purge_rx) = broadcast::channel(16); let shared = Arc::new(Shared { tuples: RwLock::new(HashMap::new()), - outgoing: outgoing_tx, + tun: tun.clone(), listening: RwLock::new(HashSet::new()), ready: ready_tx, + tuples_purge: tuples_purge_tx.clone(), }); - let local_ip = tun.destination().unwrap(); + let local_ip = tun[0].destination().unwrap(); + + for t in tun { + tokio::spawn(Stack::reader_task( + t, + shared.clone(), + tuples_purge_tx.subscribe(), + )); + } - tokio::spawn(Stack::dispatch(tun, outgoing_rx, shared.clone())); Stack { shared, local_ip, @@ -337,8 +340,8 @@ impl Stack { let local_addr = SocketAddrV4::new(self.local_ip, local_port); let tuple = AddrTuple::new(local_addr, addr); let (mut sock, incoming) = Socket::new( - Mode::Client, self.shared.clone(), + self.shared.tun.choose(&mut rng).unwrap().clone(), local_addr, addr, None, @@ -347,53 +350,83 @@ impl Stack { { let mut tuples = self.shared.tuples.write().unwrap(); - assert!(tuples.insert(tuple, Arc::new(incoming.clone())).is_none()); + assert!(tuples.insert(tuple, incoming.clone()).is_none()); } sock.connect().await.map(|_| sock) } - async fn dispatch(tun: Tun, mut outgoing: Receiver, shared: Arc) { - let (mut tun_r, mut tun_w) = io::split(tun); + async fn reader_task( + tun: Arc, + shared: Arc, + mut tuples_purge: broadcast::Receiver, + ) { + let mut tuples: HashMap> = HashMap::new(); loop { let mut buf = BytesMut::with_capacity(MAX_PACKET_LEN); + buf.resize(MAX_PACKET_LEN, 0); tokio::select! { - buf = outgoing.recv() => { - let buf = buf.unwrap(); - tun_w.write_all(&buf).await.unwrap(); - }, - s = tun_r.read_buf(&mut buf) => { - s.unwrap(); + size = tun.recv(&mut buf) => { + let size = size.unwrap(); + buf.truncate(size); let buf = buf.freeze(); + if buf[0] >> 4 != 4 { // not an IPv4 packet continue; } let (ip_packet, tcp_packet) = parse_ipv4_packet(&buf); - let local_addr = SocketAddrV4::new(ip_packet.get_destination(), tcp_packet.get_destination()); + let local_addr = + SocketAddrV4::new(ip_packet.get_destination(), tcp_packet.get_destination()); let remote_addr = SocketAddrV4::new(ip_packet.get_source(), tcp_packet.get_source()); let tuple = AddrTuple::new(local_addr, remote_addr); - - let sender; - { - let tuples = shared.tuples.read().unwrap(); - sender = tuples.get(&tuple).cloned(); - } - - if let Some(c) = sender { + if let Some(c) = tuples.get(&tuple) { c.send(buf).await.unwrap(); continue; + + } else { + trace!("Cache miss, checking the shared tuples table for connection"); + let sender; + { + let tuples = shared.tuples.read().unwrap(); + sender = tuples.get(&tuple).cloned(); + } + + if let Some(c) = sender { + trace!("Storing connection information into local tuples"); + tuples.insert(tuple, c.clone()); + c.send(buf).await.unwrap(); + continue; + } } - if tcp_packet.get_flags() == tcp::TcpFlags::SYN && shared.listening.read().unwrap().contains(&tcp_packet.get_destination()) { + if tcp_packet.get_flags() == tcp::TcpFlags::SYN + && shared + .listening + .read() + .unwrap() + .contains(&tcp_packet.get_destination()) + { // SYN seen on listening socket if tcp_packet.get_sequence() == 0 { - let (sock, incoming) = Socket::new(Mode::Server, shared.clone(), local_addr, remote_addr, Some(tcp_packet.get_sequence() + 1), State::Idle); - assert!(shared.tuples.write().unwrap().insert(tuple, Arc::new(incoming)).is_none()); + let (sock, incoming) = Socket::new( + shared.clone(), + tun.clone(), + local_addr, + remote_addr, + Some(tcp_packet.get_sequence() + 1), + State::Idle, + ); + assert!(shared + .tuples + .write() + .unwrap() + .insert(tuple, incoming) + .is_none()); tokio::spawn(sock.accept()); } else { trace!("Bad TCP SYN packet from {}, sending RST", remote_addr); @@ -405,7 +438,7 @@ impl Stack { tcp::TcpFlags::RST, None, ); - shared.outgoing.try_send(buf).unwrap(); + shared.tun[0].try_send(&buf).unwrap(); } } else if (tcp_packet.get_flags() & tcp::TcpFlags::RST) == 0 { info!("Unknown TCP packet from {}, sending RST", remote_addr); @@ -417,8 +450,13 @@ impl Stack { tcp::TcpFlags::RST, None, ); - shared.outgoing.try_send(buf).unwrap(); + shared.tun[0].try_send(&buf).unwrap(); } + }, + tuple = tuples_purge.recv() => { + let tuple = tuple.unwrap(); + tuples.remove(&tuple); + trace!("Removed cached tuple"); } } } diff --git a/fake-tcp/src/packet.rs b/fake-tcp/src/packet.rs index 546fb73..c73d3ad 100644 --- a/fake-tcp/src/packet.rs +++ b/fake-tcp/src/packet.rs @@ -18,8 +18,8 @@ pub fn build_tcp_packet( payload: Option<&[u8]>, ) -> Bytes { let wscale = (flags & tcp::TcpFlags::SYN) != 0; - let tcp_total_len = TCP_HEADER_LEN + if wscale {4} else {0} // nop + wscale - + payload.map_or(0, |payload| payload.len()); + let tcp_header_len = TCP_HEADER_LEN + if wscale { 4 } else { 0 }; // nop + wscale + let tcp_total_len = tcp_header_len + payload.map_or(0, |payload| payload.len()); let total_len = IPV4_HEADER_LEN + tcp_total_len; let mut buf = BytesMut::with_capacity(total_len); buf.resize(total_len, 0); @@ -62,9 +62,10 @@ pub fn build_tcp_packet( cksm.add_bytes(&local_addr.ip().octets()); cksm.add_bytes(&remote_addr.ip().octets()); let ip::IpNextHeaderProtocol(tcp_protocol) = ip::IpNextHeaderProtocols::Tcp; - let pseudo = [0u8, tcp_protocol, 0, tcp_total_len as u8]; + let mut pseudo = [0u8, tcp_protocol, 0, 0]; + pseudo[2..].copy_from_slice(&(tcp_total_len as u16).to_be_bytes()); cksm.add_bytes(&pseudo); - cksm.add_bytes(v4.packet()); + cksm.add_bytes(tcp.packet()); tcp.set_checksum(u16::from_be_bytes(cksm.checksum())); v4_buf.unsplit(tcp_buf); diff --git a/phantun/Cargo.toml b/phantun/Cargo.toml index db0938a..ac1a3e9 100644 --- a/phantun/Cargo.toml +++ b/phantun/Cargo.toml @@ -1,20 +1,21 @@ [package] name = "phantun" -version = "0.1.0" +version = "0.1.1" edition = "2018" authors = ["Datong Sun "] license = "MIT OR Apache-2.0" repository = "https://github.com/dndx/phantun" readme = "README.md" description = """ -Turns transforms UDP stream into (fake) TCP streams that can go through -Layer 4 firewalls. +Transforms UDP stream into (fake) TCP streams that can go through +Layer 3 & Layer 4 (NAPT) firewalls/NATs. """ [dependencies] clap = "2.33.3" socket2 = { version = "0.4.2", features = ["all"] } -fake-tcp = "0.1.0" -tokio-tun = "0.3.15" -tokio = { version = "1.11.0", features = ["full"] } +fake-tcp = { path = "../fake-tcp" } +tokio = { version = "1.12.0", features = ["full"] } log = "0.4" pretty_env_logger = "0.4.0" +dndx-fork-tokio-tun = "0.3.16" +num_cpus = "1.13.0" diff --git a/phantun/src/bin/client.rs b/phantun/src/bin/client.rs index ffd92e8..ff8129b 100644 --- a/phantun/src/bin/client.rs +++ b/phantun/src/bin/client.rs @@ -1,3 +1,5 @@ +extern crate dndx_fork_tokio_tun as tokio_tun; + use clap::{App, Arg}; use fake_tcp::packet::MAX_PACKET_LEN; use fake_tcp::{Socket, Stack}; @@ -70,10 +72,10 @@ async fn main() { .up() // or set it up manually using `sudo ip link set up`. .address("192.168.200.1".parse().unwrap()) .destination("192.168.200.2".parse().unwrap()) - .try_build() + .try_build_mq(num_cpus::get()) .unwrap(); - info!("Created TUN device {}", tun.name()); + info!("Created TUN device {}", tun[0].name()); let udp_sock = Arc::new(new_udp_reuseport(local_addr)); let connections = Arc::new(RwLock::new(HashMap::>::new())); diff --git a/phantun/src/bin/server.rs b/phantun/src/bin/server.rs index bdd6d85..3914b1f 100644 --- a/phantun/src/bin/server.rs +++ b/phantun/src/bin/server.rs @@ -1,3 +1,5 @@ +extern crate dndx_fork_tokio_tun as tokio_tun; + use clap::{App, Arg}; use fake_tcp::packet::MAX_PACKET_LEN; use fake_tcp::Stack; @@ -53,9 +55,11 @@ async fn main() { .up() // or set it up manually using `sudo ip link set up`. .address("192.168.201.1".parse().unwrap()) .destination("192.168.201.2".parse().unwrap()) - .try_build() + .try_build_mq(num_cpus::get()) .unwrap(); + info!("Created TUN device {}", tun[0].name()); + //thread::sleep(time::Duration::from_secs(5)); let mut stack = Stack::new(tun); stack.listen(local_port);