From 65e200b1f2f837c03cec372c5ccc6d7a723c8a98 Mon Sep 17 00:00:00 2001 From: Saber Haj Rabiee Date: Thu, 10 Nov 2022 15:23:09 -0800 Subject: [PATCH] Multi-stream TCP and UDP, encryption and performance --- Cargo.toml | 4 + README.md | 17 +- fake-tcp/Cargo.toml | 3 +- fake-tcp/src/lib.rs | 108 ++++++------ phantun/Cargo.toml | 3 + phantun/src/bin/client.rs | 350 ++++++++++++++++++++++++-------------- phantun/src/bin/server.rs | 187 +++++++++++++------- phantun/src/lib.rs | 139 ++++++++++++++- 8 files changed, 566 insertions(+), 245 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 35a885f..e56299d 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -4,3 +4,7 @@ members = [ "fake-tcp", "phantun", ] + +[profile.release] +lto = true +codegen-units = 1 diff --git a/README.md b/README.md index ecab151..889881e 100644 --- a/README.md +++ b/README.md @@ -335,7 +335,6 @@ Writeup on some of the techniques used in Phantun to achieve this performance re # Future plans -* Load balancing a single UDP stream into multiple TCP streams * Integration tests * Auto insertion/removal of required firewall rules @@ -351,17 +350,19 @@ performance overall and less MTU overhead because lack of additional headers ins Here is a quick overview of comparison between those two to help you choose: -| | Phantun | udp2raw | -|--------------------------------------------------|:-------------:|:-----------------:| +| | Phantun | udp2raw | +|--------------------------------------------------|:-------------:|:-------------------:| | UDP over FakeTCP obfuscation | ✅ | ✅ | | UDP over ICMP obfuscation | ❌ | ✅ | | UDP over UDP obfuscation | ❌ | ✅ | +| Arbitrary TCP handshake content | ✅ | ❌ | | Multi-threaded | ✅ | ❌ | -| Throughput | Better | Good | -| Layer 3 mode | TUN interface | Raw sockets + BPF | -| Tunneling MTU overhead | 12 bytes | 44 bytes | -| Seprate TCP connections for each UDP connection | Client/Server | Server only | -| Anti-replay, encryption | ❌ | ✅ | +| Throughput | Better | Good | +| Layer 3 mode | TUN interface | Raw sockets + BPF | +| Tunneling MTU overhead | 12 bytes | 44 bytes | +| Seprate TCP connections for each UDP connection | Client/Server | Server only | +| Anti-replay | ❌ | ✅ | +| Encryption | ✅ | ✅ | | IPv6 | ✅ | ✅ | [Back to TOC](#table-of-contents) diff --git a/fake-tcp/Cargo.toml b/fake-tcp/Cargo.toml index 05337b6..5c3ef5b 100644 --- a/fake-tcp/Cargo.toml +++ b/fake-tcp/Cargo.toml @@ -18,8 +18,9 @@ benchmark = [] bytes = "1" pnet = "0.31" tokio = { version = "1.14", features = ["full"] } -rand = { version = "0.8", features = ["small_rng"] } log = "0.4" internet-checksum = "0.2" tokio-tun = "0.7" flume = "0.10" +fxhash = "0.2.1" +dashmap = "5.4.0" diff --git a/fake-tcp/src/lib.rs b/fake-tcp/src/lib.rs index a5c5863..8f531c3 100644 --- a/fake-tcp/src/lib.rs +++ b/fake-tcp/src/lib.rs @@ -43,22 +43,23 @@ pub mod packet; use bytes::{Bytes, BytesMut}; +use dashmap::{mapref::entry::Entry, DashMap, DashSet}; +use fxhash::FxBuildHasher; use log::{error, info, trace, warn}; use packet::*; use pnet::packet::{tcp, Packet}; -use rand::prelude::*; -use std::collections::{HashMap, HashSet}; +use std::collections::HashMap; use std::fmt; use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr}; -use std::sync::atomic::{AtomicU32, Ordering}; -use std::sync::{Arc, RwLock}; +use std::sync::atomic::{AtomicU32, AtomicUsize, Ordering}; +use std::sync::Arc; use tokio::sync::broadcast; use tokio::sync::mpsc; use tokio::time; use tokio_tun::Tun; -const TIMEOUT: time::Duration = time::Duration::from_secs(1); -const RETRIES: usize = 6; +const TIMEOUT: time::Duration = time::Duration::from_secs(3); +const RETRIES: usize = 2; const MPMC_BUFFER_LEN: usize = 512; const MPSC_BUFFER_LEN: usize = 128; const MAX_UNACKED_LEN: u32 = 128 * 1024 * 1024; // 128MB @@ -79,9 +80,10 @@ impl AddrTuple { } struct Shared { - tuples: RwLock>>, - listening: RwLock>, + tuples: DashMap, FxBuildHasher>, + listening: DashSet, tun: Vec>, + tun_index: AtomicUsize, ready: mpsc::Sender, tuples_purge: broadcast::Sender, } @@ -322,7 +324,7 @@ impl Drop for Socket { fn drop(&mut self) { let tuple = AddrTuple::new(self.local_addr, self.remote_addr); // dissociates ourself from the dispatch map - assert!(self.shared.tuples.write().unwrap().remove(&tuple).is_some()); + assert!(self.shared.tuples.remove(&tuple).is_some()); // purge cache self.shared.tuples_purge.send(tuple).unwrap(); @@ -364,9 +366,10 @@ impl Stack { let (ready_tx, ready_rx) = mpsc::channel(MPSC_BUFFER_LEN); let (tuples_purge_tx, _tuples_purge_rx) = broadcast::channel(16); let shared = Arc::new(Shared { - tuples: RwLock::new(HashMap::new()), + tuples: DashMap::default(), tun: tun.clone(), - listening: RwLock::new(HashSet::new()), + tun_index: AtomicUsize::new(0), + listening: DashSet::default(), ready: ready_tx, tuples_purge: tuples_purge_tx.clone(), }); @@ -389,7 +392,7 @@ impl Stack { /// Listens for incoming connections on the given `port`. pub fn listen(&mut self, port: u16) { - assert!(self.shared.listening.write().unwrap().insert(port)); + assert!(self.shared.listening.insert(port)); } /// Accepts an incoming connection. @@ -399,33 +402,38 @@ impl Stack { /// Connects to the remote end. `None` returned means /// the connection attempt failed. - pub async fn connect(&mut self, addr: SocketAddr) -> Option { - let mut rng = SmallRng::from_entropy(); - let local_port: u16 = rng.gen_range(1024..65535); - let local_addr = SocketAddr::new( - if addr.is_ipv4() { - IpAddr::V4(self.local_ip) - } else { - IpAddr::V6(self.local_ip6.expect("IPv6 local address undefined")) - }, - local_port, - ); - let tuple = AddrTuple::new(local_addr, addr); - let (mut sock, incoming) = Socket::new( - self.shared.clone(), - self.shared.tun.choose(&mut rng).unwrap().clone(), - local_addr, - addr, - None, - State::Idle, - ); - - { - let mut tuples = self.shared.tuples.write().unwrap(); - assert!(tuples.insert(tuple, incoming.clone()).is_none()); + pub async fn connect(&self, addr: SocketAddr) -> Option { + for local_port in 1024..u16::MAX { + let local_addr = SocketAddr::new( + if addr.is_ipv4() { + IpAddr::V4(self.local_ip) + } else { + IpAddr::V6(self.local_ip6.expect("IPv6 local address undefined")) + }, + local_port, + ); + let tuple = AddrTuple::new(local_addr, addr); + let mut sock = match self.shared.tuples.entry(tuple) { + Entry::Occupied(_) => continue, + Entry::Vacant(v) => { + let tun_index = self.shared.tun_index.fetch_add(1, Ordering::Relaxed) + % self.shared.tun.len(); + let tun = unsafe { self.shared.tun.get_unchecked(tun_index).clone() }; + let (sock, incoming) = Socket::new( + self.shared.clone(), + tun, + local_addr, + addr, + None, + State::Idle, + ); + v.insert(incoming.clone()); + sock + } + }; + return sock.connect().await.map(|_| sock); } - - sock.connect().await.map(|_| sock) + None } async fn reader_task( @@ -433,7 +441,8 @@ impl Stack { shared: Arc, mut tuples_purge: broadcast::Receiver, ) { - let mut tuples: HashMap> = HashMap::new(); + let mut tuples: HashMap, FxBuildHasher> = + HashMap::default(); loop { let mut buf = BytesMut::zeroed(MAX_PACKET_LEN); @@ -462,10 +471,7 @@ impl Stack { // path below } else { trace!("Cache miss, checking the shared tuples table for connection"); - let sender = { - let tuples = shared.tuples.read().unwrap(); - tuples.get(&tuple).cloned() - }; + let sender = shared.tuples.get(&tuple); if let Some(c) = sender { trace!("Storing connection information into local tuples"); @@ -478,8 +484,6 @@ impl Stack { if tcp_packet.get_flags() == tcp::TcpFlags::SYN && shared .listening - .read() - .unwrap() .contains(&tcp_packet.get_destination()) { // SYN seen on listening socket @@ -494,8 +498,6 @@ impl Stack { ); assert!(shared .tuples - .write() - .unwrap() .insert(tuple, incoming) .is_none()); tokio::spawn(sock.accept()); @@ -509,7 +511,11 @@ impl Stack { tcp::TcpFlags::RST | tcp::TcpFlags::ACK, None, ); - shared.tun[0].try_send(&buf).unwrap(); + let tun_index = shared.tun_index.fetch_add(1, Ordering::Relaxed) % shared.tun.len(); + let tun = unsafe { + shared.tun.get_unchecked(tun_index) + }; + tun.try_send(&buf).unwrap(); } } else if (tcp_packet.get_flags() & tcp::TcpFlags::RST) == 0 { info!("Unknown TCP packet from {}, sending RST", remote_addr); @@ -521,7 +527,11 @@ impl Stack { tcp::TcpFlags::RST | tcp::TcpFlags::ACK, None, ); - shared.tun[0].try_send(&buf).unwrap(); + let tun_index = shared.tun_index.fetch_add(1, Ordering::Relaxed) % shared.tun.len(); + let tun = unsafe { + shared.tun.get_unchecked(tun_index) + }; + tun.try_send(&buf).unwrap(); } } None => { diff --git a/phantun/Cargo.toml b/phantun/Cargo.toml index 0864195..9ac2d8a 100644 --- a/phantun/Cargo.toml +++ b/phantun/Cargo.toml @@ -22,3 +22,6 @@ tokio-tun = "0.7" num_cpus = "1.13" neli = "0.6" nix = "0.25" + +[dev-dependencies] +rand = "0.8.5" diff --git a/phantun/src/bin/client.rs b/phantun/src/bin/client.rs index 360101e..8b4378a 100644 --- a/phantun/src/bin/client.rs +++ b/phantun/src/bin/client.rs @@ -1,14 +1,15 @@ use clap::{crate_version, Arg, ArgAction, Command}; use fake_tcp::packet::MAX_PACKET_LEN; -use fake_tcp::{Socket, Stack}; +use fake_tcp::Stack; use log::{debug, error, info}; use phantun::utils::{assign_ipv6_address, new_udp_reuseport}; -use std::collections::HashMap; +use phantun::Encryption; use std::fs; use std::io; use std::net::{Ipv4Addr, SocketAddr}; +use std::sync::atomic::{AtomicUsize, Ordering}; use std::sync::Arc; -use tokio::sync::{Notify, RwLock}; +use tokio::sync::Notify; use tokio::time; use tokio_tun::TunBuilder; use tokio_util::sync::CancellationToken; @@ -101,13 +102,40 @@ async fn main() -> io::Result<()> { Note: ensure this file's size does not exceed the MTU of the outgoing interface. \ The content is always sent out in a single packet and will not be further segmented") ) + .arg( + Arg::new("tcp_connections") + .long("tcp-connections") + .required(false) + .value_name("number") + .help("Number of TCP connections per each client.") + .default_value("8") + ) + .arg( + Arg::new("udp_connections") + .long("udp-connections") + .required(false) + .value_name("number") + .help("Number of UDP connections per each client.") + .default_value("8") + ) + .arg( + Arg::new("encryption") + .long("encryption") + .required(false) + .value_name("encryption") + .help("Specify an encryption algorithm for using in TCP connections. \n\ + Server and client should use the same encryption. \n\ + Currently XOR is only supported and the format should be 'xor:key'.") + ) .get_matches(); - let local_addr: SocketAddr = matches - .get_one::("local") - .unwrap() - .parse() - .expect("bad local address"); + let local_addr: Arc = Arc::new( + matches + .get_one::("local") + .unwrap() + .parse() + .expect("bad local address"), + ); let ipv4_only = matches.get_flag("ipv4_only"); @@ -129,7 +157,7 @@ async fn main() -> io::Result<()> { .parse() .expect("bad peer address for Tun interface"); - let (tun_local6, tun_peer6) = if matches.get_flag("ipv4_only") { + let (tun_local6, tun_peer6) = if ipv4_only { (None, None) } else { ( @@ -142,11 +170,37 @@ async fn main() -> io::Result<()> { ) }; + let tcp_socks_amount: usize = matches + .get_one::("tcp_connections") + .unwrap() + .parse() + .expect("Unspecified number of TCP connections per each client"); + if tcp_socks_amount == 0 { + panic!("TCP connections should be greater than or equal to 1"); + } + + let udp_socks_amount: usize = matches + .get_one::("udp_connections") + .unwrap() + .parse() + .expect("Unspecified number of UDP connections per each client"); + if udp_socks_amount == 0 { + panic!("UDP connections should be greater than or equal to 1"); + } + + let encryption = matches + .get_one::("encryption") + .map(Encryption::from); + debug!("Encryption in use: {:?}", encryption); + let encryption = Arc::new(encryption); + let tun_name = matches.get_one::("tun").unwrap(); - let handshake_packet: Option> = matches - .get_one::("handshake_packet") - .map(fs::read) - .transpose()?; + let handshake_packet: Arc>> = Arc::new( + matches + .get_one::("handshake_packet") + .map(fs::read) + .transpose()?, + ); let num_cpus = num_cpus::get(); info!("{} cores available", num_cpus); @@ -167,137 +221,175 @@ async fn main() -> io::Result<()> { info!("Created TUN device {}", tun[0].name()); - let udp_sock = Arc::new(new_udp_reuseport(local_addr)); - let connections = Arc::new(RwLock::new(HashMap::>::new())); - - let mut stack = Stack::new(tun, tun_peer, tun_peer6); + let stack = Arc::new(Stack::new(tun, tun_peer, tun_peer6)); + let local_addr = local_addr.clone(); let main_loop = tokio::spawn(async move { let mut buf_r = [0u8; MAX_PACKET_LEN]; + let udp_sock = new_udp_reuseport(*local_addr); - loop { - let (size, addr) = udp_sock.recv_from(&mut buf_r).await?; - // seen UDP packet to listening socket, this means: - // 1. It is a new UDP connection, or - // 2. It is some extra packets not filtered by more specific - // connected UDP socket yet - if let Some(sock) = connections.read().await.get(&addr) { - sock.send(&buf_r[..size]).await; - continue; - } + 'main_loop: loop { + let (size, addr) = udp_sock.recv_from(&mut buf_r).await.unwrap(); info!("New UDP client from {}", addr); - let sock = stack.connect(remote_addr).await; - if sock.is_none() { - error!("Unable to connect to remote {}", remote_addr); - continue; - } + let stack = stack.clone(); + let local_addr = local_addr.clone(); + let handshake_packet = handshake_packet.clone(); + let encryption = encryption.clone(); - let sock = Arc::new(sock.unwrap()); - if let Some(ref p) = handshake_packet { - if sock.send(p).await.is_none() { - error!("Failed to send handshake packet to remote, closing connection."); - continue; + let udp_socks: Vec<_> = { + let mut socks = Vec::with_capacity(udp_socks_amount); + for _ in 0..udp_socks_amount { + let udp_sock = new_udp_reuseport(*local_addr); + if let Err(err) = udp_sock.connect(addr).await { + error!("Unable to connect to {addr} over udp: {err}"); + continue 'main_loop; + } + socks.push(Arc::new(udp_sock)); + } + socks + }; + tokio::spawn(async move { + let udp_socks = Arc::new(udp_socks); + let cancellation = CancellationToken::new(); + let packet_received = Arc::new(Notify::new()); + let mut tcp_socks = Vec::with_capacity(tcp_socks_amount); + let udp_sock_index = Arc::new(AtomicUsize::new(0)); + let tcp_sock_index = Arc::new(AtomicUsize::new(0)); + + for sock_index in 0..tcp_socks_amount { + debug!("Creating tcp stream number {sock_index} for {addr} to {remote_addr}."); + let tcp_sock = match stack.connect(remote_addr).await { + Some(tcp_sock) => Arc::new(tcp_sock), + None => { + error!("Unable to connect to remote {}", remote_addr); + cancellation.cancel(); + return; + } + }; + + if let Some(ref p) = *handshake_packet { + if tcp_sock.send(p).await.is_none() { + error!( + "Failed to send handshake packet to remote, closing connection." + ); + cancellation.cancel(); + return; + } + + debug!("Sent handshake packet to: {}", tcp_sock); + } + + // send first packet + if sock_index == 0 { + if let Some(ref enc) = *encryption { + enc.encrypt(&mut buf_r[..size]); + } + if tcp_sock.send(&buf_r[..size]).await.is_none() { + cancellation.cancel(); + return; + } + } + + tcp_socks.push(tcp_sock.clone()); + + // spawn "fastpath" UDP socket and task, this will offload main task + // from forwarding UDP packets + let packet_received = packet_received.clone(); + let cancellation = cancellation.clone(); + let udp_socks = udp_socks.clone(); + let udp_sock_index = udp_sock_index.clone(); + let encryption = encryption.clone(); + tokio::spawn(async move { + let mut buf_tcp = [0u8; MAX_PACKET_LEN]; + loop { + tokio::select! { + biased; + _ = cancellation.cancelled() => { + debug!("Closing connection requested for {addr}, closing connection {sock_index}"); + break; + }, + res = tcp_sock.recv(&mut buf_tcp) => { + match res { + Some(size) => { + let udp_sock_index = udp_sock_index.fetch_add(1, Ordering::Relaxed) % udp_socks_amount; + let udp_sock = unsafe { udp_socks.get_unchecked(udp_sock_index) }; + if let Some(ref enc) = *encryption { + enc.decrypt(&mut buf_tcp[..size]); + } + if let Err(e) = udp_sock.send(&buf_tcp[..size]).await { + debug!("Unable to send UDP packet to {}: {}, closing connection {sock_index}", e, addr); + break; + } + }, + None => { + debug!("TCP connection closed on {addr}, closing connection {sock_index}"); + break; + }, + } + packet_received.notify_waiters(); + }, + }; + } + cancellation.cancel(); + }); + debug!( + "inserted fake TCP socket into connection table {remote_addr} {sock_index}" + ); } - debug!("Sent handshake packet to: {}", sock); - } - - // send first packet - if sock.send(&buf_r[..size]).await.is_none() { - continue; - } - - assert!(connections - .write() - .await - .insert(addr, sock.clone()) - .is_none()); - debug!("inserted fake TCP socket into connection table"); - - // spawn "fastpath" UDP socket and task, this will offload main task - // from forwarding UDP packets - - let packet_received = Arc::new(Notify::new()); - let quit = CancellationToken::new(); - - for i in 0..num_cpus { - let sock = sock.clone(); - let quit = quit.clone(); - let packet_received = packet_received.clone(); - - tokio::spawn(async move { - let mut buf_udp = [0u8; MAX_PACKET_LEN]; - let mut buf_tcp = [0u8; MAX_PACKET_LEN]; - let udp_sock = new_udp_reuseport(local_addr); - udp_sock.connect(addr).await.unwrap(); - - loop { - tokio::select! { - Ok(size) = udp_sock.recv(&mut buf_udp) => { - if sock.send(&buf_udp[..size]).await.is_none() { - debug!("removed fake TCP socket from connections table"); - quit.cancel(); - return; - } - - packet_received.notify_one(); - }, - res = sock.recv(&mut buf_tcp) => { - match res { - Some(size) => { - if size > 0 { - if let Err(e) = udp_sock.send(&buf_tcp[..size]).await { - error!("Unable to send UDP packet to {}: {}, closing connection", e, addr); - quit.cancel(); - return; + for (sock_index, udp_sock) in udp_socks.iter().enumerate() { + let udp_sock = udp_sock.clone(); + let packet_received = packet_received.clone(); + let cancellation = cancellation.clone(); + let tcp_socks = tcp_socks.clone(); + let tcp_sock_index = tcp_sock_index.clone(); + let encryption = encryption.clone(); + tokio::spawn(async move { + let mut buf_udp = [0u8; MAX_PACKET_LEN]; + loop { + let read_timeout = time::sleep(UDP_TTL); + tokio::select! { + biased; + _ = cancellation.cancelled() => { + debug!("Closing connection requested for {addr}, closing connection UDP {sock_index}"); + break; + }, + _ = packet_received.notified() => {}, + res = udp_sock.recv(&mut buf_udp) => { + match res { + Ok(size) => { + let tcp_sock_index = tcp_sock_index.fetch_add(1, Ordering::Relaxed) % tcp_socks_amount; + let tcp_sock = unsafe { tcp_socks.get_unchecked(tcp_sock_index) }; + if let Some(ref enc) = *encryption { + enc.encrypt(&mut buf_udp[..size]); } + if tcp_sock.send(&buf_udp[..size]).await.is_none() { + debug!("Unable to send TCP traffic to {addr}, closing connection {sock_index}"); + break; + } + }, + Err(e) => { + debug!("UDP connection closed on {addr}: {e}, closing connection {sock_index}"); + break; } - }, - None => { - debug!("removed fake TCP socket from connections table"); - quit.cancel(); - return; - }, - } + }; - packet_received.notify_one(); - }, - _ = quit.cancelled() => { - debug!("worker {} terminated", i); - return; - }, - }; - } - }); - } - - let connections = connections.clone(); - tokio::spawn(async move { - loop { - let read_timeout = time::sleep(UDP_TTL); - let packet_received_fut = packet_received.notified(); - - tokio::select! { - _ = read_timeout => { - info!("No traffic seen in the last {:?}, closing connection", UDP_TTL); - connections.write().await.remove(&addr); - debug!("removed fake TCP socket from connections table"); - - quit.cancel(); - return; - }, - _ = quit.cancelled() => { - connections.write().await.remove(&addr); - debug!("removed fake TCP socket from connections table"); - return; - }, - _ = packet_received_fut => {}, - } + }, + _ = read_timeout => { + debug!("No traffic seen in the last {:?} on {addr}, closing connection {sock_index}", UDP_TTL); + break; + }, + }; + } + cancellation.cancel(); + info!("Connention {addr} to {remote_addr} closed {sock_index}"); + }); } }); } }); - tokio::join!(main_loop).0.unwrap() + tokio::join!(main_loop).0.unwrap(); + Ok(()) } diff --git a/phantun/src/bin/server.rs b/phantun/src/bin/server.rs index c5dfeea..b8b53f1 100644 --- a/phantun/src/bin/server.rs +++ b/phantun/src/bin/server.rs @@ -3,6 +3,7 @@ use fake_tcp::packet::MAX_PACKET_LEN; use fake_tcp::Stack; use log::{debug, error, info}; use phantun::utils::{assign_ipv6_address, new_udp_reuseport}; +use phantun::Encryption; use std::fs; use std::io; use std::net::Ipv4Addr; @@ -101,6 +102,23 @@ async fn main() -> io::Result<()> { Note: ensure this file's size does not exceed the MTU of the outgoing interface. \ The content is always sent out in a single packet and will not be further segmented") ) + .arg( + Arg::new("encryption") + .long("encryption") + .required(false) + .value_name("encryption") + .help("Specify an encryption algorithm for using in TCP connections. \n\ + Server and client should use the same encryption. \n\ + Currently XOR is only supported and the format should be 'xor:key'.") + ) + .arg( + Arg::new("udp_connections") + .long("udp-connections") + .required(false) + .value_name("number") + .help("Number of UDP connections per each TCP connections.") + .default_value("8") + ) .get_matches(); let local_port: u16 = matches @@ -114,7 +132,6 @@ async fn main() -> io::Result<()> { .expect("bad remote address or host") .next() .expect("unable to resolve remote host name"); - info!("Remote address is: {}", remote_addr); let tun_local: Ipv4Addr = matches @@ -128,6 +145,21 @@ async fn main() -> io::Result<()> { .parse() .expect("bad peer address for Tun interface"); + let udp_socks_amount: usize = matches + .get_one::("udp_connections") + .unwrap() + .parse() + .expect("Unspecified number of UDP connections per each client"); + if udp_socks_amount == 0 { + panic!("UDP connections should be greater than or equal to 1"); + } + + let encryption = matches + .get_one::("encryption") + .map(Encryption::from); + debug!("Encryption in use: {:?}", encryption); + let encryption = Arc::new(encryption); + let (tun_local6, tun_peer6) = if matches.get_flag("ipv4_only") { (None, None) } else { @@ -172,97 +204,138 @@ async fn main() -> io::Result<()> { info!("Listening on {}", local_port); let main_loop = tokio::spawn(async move { - let mut buf_udp = [0u8; MAX_PACKET_LEN]; - let mut buf_tcp = [0u8; MAX_PACKET_LEN]; - - loop { - let sock = Arc::new(stack.accept().await); - info!("New connection: {}", sock); + 'main_loop: loop { + let tcp_sock = Arc::new(stack.accept().await); + info!("New connection: {}", tcp_sock); if let Some(ref p) = handshake_packet { - if sock.send(p).await.is_none() { + if tcp_sock.send(p).await.is_none() { error!("Failed to send handshake packet to remote, closing connection."); continue; } - debug!("Sent handshake packet to: {}", sock); + debug!("Sent handshake packet to: {}", tcp_sock); } - let packet_received = Arc::new(Notify::new()); - let quit = CancellationToken::new(); let udp_sock = UdpSocket::bind(if remote_addr.is_ipv4() { "0.0.0.0:0" } else { "[::]:0" }) - .await?; - let local_addr = udp_sock.local_addr()?; + .await; + + let udp_sock = match udp_sock { + Ok(udp_sock) => udp_sock, + Err(err) => { + error!("No more UDP address is available: {err}"); + continue; + } + }; + + let local_addr = udp_sock.local_addr().unwrap(); drop(udp_sock); - for i in 0..num_cpus { - let sock = sock.clone(); - let quit = quit.clone(); + let cancellation = CancellationToken::new(); + let packet_received = Arc::new(Notify::new()); + let udp_socks: Vec<_> = { + let mut socks = Vec::with_capacity(udp_socks_amount); + for _ in 0..udp_socks_amount { + let udp_sock = new_udp_reuseport(local_addr); + if let Err(err) = udp_sock.connect(remote_addr).await { + error!("UDP couldn't connect to {remote_addr}: {err}, closing connection"); + continue 'main_loop; + } + socks.push(Arc::new(udp_sock)); + } + socks + }; + + for udp_sock in &udp_socks { + let tcp_sock = tcp_sock.clone(); + let cancellation = cancellation.clone(); + let encryption = encryption.clone(); let packet_received = packet_received.clone(); - let udp_sock = new_udp_reuseport(local_addr); - + let udp_sock = udp_sock.clone(); tokio::spawn(async move { - udp_sock.connect(remote_addr).await.unwrap(); - + let mut buf_udp = [0u8; MAX_PACKET_LEN]; loop { + let read_timeout = time::sleep(UDP_TTL); tokio::select! { - Ok(size) = udp_sock.recv(&mut buf_udp) => { - if sock.send(&buf_udp[..size]).await.is_none() { - quit.cancel(); - return; - } - - packet_received.notify_one(); + biased; + _ = cancellation.cancelled() => { + debug!("Closing connection requested for {local_addr}, closing connection"); + break; }, - res = sock.recv(&mut buf_tcp) => { + _ = read_timeout => { + debug!("No traffic seen in the last {:?}, closing connection {local_addr}", UDP_TTL); + break; + }, + _ = packet_received.notified() => {}, + res = udp_sock.recv(&mut buf_udp) => { match res { - Some(size) => { - if size > 0 { - if let Err(e) = udp_sock.send(&buf_tcp[..size]).await { - error!("Unable to send UDP packet to {}: {}, closing connection", e, remote_addr); - quit.cancel(); - return; - } + Ok(size) => { + if let Some(ref enc) = *encryption { + enc.encrypt(&mut buf_udp[..size]); + } + if tcp_sock.send(&buf_udp[..size]).await.is_none() { + debug!("Unable to send TCP packet to {remote_addr}, closing connection"); + break; } }, - None => { - quit.cancel(); - return; - }, - } + Err(err) => { + debug!("UDP connection closed on {remote_addr}: {err}, closing connection"); + break; - packet_received.notify_one(); - }, - _ = quit.cancelled() => { - debug!("worker {} terminated", i); - return; + } + }; }, }; } + cancellation.cancel(); }); } - + let tcp_sock = tcp_sock.clone(); + let encryption = encryption.clone(); + let packet_received = packet_received.clone(); + let cancellation = cancellation.clone(); tokio::spawn(async move { + let mut buf_tcp = [0u8; MAX_PACKET_LEN]; + let mut udp_sock_index = 0; + loop { - let read_timeout = time::sleep(UDP_TTL); - let packet_received_fut = packet_received.notified(); - tokio::select! { - _ = read_timeout => { - info!("No traffic seen in the last {:?}, closing connection", UDP_TTL); - - quit.cancel(); - return; + biased; + _ = cancellation.cancelled() => { + debug!("Closing connection requested for {local_addr}, closing connection"); + break; }, - _ = packet_received_fut => {}, - } + res = tcp_sock.recv(&mut buf_tcp) => { + match res { + Some(size) => { + udp_sock_index = (udp_sock_index + 1) % udp_socks_amount; + let udp_sock = unsafe { udp_socks.get_unchecked(udp_sock_index) }; + if let Some(ref enc) = *encryption { + enc.decrypt(&mut buf_tcp[..size]); + } + if let Err(e) = udp_sock.send(&buf_tcp[..size]).await { + debug!("Unable to send UDP packet to {local_addr}: {e}, closing connection"); + break; + } + }, + None => { + debug!("TCP connection closed on {local_addr}"); + break; + }, + }; + packet_received.notify_waiters(); + }, + }; } + cancellation.cancel(); + info!("Connention {local_addr} closed"); }); } }); - tokio::join!(main_loop).0.unwrap() + tokio::join!(main_loop).0.unwrap(); + Ok(()) } diff --git a/phantun/src/lib.rs b/phantun/src/lib.rs index 2e24c3f..f9c2842 100644 --- a/phantun/src/lib.rs +++ b/phantun/src/lib.rs @@ -1,5 +1,142 @@ +use fake_tcp::packet::MAX_PACKET_LEN; +use std::convert::From; +use std::iter; use std::time::Duration; pub mod utils; -pub const UDP_TTL: Duration = Duration::from_secs(180); +pub const UDP_TTL: Duration = Duration::from_secs(60); + +#[derive(Debug)] +pub enum Encryption { + Xor(Vec), +} + +impl From for Encryption { + fn from(input: String) -> Self { + Self::from(input.as_str()) + } +} + +impl From<&String> for Encryption { + fn from(input: &String) -> Self { + Self::from(input.as_str()) + } +} + +impl From<&str> for Encryption { + fn from(input: &str) -> Self { + let input = input.to_lowercase(); + let input: Vec<&str> = input.splitn(2, ':').collect(); + match input[0] { + "xor" => { + if input.len() < 2 { + panic!("xor key should be provided"); + } else { + return Self::Xor( + iter::repeat(input[1]) + .take((MAX_PACKET_LEN as f32 / input[1].len() as f32).ceil() as usize) + .collect::()[..MAX_PACKET_LEN] + .into(), + ); + } + } + _ => { + panic!("input[0] encryption is not supported."); + } + } + } +} + +impl Encryption { + // in-place encryption + pub fn encrypt(&self, input: &mut [u8]) { + match self { + Self::Xor(ref key) => { + let len = input.len(); + let input = &mut input[..len]; + let key = &key[..len]; + for i in 0..len { + input[i] ^= key[i]; + } + } + } + } + + // in-place decryption + pub fn decrypt(&self, input: &mut [u8]) { + match self { + Self::Xor(ref key) => { + let len = input.len(); + let input = &mut input[..len]; + let key = &key[..len]; + for i in 0..len { + input[i] ^= key[i]; + } + } + } + } +} + +#[cfg(test)] +mod tests { + use super::Encryption; + use rand::Rng; + + fn xor_encryption_test(model: &str) { + let enc = Encryption::from(model); + let origin: Vec = rand::thread_rng() + .sample_iter(&rand::distributions::Standard) + .take(1500) + .collect(); + let mut test = origin.clone(); + enc.encrypt(&mut test); + let mut is_equal = true; + for (i, _) in origin.iter().enumerate() { + if origin[i] != test[i] { + is_equal = false; + } + } + assert!(!is_equal); + enc.decrypt(&mut test); + for (i, _) in origin.iter().enumerate() { + assert_eq!(origin[i], test[i]); + } + } + + #[test] + #[should_panic] + fn xor_encryption_with_no_key() { + xor_encryption_test("xor"); + } + + #[test] + fn xor_encryption_with_min_key() { + let key: String = rand::thread_rng() + .sample_iter(&rand::distributions::Alphanumeric) + .take(1) + .map(char::from) + .collect(); + xor_encryption_test(format!("xor:{key}").as_str()); + } + + #[test] + fn xor_encryption_with_max_key() { + let key: String = rand::thread_rng() + .sample_iter(&rand::distributions::Alphanumeric) + .take(1500) + .map(char::from) + .collect(); + xor_encryption_test(format!("xor:{key}").as_str()); + } + + #[test] + fn xor_encryption_with_too_long_key() { + let key: String = rand::thread_rng() + .sample_iter(&rand::distributions::Alphanumeric) + .take(1501) + .map(char::from) + .collect(); + xor_encryption_test(format!("xor:{key}").as_str()); + } +}