From 7c06c5b08b611a28d946918242e039d2ac6bdb6c Mon Sep 17 00:00:00 2001 From: Datong Sun Date: Thu, 16 Sep 2021 13:27:52 -0700 Subject: [PATCH] feat(phantom) better error handling --- Cargo.toml | 2 + src/bin/client.rs | 48 +++++++++++++----- src/bin/server.rs | 26 ++++++---- src/fake_tcp/mod.rs | 110 ++++++++++++++++++++++++++++++----------- src/fake_tcp/packet.rs | 2 +- 5 files changed, 137 insertions(+), 51 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index a0e5ec7..da00fde 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -11,3 +11,5 @@ tokio = { version = "1.11.0", features = ["full"] } lru_time_cache = "0.11.11" rand = { version = "0.8.4", features = ["small_rng"] } clap = "2.33.3" +log = "0.4" +pretty_env_logger = "0.4.0" diff --git a/src/bin/client.rs b/src/bin/client.rs index 3d53ea1..630f25a 100644 --- a/src/bin/client.rs +++ b/src/bin/client.rs @@ -1,20 +1,22 @@ use clap::{App, Arg}; -use lru_time_cache::LruCache; +use log::{debug, error, info}; +use lru_time_cache::{LruCache, TimedEntry}; use phantom::fake_tcp::packet::MAX_PACKET_LEN; use phantom::fake_tcp::{Socket, Stack}; use std::net::{SocketAddr, SocketAddrV4}; use std::sync::Arc; -use std::thread; use std::time::Duration; use tokio::net::UdpSocket; use tokio::sync::Mutex; use tokio::time; use tokio_tun::TunBuilder; -const UDP_TTL: Duration = Duration::from_secs(300); +const UDP_TTL: Duration = Duration::from_secs(180); #[tokio::main] async fn main() { + pretty_env_logger::init(); + let matches = App::new("Phantom Client") .version("1.0") .author("Dndx") @@ -59,12 +61,13 @@ async fn main() { .try_build() .unwrap(); + info!("Created TUN device {}", tun.name()); + let udp_sock = Arc::new(UdpSocket::bind(local_addr).await.unwrap()); let connections = Mutex::new(LruCache::>::with_expiry_duration( UDP_TTL, )); - thread::sleep(Duration::from_secs(5)); let mut stack = Stack::new(tun); let main_loop = tokio::spawn(async move { @@ -79,28 +82,49 @@ async fn main() { continue; } - let mut sock = Arc::new(stack.connect(remote_addr).await); - sock.send(&buf_r[..size]).await; + info!("New UDP client from {}", addr); + let sock = stack.connect(remote_addr).await; + if sock.is_none() { + error!("Unable to connect to remote {}", remote_addr); + continue; + } + + let sock = Arc::new(sock.unwrap()); + let res = sock.send(&buf_r[..size]).await; + if res.is_none() { + continue; + } + assert!(connections.lock().await.insert(addr, sock.clone()).is_none()); let udp_sock = udp_sock.clone(); tokio::spawn(async move { loop { let mut buf_r = [0u8; MAX_PACKET_LEN]; - let size = sock.recv(&mut buf_r).await; - - if size > 0 { - udp_sock.send_to(&buf_r[..size], addr).await.unwrap(); + match sock.recv(&mut buf_r).await { + Some(size) => { + udp_sock.send_to(&buf_r[..size], addr).await.unwrap(); + }, + None => { return; }, } } }); }, _ = cleanup_timer.tick() => { - connections.lock().await.iter(); + let mut total = 0; + + for c in connections.lock().await.notify_iter() { + if let TimedEntry::Expired(_addr, sock) = c { + sock.close(); + total += 1; + } + } + + debug!("Cleaned {} stale connections", total); }, } } }); - tokio::join!(main_loop); + tokio::join!(main_loop).0.unwrap(); } diff --git a/src/bin/server.rs b/src/bin/server.rs index d150fbb..b2359bc 100644 --- a/src/bin/server.rs +++ b/src/bin/server.rs @@ -1,13 +1,14 @@ -use clap::{App, Arg, SubCommand}; +use clap::{App, Arg}; use phantom::fake_tcp::packet::MAX_PACKET_LEN; use phantom::fake_tcp::Stack; use std::net::SocketAddrV4; -use std::{thread, time}; use tokio::net::UdpSocket; use tokio_tun::TunBuilder; #[tokio::main] async fn main() { + pretty_env_logger::init(); + let matches = App::new("Phantom Server") .version("1.0") .author("Dndx") @@ -69,18 +70,25 @@ async fn main() { loop { tokio::select! { Ok(size) = udp_sock.recv(&mut buf_udp) => { - sock.send(&buf_udp[..size]).await; - }, - size = sock.recv(&mut buf_tcp) => { - if size > 0 { - udp_sock.send(&buf_tcp[..size]).await.unwrap(); + if let None = sock.send(&buf_udp[..size]).await { + return; } - } + }, + res = sock.recv(&mut buf_tcp) => { + match res { + Some(size) => { + if size > 0 { + udp_sock.send(&buf_tcp[..size]).await.unwrap(); + } + }, + None => { return; }, + } + }, }; } }); } }); - tokio::join!(main_loop); + tokio::join!(main_loop).0.unwrap(); } diff --git a/src/fake_tcp/mod.rs b/src/fake_tcp/mod.rs index 81fe6ad..c7489da 100644 --- a/src/fake_tcp/mod.rs +++ b/src/fake_tcp/mod.rs @@ -1,23 +1,24 @@ pub mod packet; use bytes::{Bytes, BytesMut}; +use log::info; use packet::*; use pnet::packet::{tcp, Packet}; use rand::prelude::*; -use std::cell::RefCell; -use std::cmp::max; use std::collections::{HashMap, HashSet}; -use std::io::{Error, Result}; +use std::fmt; use std::net::{Ipv4Addr, SocketAddrV4}; use std::sync::atomic::{AtomicU32, Ordering}; use std::sync::{Arc, Mutex}; use tokio::io::{AsyncReadExt, AsyncWriteExt}; -use tokio::sync::mpsc::{self, error::TrySendError, Receiver, Sender}; +use tokio::sync::mpsc::{self, Receiver, Sender}; +use tokio::sync::watch; use tokio::sync::Mutex as AsyncMutex; use tokio::{io, time}; use tokio_tun::Tun; -const TIMEOUT: time::Duration = time::Duration::from_secs(5); +const TIMEOUT: time::Duration = time::Duration::from_secs(1); +const RETRIES: usize = 6; const MPSC_BUFFER_LEN: usize = 128; #[derive(Debug, Hash, Eq, PartialEq)] @@ -73,6 +74,8 @@ pub struct Socket { seq: AtomicU32, ack: AtomicU32, state: State, + closing_tx: watch::Sender<()>, + closing_rx: watch::Receiver<()>, } impl Socket { @@ -85,6 +88,7 @@ impl Socket { state: State, ) -> (Socket, Sender) { let (incoming_tx, incoming_rx) = mpsc::channel(MPSC_BUFFER_LEN); + let (closing_tx, closing_rx) = watch::channel(()); ( Socket { @@ -96,6 +100,8 @@ impl Socket { seq: AtomicU32::new(0), ack: AtomicU32::new(ack.unwrap_or(0)), state, + closing_tx, + closing_rx, }, incoming_tx, ) @@ -112,43 +118,75 @@ impl Socket { ); } - pub async fn send(&self, payload: &[u8]) { + pub async fn send(&self, payload: &[u8]) -> Option<()> { + let mut closing = self.closing_rx.clone(); + match self.state { State::Established => { let buf = self.build_tcp_packet(tcp::TcpFlags::ACK, Some(payload)); self.seq.fetch_add(buf.len() as u32, Ordering::Relaxed); - self.shared.outgoing.send(buf).await.unwrap(); + + tokio::select! { + res = self.shared.outgoing.send(buf) => { + res.unwrap(); + Some(()) + }, + _ = closing.changed() => { + None + } + } } _ => unreachable!(), } } - pub async fn recv(&self, buf: &mut [u8]) -> usize { + pub async fn recv(&self, buf: &mut [u8]) -> Option { + let mut closing = self.closing_rx.clone(); + match self.state { State::Established => { - let raw_buf = self.incoming.lock().await.recv().await.unwrap(); - let (_v4_packet, tcp_packet) = parse_ipv4_packet(&raw_buf); - let payload = tcp_packet.payload(); + let mut incoming = self.incoming.lock().await; + tokio::select! { + Some(raw_buf) = incoming.recv() => { + let (_v4_packet, tcp_packet) = parse_ipv4_packet(&raw_buf); - self.ack - .fetch_max(tcp_packet.get_sequence() + 1, Ordering::Relaxed); + if (tcp_packet.get_flags() & tcp::TcpFlags::RST) != 0 { + info!("Connection {} reset by peer", self); + self.close(); + return None; + } - buf[..payload.len()].copy_from_slice(payload); + let payload = tcp_packet.payload(); - payload.len() + self.ack + .store(tcp_packet.get_sequence().wrapping_add(1), Ordering::Relaxed); + + buf[..payload.len()].copy_from_slice(payload); + + Some(payload.len()) + }, + _ = closing.changed() => { + None + } + } } _ => unreachable!(), } } + pub fn close(&self) { + self.closing_tx.send(()).unwrap(); + } + async fn accept(mut self) { - loop { + for _ in 0..RETRIES { match self.state { State::Idle => { let buf = self.build_tcp_packet(tcp::TcpFlags::SYN | tcp::TcpFlags::ACK, None); // ACK set by constructor self.shared.outgoing.send(buf).await.unwrap(); self.state = State::SynReceived; + info!("Sent SYN + ACK to client"); } State::SynReceived => { let res = time::timeout(TIMEOUT, self.incoming.lock().await.recv()).await; @@ -168,14 +206,14 @@ impl Socket { self.seq.fetch_add(1, Ordering::Relaxed); self.state = State::Established; - println!("Connection from {:?} established", self.remote_addr); + info!("Connection from {:?} established", self.remote_addr); let ready = self.shared.ready.clone(); ready.send(self).await.unwrap(); return; } } else { - println!("waiting for SYN + ACK timed out, dropping connection"); - return; + info!("Waiting for client ACK timed out"); + self.state = State::Idle; } } _ => unreachable!(), @@ -183,13 +221,14 @@ impl Socket { } } - async fn connect(&mut self) { - loop { + async fn connect(&mut self) -> Option<()> { + for _ in 0..RETRIES { match self.state { State::Idle => { let buf = self.build_tcp_packet(tcp::TcpFlags::SYN, None); self.shared.outgoing.send(buf).await.unwrap(); self.state = State::SynSent; + info!("Sent SYN to server"); } State::SynSent => { match time::timeout(TIMEOUT, self.incoming.lock().await.recv()).await { @@ -198,7 +237,7 @@ impl Socket { let (_v4_packet, tcp_packet) = parse_ipv4_packet(&buf); if (tcp_packet.get_flags() & tcp::TcpFlags::RST) != 0 { - return; + return None; } if tcp_packet.get_flags() == tcp::TcpFlags::SYN | tcp::TcpFlags::ACK @@ -216,12 +255,12 @@ impl Socket { self.state = State::Established; - println!("Connection to {:?} established", self.remote_addr); - return; + info!("Connection to {:?} established", self.remote_addr); + return Some(()); } } Err(_) => { - println!("waiting for SYN + ACK timed out, going back to Idle"); + info!("Waiting for SYN + ACK timed out"); self.state = State::Idle; } } @@ -229,6 +268,8 @@ impl Socket { _ => unreachable!(), } } + + None } } @@ -245,6 +286,18 @@ impl Drop for Socket { let buf = self.build_tcp_packet(tcp::TcpFlags::RST, None); self.shared.outgoing.try_send(buf).unwrap(); + self.close(); + info!("Fake TCP connection to {} closed", self); + } +} + +impl fmt::Display for Socket { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!( + f, + "(Fake TCP connection from {} to {})", + self.local_addr, self.remote_addr + ) } } @@ -276,7 +329,7 @@ impl Stack { self.ready.recv().await.unwrap() } - pub async fn connect(&mut self, addr: SocketAddrV4) -> Socket { + pub async fn connect(&mut self, addr: SocketAddrV4) -> Option { let mut rng = SmallRng::from_entropy(); let local_port: u16 = rng.gen_range(1024..65535); let local_addr = SocketAddrV4::new(self.local_ip, local_port); @@ -295,8 +348,7 @@ impl Stack { assert!(tuples.insert(tuple, Arc::new(incoming.clone())).is_none()); } - sock.connect().await; - sock + sock.connect().await.map(|_| sock) } async fn dispatch(tun: Tun, mut outgoing: Receiver, shared: Arc) { @@ -326,7 +378,7 @@ impl Stack { let sender; { - let mut tuples = shared.tuples.lock().unwrap(); + let tuples = shared.tuples.lock().unwrap(); sender = tuples.get(&tuple).map(|c| c.clone()); } diff --git a/src/fake_tcp/packet.rs b/src/fake_tcp/packet.rs index eef4128..283ab76 100644 --- a/src/fake_tcp/packet.rs +++ b/src/fake_tcp/packet.rs @@ -1,5 +1,5 @@ use bytes::{Bytes, BytesMut}; -use pnet::packet::{ip, ipv4, tcp, Packet, PacketSize}; +use pnet::packet::{ip, ipv4, tcp}; use std::convert::TryInto; use std::net::SocketAddrV4;