From 1040b3dec968686da23fa687319dc2e2ba3fe9c5 Mon Sep 17 00:00:00 2001 From: Saber Haj Rabiee Date: Tue, 20 Dec 2022 06:53:26 -0800 Subject: [PATCH] Preventing resource exhaustion and reconnection --- fake-tcp/src/lib.rs | 48 +++++++++++++++++++++++++-------------- phantun/src/bin/client.rs | 8 +++++++ phantun/src/bin/server.rs | 8 +++++++ 3 files changed, 47 insertions(+), 17 deletions(-) diff --git a/fake-tcp/src/lib.rs b/fake-tcp/src/lib.rs index 0dea62e..c3497fc 100644 --- a/fake-tcp/src/lib.rs +++ b/fake-tcp/src/lib.rs @@ -150,13 +150,13 @@ impl Socket { } fn build_tcp_packet(&self, flags: u16, payload: Option<&[u8]>) -> Bytes { - let ack = self.ack.load(Ordering::Relaxed); - self.last_ack.store(ack, Ordering::Relaxed); + let ack = self.ack.load(Ordering::SeqCst); + self.last_ack.store(ack, Ordering::SeqCst); build_tcp_packet( self.local_addr, self.remote_addr, - self.seq.load(Ordering::Relaxed), + self.seq.load(Ordering::SeqCst), ack, flags, payload, @@ -174,7 +174,7 @@ impl Socket { match self.state { State::Established => { let buf = self.build_tcp_packet(tcp::TcpFlags::ACK, Some(payload)); - self.seq.fetch_add(payload.len() as u32, Ordering::Relaxed); + self.seq.fetch_add(payload.len() as u32, Ordering::SeqCst); self.tun.send(&buf).await.ok().and(Some(())) } _ => unreachable!(), @@ -202,8 +202,8 @@ impl Socket { let payload = tcp_packet.payload(); let new_ack = tcp_packet.get_sequence().wrapping_add(payload.len() as u32); - let last_ask = self.last_ack.load(Ordering::Relaxed); - self.ack.store(new_ack, Ordering::Relaxed); + let last_ask = self.last_ack.load(Ordering::SeqCst); + self.ack.store(new_ack, Ordering::SeqCst); if new_ack.overflowing_sub(last_ask).0 > MAX_UNACKED_LEN { let buf = self.build_tcp_packet(tcp::TcpFlags::ACK, None); @@ -255,12 +255,19 @@ impl Socket { return; } + let packet_ack = tcp_packet.get_acknowledgement(); if tcp_packet.get_flags() == tcp::TcpFlags::ACK - && tcp_packet.get_acknowledgement() - == self.seq.load(Ordering::Relaxed) + 1 + && self + .seq + .compare_exchange( + packet_ack - 1, + packet_ack, + Ordering::SeqCst, + Ordering::Acquire, + ) + .is_ok() { // found our ACK - self.seq.fetch_add(1, Ordering::Relaxed); self.state = State::Established; info!("Connection from {:?} established", self.remote_addr); @@ -308,14 +315,21 @@ impl Socket { return None; } + let packet_ack = tcp_packet.get_acknowledgement(); if tcp_packet.get_flags() == tcp::TcpFlags::SYN | tcp::TcpFlags::ACK - && tcp_packet.get_acknowledgement() - == self.seq.load(Ordering::Relaxed) + 1 + && self + .seq + .compare_exchange( + packet_ack - 1, + packet_ack, + Ordering::SeqCst, + Ordering::Acquire, + ) + .is_ok() { // found our SYN + ACK - self.seq.fetch_add(1, Ordering::Relaxed); self.ack - .store(tcp_packet.get_sequence() + 1, Ordering::Relaxed); + .store(tcp_packet.get_sequence() + 1, Ordering::SeqCst); // send ACK to finish handshake let buf = self.build_tcp_packet(tcp::TcpFlags::ACK, None); @@ -358,7 +372,7 @@ impl Drop for Socket { let buf = build_tcp_packet( self.local_addr, self.remote_addr, - self.seq.load(Ordering::Relaxed), + self.seq.load(Ordering::SeqCst), 0, tcp::TcpFlags::RST, None, @@ -449,7 +463,7 @@ impl Stack { let mut sock = match self.shared.tuples.entry(tuple) { Entry::Occupied(_) => continue, Entry::Vacant(v) => { - let tun_index = self.shared.tun_index.fetch_add(1, Ordering::Relaxed) + let tun_index = self.shared.tun_index.fetch_add(1, Ordering::SeqCst) % self.shared.tun.len(); let tun = self.shared.tun[tun_index].clone(); let (sock, incoming) = Socket::new( @@ -552,7 +566,7 @@ impl Stack { tcp::TcpFlags::RST | tcp::TcpFlags::ACK, None, ); - let tun_index = shared.tun_index.fetch_add(1, Ordering::Relaxed) % shared.tun.len(); + let tun_index = shared.tun_index.fetch_add(1, Ordering::SeqCst) % shared.tun.len(); let tun = shared.tun[tun_index].clone(); if let Err(err) = tun.try_send(&buf) { error!("tun send error: {err}"); @@ -568,7 +582,7 @@ impl Stack { tcp::TcpFlags::RST | tcp::TcpFlags::ACK, None, ); - let tun_index = shared.tun_index.fetch_add(1, Ordering::Relaxed) % shared.tun.len(); + let tun_index = shared.tun_index.fetch_add(1, Ordering::SeqCst) % shared.tun.len(); let tun = shared.tun[tun_index].clone(); if let Err(err) = tun.try_send(&buf) { error!("tun send error: {err}"); diff --git a/phantun/src/bin/client.rs b/phantun/src/bin/client.rs index d74bbd0..f8c7d09 100644 --- a/phantun/src/bin/client.rs +++ b/phantun/src/bin/client.rs @@ -312,6 +312,10 @@ async fn main() -> io::Result<()> { res = tcp_sock.recv(&mut buf_tcp) => { match res { Some(size) => { + if size == 0 { + debug!("Received EOF from {addr}, closing connection {sock_index}"); + break; + } let udp_sock_index = udp_sock_index.fetch_add(1, Ordering::Relaxed) % udp_socks_amount; let udp_sock = udp_socks[udp_sock_index].clone(); if let Some(ref enc) = *encryption { @@ -359,6 +363,10 @@ async fn main() -> io::Result<()> { res = udp_sock.recv(&mut buf_udp) => { match res { Ok(size) => { + if size == 0 { + debug!("Zero-sized data are not supported, discarding received data from {addr}"); + continue; + } let tcp_sock_index = tcp_sock_index.fetch_add(1, Ordering::Relaxed) % tcp_socks_amount; let tcp_sock = tcp_socks[tcp_sock_index].clone(); if let Some(ref enc) = *encryption { diff --git a/phantun/src/bin/server.rs b/phantun/src/bin/server.rs index 92365e1..ccc1d2e 100644 --- a/phantun/src/bin/server.rs +++ b/phantun/src/bin/server.rs @@ -273,6 +273,10 @@ async fn main() -> io::Result<()> { res = udp_sock.recv(&mut buf_udp) => { match res { Ok(size) => { + if size == 0 { + debug!("Zero-sized data are not supported, discarding received data from {local_addr}"); + continue; + } if let Some(ref enc) = *encryption { enc.encrypt(&mut buf_udp[..size]); } @@ -311,6 +315,10 @@ async fn main() -> io::Result<()> { res = tcp_sock.recv(&mut buf_tcp) => { match res { Some(size) => { + if size == 0 { + debug!("Received EOF from {local_addr}, closing connection"); + break; + } udp_sock_index = (udp_sock_index + 1) % udp_socks_amount; let udp_sock = udp_socks[udp_sock_index].clone(); if let Some(ref enc) = *encryption {