Preventing resource exhaustion and reconnection

This commit is contained in:
Saber Haj Rabiee 2022-12-20 06:53:26 -08:00
parent f68f153eb9
commit 1040b3dec9
3 changed files with 47 additions and 17 deletions

View File

@ -150,13 +150,13 @@ impl Socket {
} }
fn build_tcp_packet(&self, flags: u16, payload: Option<&[u8]>) -> Bytes { fn build_tcp_packet(&self, flags: u16, payload: Option<&[u8]>) -> Bytes {
let ack = self.ack.load(Ordering::Relaxed); let ack = self.ack.load(Ordering::SeqCst);
self.last_ack.store(ack, Ordering::Relaxed); self.last_ack.store(ack, Ordering::SeqCst);
build_tcp_packet( build_tcp_packet(
self.local_addr, self.local_addr,
self.remote_addr, self.remote_addr,
self.seq.load(Ordering::Relaxed), self.seq.load(Ordering::SeqCst),
ack, ack,
flags, flags,
payload, payload,
@ -174,7 +174,7 @@ impl Socket {
match self.state { match self.state {
State::Established => { State::Established => {
let buf = self.build_tcp_packet(tcp::TcpFlags::ACK, Some(payload)); let buf = self.build_tcp_packet(tcp::TcpFlags::ACK, Some(payload));
self.seq.fetch_add(payload.len() as u32, Ordering::Relaxed); self.seq.fetch_add(payload.len() as u32, Ordering::SeqCst);
self.tun.send(&buf).await.ok().and(Some(())) self.tun.send(&buf).await.ok().and(Some(()))
} }
_ => unreachable!(), _ => unreachable!(),
@ -202,8 +202,8 @@ impl Socket {
let payload = tcp_packet.payload(); let payload = tcp_packet.payload();
let new_ack = tcp_packet.get_sequence().wrapping_add(payload.len() as u32); let new_ack = tcp_packet.get_sequence().wrapping_add(payload.len() as u32);
let last_ask = self.last_ack.load(Ordering::Relaxed); let last_ask = self.last_ack.load(Ordering::SeqCst);
self.ack.store(new_ack, Ordering::Relaxed); self.ack.store(new_ack, Ordering::SeqCst);
if new_ack.overflowing_sub(last_ask).0 > MAX_UNACKED_LEN { if new_ack.overflowing_sub(last_ask).0 > MAX_UNACKED_LEN {
let buf = self.build_tcp_packet(tcp::TcpFlags::ACK, None); let buf = self.build_tcp_packet(tcp::TcpFlags::ACK, None);
@ -255,12 +255,19 @@ impl Socket {
return; return;
} }
let packet_ack = tcp_packet.get_acknowledgement();
if tcp_packet.get_flags() == tcp::TcpFlags::ACK if tcp_packet.get_flags() == tcp::TcpFlags::ACK
&& tcp_packet.get_acknowledgement() && self
== self.seq.load(Ordering::Relaxed) + 1 .seq
.compare_exchange(
packet_ack - 1,
packet_ack,
Ordering::SeqCst,
Ordering::Acquire,
)
.is_ok()
{ {
// found our ACK // found our ACK
self.seq.fetch_add(1, Ordering::Relaxed);
self.state = State::Established; self.state = State::Established;
info!("Connection from {:?} established", self.remote_addr); info!("Connection from {:?} established", self.remote_addr);
@ -308,14 +315,21 @@ impl Socket {
return None; return None;
} }
let packet_ack = tcp_packet.get_acknowledgement();
if tcp_packet.get_flags() == tcp::TcpFlags::SYN | tcp::TcpFlags::ACK if tcp_packet.get_flags() == tcp::TcpFlags::SYN | tcp::TcpFlags::ACK
&& tcp_packet.get_acknowledgement() && self
== self.seq.load(Ordering::Relaxed) + 1 .seq
.compare_exchange(
packet_ack - 1,
packet_ack,
Ordering::SeqCst,
Ordering::Acquire,
)
.is_ok()
{ {
// found our SYN + ACK // found our SYN + ACK
self.seq.fetch_add(1, Ordering::Relaxed);
self.ack self.ack
.store(tcp_packet.get_sequence() + 1, Ordering::Relaxed); .store(tcp_packet.get_sequence() + 1, Ordering::SeqCst);
// send ACK to finish handshake // send ACK to finish handshake
let buf = self.build_tcp_packet(tcp::TcpFlags::ACK, None); let buf = self.build_tcp_packet(tcp::TcpFlags::ACK, None);
@ -358,7 +372,7 @@ impl Drop for Socket {
let buf = build_tcp_packet( let buf = build_tcp_packet(
self.local_addr, self.local_addr,
self.remote_addr, self.remote_addr,
self.seq.load(Ordering::Relaxed), self.seq.load(Ordering::SeqCst),
0, 0,
tcp::TcpFlags::RST, tcp::TcpFlags::RST,
None, None,
@ -449,7 +463,7 @@ impl Stack {
let mut sock = match self.shared.tuples.entry(tuple) { let mut sock = match self.shared.tuples.entry(tuple) {
Entry::Occupied(_) => continue, Entry::Occupied(_) => continue,
Entry::Vacant(v) => { Entry::Vacant(v) => {
let tun_index = self.shared.tun_index.fetch_add(1, Ordering::Relaxed) let tun_index = self.shared.tun_index.fetch_add(1, Ordering::SeqCst)
% self.shared.tun.len(); % self.shared.tun.len();
let tun = self.shared.tun[tun_index].clone(); let tun = self.shared.tun[tun_index].clone();
let (sock, incoming) = Socket::new( let (sock, incoming) = Socket::new(
@ -552,7 +566,7 @@ impl Stack {
tcp::TcpFlags::RST | tcp::TcpFlags::ACK, tcp::TcpFlags::RST | tcp::TcpFlags::ACK,
None, None,
); );
let tun_index = shared.tun_index.fetch_add(1, Ordering::Relaxed) % shared.tun.len(); let tun_index = shared.tun_index.fetch_add(1, Ordering::SeqCst) % shared.tun.len();
let tun = shared.tun[tun_index].clone(); let tun = shared.tun[tun_index].clone();
if let Err(err) = tun.try_send(&buf) { if let Err(err) = tun.try_send(&buf) {
error!("tun send error: {err}"); error!("tun send error: {err}");
@ -568,7 +582,7 @@ impl Stack {
tcp::TcpFlags::RST | tcp::TcpFlags::ACK, tcp::TcpFlags::RST | tcp::TcpFlags::ACK,
None, None,
); );
let tun_index = shared.tun_index.fetch_add(1, Ordering::Relaxed) % shared.tun.len(); let tun_index = shared.tun_index.fetch_add(1, Ordering::SeqCst) % shared.tun.len();
let tun = shared.tun[tun_index].clone(); let tun = shared.tun[tun_index].clone();
if let Err(err) = tun.try_send(&buf) { if let Err(err) = tun.try_send(&buf) {
error!("tun send error: {err}"); error!("tun send error: {err}");

View File

@ -312,6 +312,10 @@ async fn main() -> io::Result<()> {
res = tcp_sock.recv(&mut buf_tcp) => { res = tcp_sock.recv(&mut buf_tcp) => {
match res { match res {
Some(size) => { Some(size) => {
if size == 0 {
debug!("Received EOF from {addr}, closing connection {sock_index}");
break;
}
let udp_sock_index = udp_sock_index.fetch_add(1, Ordering::Relaxed) % udp_socks_amount; let udp_sock_index = udp_sock_index.fetch_add(1, Ordering::Relaxed) % udp_socks_amount;
let udp_sock = udp_socks[udp_sock_index].clone(); let udp_sock = udp_socks[udp_sock_index].clone();
if let Some(ref enc) = *encryption { if let Some(ref enc) = *encryption {
@ -359,6 +363,10 @@ async fn main() -> io::Result<()> {
res = udp_sock.recv(&mut buf_udp) => { res = udp_sock.recv(&mut buf_udp) => {
match res { match res {
Ok(size) => { Ok(size) => {
if size == 0 {
debug!("Zero-sized data are not supported, discarding received data from {addr}");
continue;
}
let tcp_sock_index = tcp_sock_index.fetch_add(1, Ordering::Relaxed) % tcp_socks_amount; let tcp_sock_index = tcp_sock_index.fetch_add(1, Ordering::Relaxed) % tcp_socks_amount;
let tcp_sock = tcp_socks[tcp_sock_index].clone(); let tcp_sock = tcp_socks[tcp_sock_index].clone();
if let Some(ref enc) = *encryption { if let Some(ref enc) = *encryption {

View File

@ -273,6 +273,10 @@ async fn main() -> io::Result<()> {
res = udp_sock.recv(&mut buf_udp) => { res = udp_sock.recv(&mut buf_udp) => {
match res { match res {
Ok(size) => { Ok(size) => {
if size == 0 {
debug!("Zero-sized data are not supported, discarding received data from {local_addr}");
continue;
}
if let Some(ref enc) = *encryption { if let Some(ref enc) = *encryption {
enc.encrypt(&mut buf_udp[..size]); enc.encrypt(&mut buf_udp[..size]);
} }
@ -311,6 +315,10 @@ async fn main() -> io::Result<()> {
res = tcp_sock.recv(&mut buf_tcp) => { res = tcp_sock.recv(&mut buf_tcp) => {
match res { match res {
Some(size) => { Some(size) => {
if size == 0 {
debug!("Received EOF from {local_addr}, closing connection");
break;
}
udp_sock_index = (udp_sock_index + 1) % udp_socks_amount; udp_sock_index = (udp_sock_index + 1) % udp_socks_amount;
let udp_sock = udp_socks[udp_sock_index].clone(); let udp_sock = udp_socks[udp_sock_index].clone();
if let Some(ref enc) = *encryption { if let Some(ref enc) = *encryption {