From 25a1cf9aab72de2c26d3468fbd0bfa700142a572 Mon Sep 17 00:00:00 2001 From: Andrea Date: Fri, 16 Dec 2022 22:11:46 +0800 Subject: [PATCH] feat(fake-tcp) add tcp keep-alive support to client --- fake-tcp/src/lib.rs | 47 +++++++++++++++-- phantun/src/bin/client.rs | 105 +++++++++++++++++++++++++++++++------- 2 files changed, 130 insertions(+), 22 deletions(-) diff --git a/fake-tcp/src/lib.rs b/fake-tcp/src/lib.rs index a5c5863..d70caf6 100644 --- a/fake-tcp/src/lib.rs +++ b/fake-tcp/src/lib.rs @@ -178,6 +178,34 @@ impl Socket { } } + /// Sends a keep-alive (zero length) datagram to the other end. + /// + /// This method takes `&self`, and it can be called safely by multiple threads + /// at the same time. + /// + /// A return of `None` means the Tun socket returned an error + /// and this socket must be closed. + pub async fn send_keepalive(&self) -> Option<()> { + match self.state { + State::Established => { + let ack = self.ack.load(Ordering::Relaxed); + self.last_ack.store(ack, Ordering::Relaxed); + + let buf = build_tcp_packet( + self.local_addr, + self.remote_addr, + // the current sequence number is one byte less than the next expected sequence number + self.seq.load(Ordering::Relaxed) - 1, + ack, + tcp::TcpFlags::ACK, + None, + ); + self.tun.send(&buf).await.ok().and(Some(())) + } + _ => unreachable!(), + } + } + /// Attempt to receive a datagram from the other end. /// /// This method takes `&self`, and it can be called safely by multiple threads @@ -198,11 +226,22 @@ impl Socket { let payload = tcp_packet.payload(); - let new_ack = tcp_packet.get_sequence().wrapping_add(payload.len() as u32); - let last_ask = self.last_ack.load(Ordering::Relaxed); - self.ack.store(new_ack, Ordering::Relaxed); + let need_ack = if payload.len() == 0 { + let current_head = self.ack.load(Ordering::Relaxed); + // If it's a keep alive packet + let kp = tcp_packet.get_sequence() == current_head - 1; + if kp { + trace!("{} Keepalive packet received!", self); + } + kp + } else { + let new_ack = tcp_packet.get_sequence().wrapping_add(payload.len() as u32); + let last_ask = self.last_ack.load(Ordering::Relaxed); + self.ack.store(new_ack, Ordering::Relaxed); + new_ack.overflowing_sub(last_ask).0 > MAX_UNACKED_LEN + }; - if new_ack.overflowing_sub(last_ask).0 > MAX_UNACKED_LEN { + if need_ack { let buf = self.build_tcp_packet(tcp::TcpFlags::ACK, None); if let Err(e) = self.tun.try_send(&buf) { // This should not really happen as we have not sent anything for diff --git a/phantun/src/bin/client.rs b/phantun/src/bin/client.rs index 360101e..aa70b0a 100644 --- a/phantun/src/bin/client.rs +++ b/phantun/src/bin/client.rs @@ -1,15 +1,16 @@ use clap::{crate_version, Arg, ArgAction, Command}; use fake_tcp::packet::MAX_PACKET_LEN; use fake_tcp::{Socket, Stack}; -use log::{debug, error, info}; +use log::{debug, error, info, trace, warn}; use phantun::utils::{assign_ipv6_address, new_udp_reuseport}; use std::collections::HashMap; use std::fs; use std::io; use std::net::{Ipv4Addr, SocketAddr}; use std::sync::Arc; +use std::time::Duration; use tokio::sync::{Notify, RwLock}; -use tokio::time; +use tokio::time::{self, Instant}; use tokio_tun::TunBuilder; use tokio_util::sync::CancellationToken; @@ -101,6 +102,34 @@ async fn main() -> io::Result<()> { Note: ensure this file's size does not exceed the MTU of the outgoing interface. \ The content is always sent out in a single packet and will not be further segmented") ) + .arg( + Arg::new("keepalive_time") + .long("keepalive-time") + .short('k') + .required(false) + .value_name("SECONDS") + .value_parser(clap::value_parser!(u64).range(1..86400)) + .help("Specify the interval between the last packet received and the first keepalive probe.\ + If not specified, no keepalive probe will be sent.") + ) + .arg( + Arg::new("keepalive_interval") + .long("keepalive-interval") + .required(false) + .default_value("5") + .value_name("SECONDS") + .value_parser(clap::value_parser!(u64).range(1..3600)) + .help("Specify the interval between keepalive probes.") + ) + .arg( + Arg::new("keepalive_retries") + .long("keepalive-retries") + .required(false) + .default_value("3") + .value_name("N") + .value_parser(clap::value_parser!(i32).range(1..20)) + .help("Specify the number of keepalive probe retries before a connection reset.") + ) .get_matches(); let local_addr: SocketAddr = matches @@ -148,6 +177,10 @@ async fn main() -> io::Result<()> { .map(fs::read) .transpose()?; + let tcp_keepalive_intvl = Duration::from_secs(*matches.get_one::("keepalive_interval").unwrap()); + let tcp_keepalive_time = matches.get_one::("keepalive_time").map(|s| Duration::from_secs(*s)); + let tcp_keepalive_retries = *matches.get_one::("keepalive_retries").unwrap(); + let num_cpus = num_cpus::get(); info!("{} cores available", num_cpus); @@ -218,13 +251,15 @@ async fn main() -> io::Result<()> { // spawn "fastpath" UDP socket and task, this will offload main task // from forwarding UDP packets - let packet_received = Arc::new(Notify::new()); + let data_received = Arc::new(Notify::new()); + let tcp_packet_received = Arc::new(Notify::new()); let quit = CancellationToken::new(); for i in 0..num_cpus { let sock = sock.clone(); let quit = quit.clone(); - let packet_received = packet_received.clone(); + let data_received = data_received.clone(); + let tcp_packet_received = tcp_packet_received.clone(); tokio::spawn(async move { let mut buf_udp = [0u8; MAX_PACKET_LEN]; @@ -241,7 +276,7 @@ async fn main() -> io::Result<()> { return; } - packet_received.notify_one(); + data_received.notify_one(); }, res = sock.recv(&mut buf_tcp) => { match res { @@ -252,6 +287,9 @@ async fn main() -> io::Result<()> { quit.cancel(); return; } + data_received.notify_one(); + } else { + trace!("Empty TCP packet received"); } }, None => { @@ -261,7 +299,7 @@ async fn main() -> io::Result<()> { }, } - packet_received.notify_one(); + tcp_packet_received.notify_one(); }, _ = quit.cancelled() => { debug!("worker {} terminated", i); @@ -274,27 +312,58 @@ async fn main() -> io::Result<()> { let connections = connections.clone(); tokio::spawn(async move { + let mut last_tcp_recv = Instant::now(); + let mut last_data_recv = Instant::now(); + let mut probe_sent = 0; + let mut last_tcp_keepalive_sent = Instant::now(); loop { - let read_timeout = time::sleep(UDP_TTL); - let packet_received_fut = packet_received.notified(); + let tm0 = time::sleep(Duration::from_secs(1)); + let data_received_fut = data_received.notified(); + let tcp_packet_received_fut = tcp_packet_received.notified(); tokio::select! { - _ = read_timeout => { - info!("No traffic seen in the last {:?}, closing connection", UDP_TTL); - connections.write().await.remove(&addr); - debug!("removed fake TCP socket from connections table"); + _ = tm0 => { + + if tcp_keepalive_time.is_some() { + let tcp_idle = last_tcp_recv.elapsed(); + if tcp_idle > tcp_keepalive_time.unwrap() { + if last_tcp_keepalive_sent.elapsed() >= tcp_keepalive_intvl { + if probe_sent >= tcp_keepalive_retries { + info!("Connection {} TCP keep-alive retry exceeded", sock); + break; + } + trace!("Connection {} sending keep-alive {} / {}", sock, probe_sent + 1, tcp_keepalive_retries); + let result = sock.send_keepalive().await; + if result.is_none() { + warn!("Failed to send keep-alive!"); + } + probe_sent += 1; + last_tcp_keepalive_sent = Instant::now(); + } + } + } - quit.cancel(); - return; + if last_data_recv.elapsed() > UDP_TTL { + // Execute every 1 sec, unless there's traffic + info!("Connection {} No traffic seen in the last {:?}, closing connection", sock, UDP_TTL); + break; + } }, _ = quit.cancelled() => { - connections.write().await.remove(&addr); - debug!("removed fake TCP socket from connections table"); - return; + break; + }, + _ = data_received_fut => { + last_data_recv = Instant::now(); + }, + _ = tcp_packet_received_fut => { + last_tcp_recv = Instant::now(); + probe_sent = 0; }, - _ = packet_received_fut => {}, } } + debug!("removed fake TCP socket {} from connections table", sock); + connections.write().await.remove(&addr); + quit.cancel(); }); } });