From 25a1cf9aab72de2c26d3468fbd0bfa700142a572 Mon Sep 17 00:00:00 2001
From: Andrea <andreadaoud6@gmail.com>
Date: Fri, 16 Dec 2022 22:11:46 +0800
Subject: [PATCH] feat(fake-tcp) add tcp keep-alive support to client

---
 fake-tcp/src/lib.rs       |  47 +++++++++++++++--
 phantun/src/bin/client.rs | 105 +++++++++++++++++++++++++++++++-------
 2 files changed, 130 insertions(+), 22 deletions(-)

diff --git a/fake-tcp/src/lib.rs b/fake-tcp/src/lib.rs
index a5c5863..d70caf6 100644
--- a/fake-tcp/src/lib.rs
+++ b/fake-tcp/src/lib.rs
@@ -178,6 +178,34 @@ impl Socket {
         }
     }
 
+    /// Sends a keep-alive (zero length) datagram to the other end.
+    ///
+    /// This method takes `&self`, and it can be called safely by multiple threads
+    /// at the same time.
+    ///
+    /// A return of `None` means the Tun socket returned an error
+    /// and this socket must be closed.
+    pub async fn send_keepalive(&self) -> Option<()> {
+        match self.state {
+            State::Established => {
+                let ack = self.ack.load(Ordering::Relaxed);
+                self.last_ack.store(ack, Ordering::Relaxed);
+
+                let buf = build_tcp_packet(
+                    self.local_addr,
+                    self.remote_addr,
+                    // the current sequence number is one byte less than the next expected sequence number
+                    self.seq.load(Ordering::Relaxed) - 1,
+                    ack,
+                    tcp::TcpFlags::ACK,
+                    None,
+                );
+                self.tun.send(&buf).await.ok().and(Some(()))
+            }
+            _ => unreachable!(),
+        }
+    }
+
     /// Attempt to receive a datagram from the other end.
     ///
     /// This method takes `&self`, and it can be called safely by multiple threads
@@ -198,11 +226,22 @@ impl Socket {
 
                     let payload = tcp_packet.payload();
 
-                    let new_ack = tcp_packet.get_sequence().wrapping_add(payload.len() as u32);
-                    let last_ask = self.last_ack.load(Ordering::Relaxed);
-                    self.ack.store(new_ack, Ordering::Relaxed);
+                    let need_ack = if payload.len() == 0 {
+                        let current_head = self.ack.load(Ordering::Relaxed);
+                        // If it's a keep alive packet
+                        let kp = tcp_packet.get_sequence() == current_head - 1;
+                        if kp {
+                            trace!("{} Keepalive packet received!", self);
+                        }
+                        kp
+                    } else {
+                        let new_ack = tcp_packet.get_sequence().wrapping_add(payload.len() as u32);
+                        let last_ask = self.last_ack.load(Ordering::Relaxed);
+                        self.ack.store(new_ack, Ordering::Relaxed);
+                        new_ack.overflowing_sub(last_ask).0 > MAX_UNACKED_LEN
+                    };
 
-                    if new_ack.overflowing_sub(last_ask).0 > MAX_UNACKED_LEN {
+                    if need_ack {
                         let buf = self.build_tcp_packet(tcp::TcpFlags::ACK, None);
                         if let Err(e) = self.tun.try_send(&buf) {
                             // This should not really happen as we have not sent anything for
diff --git a/phantun/src/bin/client.rs b/phantun/src/bin/client.rs
index 360101e..aa70b0a 100644
--- a/phantun/src/bin/client.rs
+++ b/phantun/src/bin/client.rs
@@ -1,15 +1,16 @@
 use clap::{crate_version, Arg, ArgAction, Command};
 use fake_tcp::packet::MAX_PACKET_LEN;
 use fake_tcp::{Socket, Stack};
-use log::{debug, error, info};
+use log::{debug, error, info, trace, warn};
 use phantun::utils::{assign_ipv6_address, new_udp_reuseport};
 use std::collections::HashMap;
 use std::fs;
 use std::io;
 use std::net::{Ipv4Addr, SocketAddr};
 use std::sync::Arc;
+use std::time::Duration;
 use tokio::sync::{Notify, RwLock};
-use tokio::time;
+use tokio::time::{self, Instant};
 use tokio_tun::TunBuilder;
 use tokio_util::sync::CancellationToken;
 
@@ -101,6 +102,34 @@ async fn main() -> io::Result<()> {
                       Note: ensure this file's size does not exceed the MTU of the outgoing interface. \
                       The content is always sent out in a single packet and will not be further segmented")
         )
+        .arg(
+            Arg::new("keepalive_time")
+                .long("keepalive-time")
+                .short('k')
+                .required(false)
+                .value_name("SECONDS")
+                .value_parser(clap::value_parser!(u64).range(1..86400))
+                .help("Specify the interval between the last packet received and the first keepalive probe.\
+                If not specified, no keepalive probe will be sent.")
+        )
+        .arg(
+            Arg::new("keepalive_interval")
+                .long("keepalive-interval")
+                .required(false)
+                .default_value("5")
+                .value_name("SECONDS")
+                .value_parser(clap::value_parser!(u64).range(1..3600))
+                .help("Specify the interval between keepalive probes.")
+        )
+        .arg(
+            Arg::new("keepalive_retries")
+                .long("keepalive-retries")
+                .required(false)
+                .default_value("3")
+                .value_name("N")
+                .value_parser(clap::value_parser!(i32).range(1..20))
+                .help("Specify the number of keepalive probe retries before a connection reset.")
+        )
         .get_matches();
 
     let local_addr: SocketAddr = matches
@@ -148,6 +177,10 @@ async fn main() -> io::Result<()> {
         .map(fs::read)
         .transpose()?;
 
+    let tcp_keepalive_intvl = Duration::from_secs(*matches.get_one::<u64>("keepalive_interval").unwrap());
+    let tcp_keepalive_time = matches.get_one::<u64>("keepalive_time").map(|s| Duration::from_secs(*s));
+    let tcp_keepalive_retries = *matches.get_one::<i32>("keepalive_retries").unwrap();
+
     let num_cpus = num_cpus::get();
     info!("{} cores available", num_cpus);
 
@@ -218,13 +251,15 @@ async fn main() -> io::Result<()> {
             // spawn "fastpath" UDP socket and task, this will offload main task
             // from forwarding UDP packets
 
-            let packet_received = Arc::new(Notify::new());
+            let data_received = Arc::new(Notify::new());
+            let tcp_packet_received = Arc::new(Notify::new());
             let quit = CancellationToken::new();
 
             for i in 0..num_cpus {
                 let sock = sock.clone();
                 let quit = quit.clone();
-                let packet_received = packet_received.clone();
+                let data_received = data_received.clone();
+                let tcp_packet_received = tcp_packet_received.clone();
 
                 tokio::spawn(async move {
                     let mut buf_udp = [0u8; MAX_PACKET_LEN];
@@ -241,7 +276,7 @@ async fn main() -> io::Result<()> {
                                     return;
                                 }
 
-                                packet_received.notify_one();
+                                data_received.notify_one();
                             },
                             res = sock.recv(&mut buf_tcp) => {
                                 match res {
@@ -252,6 +287,9 @@ async fn main() -> io::Result<()> {
                                                 quit.cancel();
                                                 return;
                                             }
+                                            data_received.notify_one();
+                                        } else {
+                                            trace!("Empty TCP packet received");
                                         }
                                     },
                                     None => {
@@ -261,7 +299,7 @@ async fn main() -> io::Result<()> {
                                     },
                                 }
 
-                                packet_received.notify_one();
+                                tcp_packet_received.notify_one();
                             },
                             _ = quit.cancelled() => {
                                 debug!("worker {} terminated", i);
@@ -274,27 +312,58 @@ async fn main() -> io::Result<()> {
 
             let connections = connections.clone();
             tokio::spawn(async move {
+                let mut last_tcp_recv = Instant::now();
+                let mut last_data_recv = Instant::now();
+                let mut probe_sent = 0;
+                let mut last_tcp_keepalive_sent = Instant::now();
                 loop {
-                    let read_timeout = time::sleep(UDP_TTL);
-                    let packet_received_fut = packet_received.notified();
+                    let tm0 = time::sleep(Duration::from_secs(1));
+                    let data_received_fut = data_received.notified();
+                    let tcp_packet_received_fut = tcp_packet_received.notified();
 
                     tokio::select! {
-                        _ = read_timeout => {
-                            info!("No traffic seen in the last {:?}, closing connection", UDP_TTL);
-                            connections.write().await.remove(&addr);
-                            debug!("removed fake TCP socket from connections table");
+                        _ = tm0 => {
+                            
+                            if tcp_keepalive_time.is_some() {
+                                let tcp_idle = last_tcp_recv.elapsed();
+                                if tcp_idle > tcp_keepalive_time.unwrap() {
+                                    if last_tcp_keepalive_sent.elapsed() >= tcp_keepalive_intvl {
+                                        if probe_sent >= tcp_keepalive_retries {
+                                            info!("Connection {} TCP keep-alive retry exceeded", sock);
+                                            break;
+                                        }
+                                        trace!("Connection {} sending keep-alive {} / {}", sock, probe_sent + 1, tcp_keepalive_retries);
+                                        let result = sock.send_keepalive().await;
+                                        if result.is_none() {
+                                            warn!("Failed to send keep-alive!");
+                                        }
+                                        probe_sent += 1;
+                                        last_tcp_keepalive_sent = Instant::now();
+                                    }
+                                }
+                            }
 
-                            quit.cancel();
-                            return;
+                            if last_data_recv.elapsed() > UDP_TTL {
+                                // Execute every 1 sec, unless there's traffic
+                                info!("Connection {} No traffic seen in the last {:?}, closing connection", sock, UDP_TTL);
+                                break;
+                            }
                         },
                         _ = quit.cancelled() => {
-                            connections.write().await.remove(&addr);
-                            debug!("removed fake TCP socket from connections table");
-                            return;
+                            break;
+                        },
+                        _ = data_received_fut => {
+                            last_data_recv = Instant::now();
+                        },
+                        _ = tcp_packet_received_fut => {
+                            last_tcp_recv = Instant::now();
+                            probe_sent = 0;
                         },
-                        _ = packet_received_fut => {},
                     }
                 }
+                debug!("removed fake TCP socket {} from connections table", sock);
+                connections.write().await.remove(&addr);
+                quit.cancel();
             });
         }
     });