From ac077491342db570fd1c784662eae1c75f7b1e3b Mon Sep 17 00:00:00 2001 From: Saber Haj Rabiee Date: Sun, 25 Dec 2022 06:47:03 -0800 Subject: [PATCH] fixing handshake handle logic and stricting the memory orders --- phantun/src/bin/client.rs | 12 ++++++++++-- phantun/src/bin/server.rs | 29 +++++++++++++++++++++-------- 2 files changed, 31 insertions(+), 10 deletions(-) diff --git a/phantun/src/bin/client.rs b/phantun/src/bin/client.rs index f8c7d09..a64b38b 100644 --- a/phantun/src/bin/client.rs +++ b/phantun/src/bin/client.rs @@ -268,7 +268,9 @@ async fn main() -> io::Result<()> { } }; + let mut should_receive_handshake_packet = false; if let Some(ref p) = *handshake_packet { + should_receive_handshake_packet = true; if tcp_sock.send(p).await.is_none() { error!( "Failed to send handshake packet to remote, closing connection." @@ -316,7 +318,13 @@ async fn main() -> io::Result<()> { debug!("Received EOF from {addr}, closing connection {sock_index}"); break; } - let udp_sock_index = udp_sock_index.fetch_add(1, Ordering::Relaxed) % udp_socks_amount; + // discard handshake packet since it is not related to + // underlying logic + if should_receive_handshake_packet { + should_receive_handshake_packet = false; + continue; + } + let udp_sock_index = udp_sock_index.fetch_add(1, Ordering::SeqCst) % udp_socks_amount; let udp_sock = udp_socks[udp_sock_index].clone(); if let Some(ref enc) = *encryption { enc.decrypt(&mut buf_tcp[..size]); @@ -367,7 +375,7 @@ async fn main() -> io::Result<()> { debug!("Zero-sized data are not supported, discarding received data from {addr}"); continue; } - let tcp_sock_index = tcp_sock_index.fetch_add(1, Ordering::Relaxed) % tcp_socks_amount; + let tcp_sock_index = tcp_sock_index.fetch_add(1, Ordering::SeqCst) % tcp_socks_amount; let tcp_sock = tcp_socks[tcp_sock_index].clone(); if let Some(ref enc) = *encryption { enc.encrypt(&mut buf_udp[..size]); diff --git a/phantun/src/bin/server.rs b/phantun/src/bin/server.rs index ccc1d2e..cf7fec7 100644 --- a/phantun/src/bin/server.rs +++ b/phantun/src/bin/server.rs @@ -181,6 +181,7 @@ async fn main() -> io::Result<()> { .get_one::("handshake_packet") .map(fs::read) .transpose()?; + let handshake_packet = Arc::new(handshake_packet); let tun = TunBuilder::new() .name(tun_name) // if name is empty, then it is set by kernel. @@ -207,14 +208,6 @@ async fn main() -> io::Result<()> { 'main_loop: loop { let tcp_sock = Arc::new(stack.accept().await); info!("New connection: {}", tcp_sock); - if let Some(ref p) = handshake_packet { - if tcp_sock.send(p).await.is_none() { - error!("Failed to send handshake packet to remote, closing connection."); - continue; - } - - debug!("Sent handshake packet to: {}", tcp_sock); - } let udp_sock = UdpSocket::bind(if remote_addr.is_ipv4() { "0.0.0.0:0" @@ -301,6 +294,11 @@ async fn main() -> io::Result<()> { let encryption = encryption.clone(); let packet_received = packet_received.clone(); let cancellation = cancellation.clone(); + let handshake_packet = handshake_packet.clone(); + let mut should_receive_handshake_packet = false; + if handshake_packet.is_some() { + should_receive_handshake_packet = true; + } tokio::spawn(async move { let mut buf_tcp = [0u8; MAX_PACKET_LEN]; let mut udp_sock_index = 0; @@ -319,6 +317,21 @@ async fn main() -> io::Result<()> { debug!("Received EOF from {local_addr}, closing connection"); break; } + // discard handshake packet since it is not related to + // underlying logic + if should_receive_handshake_packet { + should_receive_handshake_packet = false; + if let Some(ref p) = *handshake_packet { + if tcp_sock.send(p).await.is_none() { + error!("Failed to send handshake packet to remote, closing connection."); + break; + } + + debug!("Sent handshake packet to: {}", tcp_sock); + continue; + } + + } udp_sock_index = (udp_sock_index + 1) % udp_socks_amount; let udp_sock = udp_socks[udp_sock_index].clone(); if let Some(ref enc) = *encryption {