From 13dfdaac980aeb8e083fee41947d576d9dc62c41 Mon Sep 17 00:00:00 2001 From: Saber Haj Rabiee Date: Mon, 12 Dec 2022 00:58:10 -0800 Subject: [PATCH] Removes unsafes and unwraps, sets default tcp connections to 1 --- fake-tcp/src/lib.rs | 86 ++++++++++++++++++++++++++++----------- phantun/Cargo.toml | 14 ++++++- phantun/src/bin/client.rs | 18 ++++---- phantun/src/bin/server.rs | 12 +++--- phantun/src/lib.rs | 11 +++-- 5 files changed, 95 insertions(+), 46 deletions(-) diff --git a/fake-tcp/src/lib.rs b/fake-tcp/src/lib.rs index 12672e5..0dea62e 100644 --- a/fake-tcp/src/lib.rs +++ b/fake-tcp/src/lib.rs @@ -192,7 +192,7 @@ impl Socket { match self.state { State::Established => { self.incoming.recv_async().await.ok().and_then(|raw_buf| { - let (_v4_packet, tcp_packet) = parse_ip_packet(&raw_buf).unwrap(); + let (_v4_packet, tcp_packet) = parse_ip_packet(&raw_buf)?; if (tcp_packet.get_flags() & tcp::TcpFlags::RST) != 0 { info!("Connection {} reset by peer", self); @@ -229,15 +229,27 @@ impl Socket { State::Idle => { let buf = self.build_tcp_packet(tcp::TcpFlags::SYN | tcp::TcpFlags::ACK, None); // ACK set by constructor - self.tun.send(&buf).await.unwrap(); + if let Err(err) = self.tun.send(&buf).await { + error!("Sent SYN + ACK error: {err}"); + return; + } self.state = State::SynReceived; info!("Sent SYN + ACK to client"); } State::SynReceived => { let res = time::timeout(TIMEOUT, self.incoming.recv_async()).await; if let Ok(buf) = res { - let buf = buf.unwrap(); - let (_v4_packet, tcp_packet) = parse_ip_packet(&buf).unwrap(); + let buf = match buf { + Ok(buf) => buf, + Err(err) => { + error!("incoming channel recv_async error: {err}"); + return; + } + }; + let (_v4_packet, tcp_packet) = match parse_ip_packet(&buf) { + Some(packet) => packet, + None => return, + }; if (tcp_packet.get_flags() & tcp::TcpFlags::RST) != 0 { return; @@ -273,15 +285,24 @@ impl Socket { match self.state { State::Idle => { let buf = self.build_tcp_packet(tcp::TcpFlags::SYN, None); - self.tun.send(&buf).await.unwrap(); + if let Err(err) = self.tun.send(&buf).await { + error!("Send SYN error: {err}"); + return None; + } self.state = State::SynSent; info!("Sent SYN to server"); } State::SynSent => { match time::timeout(TIMEOUT, self.incoming.recv_async()).await { Ok(buf) => { - let buf = buf.unwrap(); - let (_v4_packet, tcp_packet) = parse_ip_packet(&buf).unwrap(); + let buf = match buf { + Ok(buf) => buf, + Err(err) => { + error!("incoming channel error: {err}"); + return None; + } + }; + let (_v4_packet, tcp_packet) = parse_ip_packet(&buf)?; if (tcp_packet.get_flags() & tcp::TcpFlags::RST) != 0 { return None; @@ -298,7 +319,10 @@ impl Socket { // send ACK to finish handshake let buf = self.build_tcp_packet(tcp::TcpFlags::ACK, None); - self.tun.send(&buf).await.unwrap(); + if let Err(err) = self.tun.send(&buf).await { + error!("Send ACK error: {err}"); + return None; + } self.state = State::Established; @@ -327,7 +351,9 @@ impl Drop for Socket { // dissociates ourself from the dispatch map assert!(self.shared.tuples.remove(&tuple).is_some()); // purge cache - self.shared.tuples_purge.send(tuple).unwrap(); + if let Err(err) = self.shared.tuples_purge.send(tuple) { + error!("Send error in tuples_purge: {err}"); + } let buf = build_tcp_packet( self.local_addr, @@ -425,7 +451,7 @@ impl Stack { Entry::Vacant(v) => { let tun_index = self.shared.tun_index.fetch_add(1, Ordering::Relaxed) % self.shared.tun.len(); - let tun = unsafe { self.shared.tun.get_unchecked(tun_index).clone() }; + let tun = self.shared.tun[tun_index].clone(); let (sock, incoming) = Socket::new( self.shared.clone(), tun, @@ -434,7 +460,7 @@ impl Stack { None, State::Idle, ); - v.insert(incoming.clone()); + v.insert(incoming); sock } }; @@ -456,7 +482,13 @@ impl Stack { tokio::select! { size = tun.recv(&mut buf) => { - let size = size.unwrap(); + let size = match size { + Ok(size) => size, + Err(err) => { + error!("Couldn't read tun buf: {err}"); + continue; + } + }; buf.truncate(size); let buf = buf.freeze(); @@ -483,8 +515,10 @@ impl Stack { if let Some(c) = sender { trace!("Storing connection information into local tuples"); tuples.insert(tuple, c.clone()); - c.send_async(buf).await.unwrap(); - continue; + if let Err(err) = c.send_async(buf).await { + error!("Couldn't send to shared tuples channel: {err}"); + } + continue } } @@ -519,10 +553,10 @@ impl Stack { None, ); let tun_index = shared.tun_index.fetch_add(1, Ordering::Relaxed) % shared.tun.len(); - let tun = unsafe { - shared.tun.get_unchecked(tun_index) - }; - tun.try_send(&buf).unwrap(); + let tun = shared.tun[tun_index].clone(); + if let Err(err) = tun.try_send(&buf) { + error!("tun send error: {err}"); + } } } else if (tcp_packet.get_flags() & tcp::TcpFlags::RST) == 0 { info!("Unknown TCP packet from {}, sending RST", remote_addr); @@ -535,10 +569,10 @@ impl Stack { None, ); let tun_index = shared.tun_index.fetch_add(1, Ordering::Relaxed) % shared.tun.len(); - let tun = unsafe { - shared.tun.get_unchecked(tun_index) - }; - tun.try_send(&buf).unwrap(); + let tun = shared.tun[tun_index].clone(); + if let Err(err) = tun.try_send(&buf) { + error!("tun send error: {err}"); + } } } None => { @@ -547,7 +581,13 @@ impl Stack { } }, tuple = tuples_purge.recv() => { - let tuple = tuple.unwrap(); + let tuple = match tuple { + Ok(tuple) => tuple, + Err(err) => { + error!("tuples_purge recv error: {err}"); + continue; + } + }; tuples.remove(&tuple); trace!("Removed cached tuple: {:?}", tuple); } diff --git a/phantun/Cargo.toml b/phantun/Cargo.toml index 9ac2d8a..a9b8800 100644 --- a/phantun/Cargo.toml +++ b/phantun/Cargo.toml @@ -10,8 +10,18 @@ description = """ Transforms UDP stream into (fake) TCP streams that can go through Layer 3 & Layer 4 (NAPT) firewalls/NATs. """ + + +[[bin]] +name = "phantunc" +path = "src/bin/client.rs" + +[[bin]] +name = "phantuns" +path = "src/bin/server.rs" + [dependencies] -clap = { version = "4.0", features = ["cargo"] } +clap = { version = "4.0", features = ["cargo", "string"] } socket2 = { version = "0.4", features = ["all"] } fake-tcp = { path = "../fake-tcp", version = "0.5" } tokio = { version = "1.14", features = ["full"] } @@ -21,7 +31,7 @@ pretty_env_logger = "0.4" tokio-tun = "0.7" num_cpus = "1.13" neli = "0.6" -nix = "0.25" +nix = "0.26" [dev-dependencies] rand = "0.8.5" diff --git a/phantun/src/bin/client.rs b/phantun/src/bin/client.rs index 8b4378a..d74bbd0 100644 --- a/phantun/src/bin/client.rs +++ b/phantun/src/bin/client.rs @@ -20,6 +20,9 @@ use phantun::UDP_TTL; async fn main() -> io::Result<()> { pretty_env_logger::init(); + let num_cpus = num_cpus::get(); + info!("{} cores available", num_cpus); + let matches = Command::new("Phantun Client") .version(crate_version!()) .author("Datong Sun (github.com/dndx)") @@ -72,7 +75,7 @@ async fn main() -> io::Result<()> { .required(false) .help("Only use IPv4 address when connecting to remote") .action(ArgAction::SetTrue) - .conflicts_with_all(&["tun_local6", "tun_peer6"]), + .conflicts_with_all(["tun_local6", "tun_peer6"]), ) .arg( Arg::new("tun_local6") @@ -108,7 +111,7 @@ async fn main() -> io::Result<()> { .required(false) .value_name("number") .help("Number of TCP connections per each client.") - .default_value("8") + .default_value("1") ) .arg( Arg::new("udp_connections") @@ -116,8 +119,8 @@ async fn main() -> io::Result<()> { .required(false) .value_name("number") .help("Number of UDP connections per each client.") - .default_value("8") - ) + .default_value(num_cpus.to_string()) + ) .arg( Arg::new("encryption") .long("encryption") @@ -202,9 +205,6 @@ async fn main() -> io::Result<()> { .transpose()?, ); - let num_cpus = num_cpus::get(); - info!("{} cores available", num_cpus); - let tun = TunBuilder::new() .name(tun_name) // if name is empty, then it is set by kernel. .tap(false) // false (default): TUN, true: TAP. @@ -313,7 +313,7 @@ async fn main() -> io::Result<()> { match res { Some(size) => { let udp_sock_index = udp_sock_index.fetch_add(1, Ordering::Relaxed) % udp_socks_amount; - let udp_sock = unsafe { udp_socks.get_unchecked(udp_sock_index) }; + let udp_sock = udp_socks[udp_sock_index].clone(); if let Some(ref enc) = *encryption { enc.decrypt(&mut buf_tcp[..size]); } @@ -360,7 +360,7 @@ async fn main() -> io::Result<()> { match res { Ok(size) => { let tcp_sock_index = tcp_sock_index.fetch_add(1, Ordering::Relaxed) % tcp_socks_amount; - let tcp_sock = unsafe { tcp_socks.get_unchecked(tcp_sock_index) }; + let tcp_sock = tcp_socks[tcp_sock_index].clone(); if let Some(ref enc) = *encryption { enc.encrypt(&mut buf_udp[..size]); } diff --git a/phantun/src/bin/server.rs b/phantun/src/bin/server.rs index b8b53f1..92365e1 100644 --- a/phantun/src/bin/server.rs +++ b/phantun/src/bin/server.rs @@ -20,6 +20,9 @@ use phantun::UDP_TTL; async fn main() -> io::Result<()> { pretty_env_logger::init(); + let num_cpus = num_cpus::get(); + info!("{} cores available", num_cpus); + let matches = Command::new("Phantun Server") .version(crate_version!()) .author("Datong Sun (github.com/dndx)") @@ -72,7 +75,7 @@ async fn main() -> io::Result<()> { .required(false) .help("Do not assign IPv6 addresses to Tun interface") .action(ArgAction::SetTrue) - .conflicts_with_all(&["tun_local6", "tun_peer6"]), + .conflicts_with_all(["tun_local6", "tun_peer6"]), ) .arg( Arg::new("tun_local6") @@ -117,7 +120,7 @@ async fn main() -> io::Result<()> { .required(false) .value_name("number") .help("Number of UDP connections per each TCP connections.") - .default_value("8") + .default_value(num_cpus.to_string()) ) .get_matches(); @@ -179,9 +182,6 @@ async fn main() -> io::Result<()> { .map(fs::read) .transpose()?; - let num_cpus = num_cpus::get(); - info!("{} cores available", num_cpus); - let tun = TunBuilder::new() .name(tun_name) // if name is empty, then it is set by kernel. .tap(false) // false (default): TUN, true: TAP. @@ -312,7 +312,7 @@ async fn main() -> io::Result<()> { match res { Some(size) => { udp_sock_index = (udp_sock_index + 1) % udp_socks_amount; - let udp_sock = unsafe { udp_socks.get_unchecked(udp_sock_index) }; + let udp_sock = udp_socks[udp_sock_index].clone(); if let Some(ref enc) = *encryption { enc.decrypt(&mut buf_tcp[..size]); } diff --git a/phantun/src/lib.rs b/phantun/src/lib.rs index f9c2842..78e6702 100644 --- a/phantun/src/lib.rs +++ b/phantun/src/lib.rs @@ -1,6 +1,5 @@ use fake_tcp::packet::MAX_PACKET_LEN; use std::convert::From; -use std::iter; use std::time::Duration; pub mod utils; @@ -33,12 +32,12 @@ impl From<&str> for Encryption { if input.len() < 2 { panic!("xor key should be provided"); } else { - return Self::Xor( - iter::repeat(input[1]) - .take((MAX_PACKET_LEN as f32 / input[1].len() as f32).ceil() as usize) - .collect::()[..MAX_PACKET_LEN] + Self::Xor( + input[1] + .repeat((MAX_PACKET_LEN as f32 / input[1].len() as f32).ceil() as usize) + [..MAX_PACKET_LEN] .into(), - ); + ) } } _ => {