Removes unsafes and unwraps, sets default tcp connections to 1

This commit is contained in:
Saber Haj Rabiee 2022-12-12 00:58:10 -08:00
parent ea9b6575fc
commit 13dfdaac98
5 changed files with 95 additions and 46 deletions

View File

@ -192,7 +192,7 @@ impl Socket {
match self.state { match self.state {
State::Established => { State::Established => {
self.incoming.recv_async().await.ok().and_then(|raw_buf| { self.incoming.recv_async().await.ok().and_then(|raw_buf| {
let (_v4_packet, tcp_packet) = parse_ip_packet(&raw_buf).unwrap(); let (_v4_packet, tcp_packet) = parse_ip_packet(&raw_buf)?;
if (tcp_packet.get_flags() & tcp::TcpFlags::RST) != 0 { if (tcp_packet.get_flags() & tcp::TcpFlags::RST) != 0 {
info!("Connection {} reset by peer", self); info!("Connection {} reset by peer", self);
@ -229,15 +229,27 @@ impl Socket {
State::Idle => { State::Idle => {
let buf = self.build_tcp_packet(tcp::TcpFlags::SYN | tcp::TcpFlags::ACK, None); let buf = self.build_tcp_packet(tcp::TcpFlags::SYN | tcp::TcpFlags::ACK, None);
// ACK set by constructor // ACK set by constructor
self.tun.send(&buf).await.unwrap(); if let Err(err) = self.tun.send(&buf).await {
error!("Sent SYN + ACK error: {err}");
return;
}
self.state = State::SynReceived; self.state = State::SynReceived;
info!("Sent SYN + ACK to client"); info!("Sent SYN + ACK to client");
} }
State::SynReceived => { State::SynReceived => {
let res = time::timeout(TIMEOUT, self.incoming.recv_async()).await; let res = time::timeout(TIMEOUT, self.incoming.recv_async()).await;
if let Ok(buf) = res { if let Ok(buf) = res {
let buf = buf.unwrap(); let buf = match buf {
let (_v4_packet, tcp_packet) = parse_ip_packet(&buf).unwrap(); Ok(buf) => buf,
Err(err) => {
error!("incoming channel recv_async error: {err}");
return;
}
};
let (_v4_packet, tcp_packet) = match parse_ip_packet(&buf) {
Some(packet) => packet,
None => return,
};
if (tcp_packet.get_flags() & tcp::TcpFlags::RST) != 0 { if (tcp_packet.get_flags() & tcp::TcpFlags::RST) != 0 {
return; return;
@ -273,15 +285,24 @@ impl Socket {
match self.state { match self.state {
State::Idle => { State::Idle => {
let buf = self.build_tcp_packet(tcp::TcpFlags::SYN, None); let buf = self.build_tcp_packet(tcp::TcpFlags::SYN, None);
self.tun.send(&buf).await.unwrap(); if let Err(err) = self.tun.send(&buf).await {
error!("Send SYN error: {err}");
return None;
}
self.state = State::SynSent; self.state = State::SynSent;
info!("Sent SYN to server"); info!("Sent SYN to server");
} }
State::SynSent => { State::SynSent => {
match time::timeout(TIMEOUT, self.incoming.recv_async()).await { match time::timeout(TIMEOUT, self.incoming.recv_async()).await {
Ok(buf) => { Ok(buf) => {
let buf = buf.unwrap(); let buf = match buf {
let (_v4_packet, tcp_packet) = parse_ip_packet(&buf).unwrap(); Ok(buf) => buf,
Err(err) => {
error!("incoming channel error: {err}");
return None;
}
};
let (_v4_packet, tcp_packet) = parse_ip_packet(&buf)?;
if (tcp_packet.get_flags() & tcp::TcpFlags::RST) != 0 { if (tcp_packet.get_flags() & tcp::TcpFlags::RST) != 0 {
return None; return None;
@ -298,7 +319,10 @@ impl Socket {
// send ACK to finish handshake // send ACK to finish handshake
let buf = self.build_tcp_packet(tcp::TcpFlags::ACK, None); let buf = self.build_tcp_packet(tcp::TcpFlags::ACK, None);
self.tun.send(&buf).await.unwrap(); if let Err(err) = self.tun.send(&buf).await {
error!("Send ACK error: {err}");
return None;
}
self.state = State::Established; self.state = State::Established;
@ -327,7 +351,9 @@ impl Drop for Socket {
// dissociates ourself from the dispatch map // dissociates ourself from the dispatch map
assert!(self.shared.tuples.remove(&tuple).is_some()); assert!(self.shared.tuples.remove(&tuple).is_some());
// purge cache // purge cache
self.shared.tuples_purge.send(tuple).unwrap(); if let Err(err) = self.shared.tuples_purge.send(tuple) {
error!("Send error in tuples_purge: {err}");
}
let buf = build_tcp_packet( let buf = build_tcp_packet(
self.local_addr, self.local_addr,
@ -425,7 +451,7 @@ impl Stack {
Entry::Vacant(v) => { Entry::Vacant(v) => {
let tun_index = self.shared.tun_index.fetch_add(1, Ordering::Relaxed) let tun_index = self.shared.tun_index.fetch_add(1, Ordering::Relaxed)
% self.shared.tun.len(); % self.shared.tun.len();
let tun = unsafe { self.shared.tun.get_unchecked(tun_index).clone() }; let tun = self.shared.tun[tun_index].clone();
let (sock, incoming) = Socket::new( let (sock, incoming) = Socket::new(
self.shared.clone(), self.shared.clone(),
tun, tun,
@ -434,7 +460,7 @@ impl Stack {
None, None,
State::Idle, State::Idle,
); );
v.insert(incoming.clone()); v.insert(incoming);
sock sock
} }
}; };
@ -456,7 +482,13 @@ impl Stack {
tokio::select! { tokio::select! {
size = tun.recv(&mut buf) => { size = tun.recv(&mut buf) => {
let size = size.unwrap(); let size = match size {
Ok(size) => size,
Err(err) => {
error!("Couldn't read tun buf: {err}");
continue;
}
};
buf.truncate(size); buf.truncate(size);
let buf = buf.freeze(); let buf = buf.freeze();
@ -483,8 +515,10 @@ impl Stack {
if let Some(c) = sender { if let Some(c) = sender {
trace!("Storing connection information into local tuples"); trace!("Storing connection information into local tuples");
tuples.insert(tuple, c.clone()); tuples.insert(tuple, c.clone());
c.send_async(buf).await.unwrap(); if let Err(err) = c.send_async(buf).await {
continue; error!("Couldn't send to shared tuples channel: {err}");
}
continue
} }
} }
@ -519,10 +553,10 @@ impl Stack {
None, None,
); );
let tun_index = shared.tun_index.fetch_add(1, Ordering::Relaxed) % shared.tun.len(); let tun_index = shared.tun_index.fetch_add(1, Ordering::Relaxed) % shared.tun.len();
let tun = unsafe { let tun = shared.tun[tun_index].clone();
shared.tun.get_unchecked(tun_index) if let Err(err) = tun.try_send(&buf) {
}; error!("tun send error: {err}");
tun.try_send(&buf).unwrap(); }
} }
} else if (tcp_packet.get_flags() & tcp::TcpFlags::RST) == 0 { } else if (tcp_packet.get_flags() & tcp::TcpFlags::RST) == 0 {
info!("Unknown TCP packet from {}, sending RST", remote_addr); info!("Unknown TCP packet from {}, sending RST", remote_addr);
@ -535,10 +569,10 @@ impl Stack {
None, None,
); );
let tun_index = shared.tun_index.fetch_add(1, Ordering::Relaxed) % shared.tun.len(); let tun_index = shared.tun_index.fetch_add(1, Ordering::Relaxed) % shared.tun.len();
let tun = unsafe { let tun = shared.tun[tun_index].clone();
shared.tun.get_unchecked(tun_index) if let Err(err) = tun.try_send(&buf) {
}; error!("tun send error: {err}");
tun.try_send(&buf).unwrap(); }
} }
} }
None => { None => {
@ -547,7 +581,13 @@ impl Stack {
} }
}, },
tuple = tuples_purge.recv() => { tuple = tuples_purge.recv() => {
let tuple = tuple.unwrap(); let tuple = match tuple {
Ok(tuple) => tuple,
Err(err) => {
error!("tuples_purge recv error: {err}");
continue;
}
};
tuples.remove(&tuple); tuples.remove(&tuple);
trace!("Removed cached tuple: {:?}", tuple); trace!("Removed cached tuple: {:?}", tuple);
} }

View File

@ -10,8 +10,18 @@ description = """
Transforms UDP stream into (fake) TCP streams that can go through Transforms UDP stream into (fake) TCP streams that can go through
Layer 3 & Layer 4 (NAPT) firewalls/NATs. Layer 3 & Layer 4 (NAPT) firewalls/NATs.
""" """
[[bin]]
name = "phantunc"
path = "src/bin/client.rs"
[[bin]]
name = "phantuns"
path = "src/bin/server.rs"
[dependencies] [dependencies]
clap = { version = "4.0", features = ["cargo"] } clap = { version = "4.0", features = ["cargo", "string"] }
socket2 = { version = "0.4", features = ["all"] } socket2 = { version = "0.4", features = ["all"] }
fake-tcp = { path = "../fake-tcp", version = "0.5" } fake-tcp = { path = "../fake-tcp", version = "0.5" }
tokio = { version = "1.14", features = ["full"] } tokio = { version = "1.14", features = ["full"] }
@ -21,7 +31,7 @@ pretty_env_logger = "0.4"
tokio-tun = "0.7" tokio-tun = "0.7"
num_cpus = "1.13" num_cpus = "1.13"
neli = "0.6" neli = "0.6"
nix = "0.25" nix = "0.26"
[dev-dependencies] [dev-dependencies]
rand = "0.8.5" rand = "0.8.5"

View File

@ -20,6 +20,9 @@ use phantun::UDP_TTL;
async fn main() -> io::Result<()> { async fn main() -> io::Result<()> {
pretty_env_logger::init(); pretty_env_logger::init();
let num_cpus = num_cpus::get();
info!("{} cores available", num_cpus);
let matches = Command::new("Phantun Client") let matches = Command::new("Phantun Client")
.version(crate_version!()) .version(crate_version!())
.author("Datong Sun (github.com/dndx)") .author("Datong Sun (github.com/dndx)")
@ -72,7 +75,7 @@ async fn main() -> io::Result<()> {
.required(false) .required(false)
.help("Only use IPv4 address when connecting to remote") .help("Only use IPv4 address when connecting to remote")
.action(ArgAction::SetTrue) .action(ArgAction::SetTrue)
.conflicts_with_all(&["tun_local6", "tun_peer6"]), .conflicts_with_all(["tun_local6", "tun_peer6"]),
) )
.arg( .arg(
Arg::new("tun_local6") Arg::new("tun_local6")
@ -108,7 +111,7 @@ async fn main() -> io::Result<()> {
.required(false) .required(false)
.value_name("number") .value_name("number")
.help("Number of TCP connections per each client.") .help("Number of TCP connections per each client.")
.default_value("8") .default_value("1")
) )
.arg( .arg(
Arg::new("udp_connections") Arg::new("udp_connections")
@ -116,7 +119,7 @@ async fn main() -> io::Result<()> {
.required(false) .required(false)
.value_name("number") .value_name("number")
.help("Number of UDP connections per each client.") .help("Number of UDP connections per each client.")
.default_value("8") .default_value(num_cpus.to_string())
) )
.arg( .arg(
Arg::new("encryption") Arg::new("encryption")
@ -202,9 +205,6 @@ async fn main() -> io::Result<()> {
.transpose()?, .transpose()?,
); );
let num_cpus = num_cpus::get();
info!("{} cores available", num_cpus);
let tun = TunBuilder::new() let tun = TunBuilder::new()
.name(tun_name) // if name is empty, then it is set by kernel. .name(tun_name) // if name is empty, then it is set by kernel.
.tap(false) // false (default): TUN, true: TAP. .tap(false) // false (default): TUN, true: TAP.
@ -313,7 +313,7 @@ async fn main() -> io::Result<()> {
match res { match res {
Some(size) => { Some(size) => {
let udp_sock_index = udp_sock_index.fetch_add(1, Ordering::Relaxed) % udp_socks_amount; let udp_sock_index = udp_sock_index.fetch_add(1, Ordering::Relaxed) % udp_socks_amount;
let udp_sock = unsafe { udp_socks.get_unchecked(udp_sock_index) }; let udp_sock = udp_socks[udp_sock_index].clone();
if let Some(ref enc) = *encryption { if let Some(ref enc) = *encryption {
enc.decrypt(&mut buf_tcp[..size]); enc.decrypt(&mut buf_tcp[..size]);
} }
@ -360,7 +360,7 @@ async fn main() -> io::Result<()> {
match res { match res {
Ok(size) => { Ok(size) => {
let tcp_sock_index = tcp_sock_index.fetch_add(1, Ordering::Relaxed) % tcp_socks_amount; let tcp_sock_index = tcp_sock_index.fetch_add(1, Ordering::Relaxed) % tcp_socks_amount;
let tcp_sock = unsafe { tcp_socks.get_unchecked(tcp_sock_index) }; let tcp_sock = tcp_socks[tcp_sock_index].clone();
if let Some(ref enc) = *encryption { if let Some(ref enc) = *encryption {
enc.encrypt(&mut buf_udp[..size]); enc.encrypt(&mut buf_udp[..size]);
} }

View File

@ -20,6 +20,9 @@ use phantun::UDP_TTL;
async fn main() -> io::Result<()> { async fn main() -> io::Result<()> {
pretty_env_logger::init(); pretty_env_logger::init();
let num_cpus = num_cpus::get();
info!("{} cores available", num_cpus);
let matches = Command::new("Phantun Server") let matches = Command::new("Phantun Server")
.version(crate_version!()) .version(crate_version!())
.author("Datong Sun (github.com/dndx)") .author("Datong Sun (github.com/dndx)")
@ -72,7 +75,7 @@ async fn main() -> io::Result<()> {
.required(false) .required(false)
.help("Do not assign IPv6 addresses to Tun interface") .help("Do not assign IPv6 addresses to Tun interface")
.action(ArgAction::SetTrue) .action(ArgAction::SetTrue)
.conflicts_with_all(&["tun_local6", "tun_peer6"]), .conflicts_with_all(["tun_local6", "tun_peer6"]),
) )
.arg( .arg(
Arg::new("tun_local6") Arg::new("tun_local6")
@ -117,7 +120,7 @@ async fn main() -> io::Result<()> {
.required(false) .required(false)
.value_name("number") .value_name("number")
.help("Number of UDP connections per each TCP connections.") .help("Number of UDP connections per each TCP connections.")
.default_value("8") .default_value(num_cpus.to_string())
) )
.get_matches(); .get_matches();
@ -179,9 +182,6 @@ async fn main() -> io::Result<()> {
.map(fs::read) .map(fs::read)
.transpose()?; .transpose()?;
let num_cpus = num_cpus::get();
info!("{} cores available", num_cpus);
let tun = TunBuilder::new() let tun = TunBuilder::new()
.name(tun_name) // if name is empty, then it is set by kernel. .name(tun_name) // if name is empty, then it is set by kernel.
.tap(false) // false (default): TUN, true: TAP. .tap(false) // false (default): TUN, true: TAP.
@ -312,7 +312,7 @@ async fn main() -> io::Result<()> {
match res { match res {
Some(size) => { Some(size) => {
udp_sock_index = (udp_sock_index + 1) % udp_socks_amount; udp_sock_index = (udp_sock_index + 1) % udp_socks_amount;
let udp_sock = unsafe { udp_socks.get_unchecked(udp_sock_index) }; let udp_sock = udp_socks[udp_sock_index].clone();
if let Some(ref enc) = *encryption { if let Some(ref enc) = *encryption {
enc.decrypt(&mut buf_tcp[..size]); enc.decrypt(&mut buf_tcp[..size]);
} }

View File

@ -1,6 +1,5 @@
use fake_tcp::packet::MAX_PACKET_LEN; use fake_tcp::packet::MAX_PACKET_LEN;
use std::convert::From; use std::convert::From;
use std::iter;
use std::time::Duration; use std::time::Duration;
pub mod utils; pub mod utils;
@ -33,12 +32,12 @@ impl From<&str> for Encryption {
if input.len() < 2 { if input.len() < 2 {
panic!("xor key should be provided"); panic!("xor key should be provided");
} else { } else {
return Self::Xor( Self::Xor(
iter::repeat(input[1]) input[1]
.take((MAX_PACKET_LEN as f32 / input[1].len() as f32).ceil() as usize) .repeat((MAX_PACKET_LEN as f32 / input[1].len() as f32).ceil() as usize)
.collect::<String>()[..MAX_PACKET_LEN] [..MAX_PACKET_LEN]
.into(), .into(),
); )
} }
} }
_ => { _ => {