465 lines
16 KiB
Rust
Raw Normal View History

#![cfg_attr(feature = "benchmark", feature(test))]
2021-09-16 11:25:28 -07:00
pub mod packet;
extern crate dndx_fork_tokio_tun as tokio_tun;
2021-09-16 11:25:28 -07:00
use bytes::{Bytes, BytesMut};
use log::{error, info, trace, warn};
2021-09-16 11:25:28 -07:00
use packet::*;
use pnet::packet::{tcp, Packet};
use rand::prelude::*;
use std::collections::{HashMap, HashSet};
2021-09-16 13:27:52 -07:00
use std::fmt;
2021-09-16 11:25:28 -07:00
use std::net::{Ipv4Addr, SocketAddrV4};
use std::sync::atomic::{AtomicU32, Ordering};
use std::sync::{Arc, RwLock};
use tokio::sync::broadcast;
2021-09-16 13:27:52 -07:00
use tokio::sync::mpsc::{self, Receiver, Sender};
use tokio::sync::watch;
2021-09-16 11:25:28 -07:00
use tokio::sync::Mutex as AsyncMutex;
use tokio::time;
2021-09-16 11:25:28 -07:00
use tokio_tun::Tun;
2021-09-16 13:27:52 -07:00
const TIMEOUT: time::Duration = time::Duration::from_secs(1);
const RETRIES: usize = 6;
const MPSC_BUFFER_LEN: usize = 512;
2021-09-16 11:25:28 -07:00
#[derive(Hash, Eq, PartialEq, Clone, Debug)]
2021-09-16 11:25:28 -07:00
pub struct AddrTuple {
local_addr: SocketAddrV4,
remote_addr: SocketAddrV4,
}
impl AddrTuple {
fn new(local_addr: SocketAddrV4, remote_addr: SocketAddrV4) -> AddrTuple {
AddrTuple {
local_addr,
remote_addr,
}
}
}
struct Shared {
tuples: RwLock<HashMap<AddrTuple, Sender<Bytes>>>,
listening: RwLock<HashSet<u16>>,
tun: Vec<Arc<Tun>>,
2021-09-16 11:25:28 -07:00
ready: Sender<Socket>,
tuples_purge: broadcast::Sender<AddrTuple>,
2021-09-16 11:25:28 -07:00
}
pub struct Stack {
shared: Arc<Shared>,
local_ip: Ipv4Addr,
ready: Receiver<Socket>,
}
pub enum State {
Idle,
SynSent,
SynReceived,
Established,
}
pub struct Socket {
shared: Arc<Shared>,
tun: Arc<Tun>,
2021-09-16 11:25:28 -07:00
incoming: AsyncMutex<Receiver<Bytes>>,
local_addr: SocketAddrV4,
remote_addr: SocketAddrV4,
seq: AtomicU32,
ack: AtomicU32,
state: State,
2021-09-16 13:27:52 -07:00
closing_tx: watch::Sender<()>,
closing_rx: watch::Receiver<()>,
2021-09-16 11:25:28 -07:00
}
impl Socket {
fn new(
shared: Arc<Shared>,
tun: Arc<Tun>,
2021-09-16 11:25:28 -07:00
local_addr: SocketAddrV4,
remote_addr: SocketAddrV4,
ack: Option<u32>,
state: State,
) -> (Socket, Sender<Bytes>) {
let (incoming_tx, incoming_rx) = mpsc::channel(MPSC_BUFFER_LEN);
2021-09-16 13:27:52 -07:00
let (closing_tx, closing_rx) = watch::channel(());
2021-09-16 11:25:28 -07:00
(
Socket {
shared,
tun,
2021-09-16 11:25:28 -07:00
incoming: AsyncMutex::new(incoming_rx),
local_addr,
remote_addr,
seq: AtomicU32::new(0),
ack: AtomicU32::new(ack.unwrap_or(0)),
state,
2021-09-16 13:27:52 -07:00
closing_tx,
closing_rx,
2021-09-16 11:25:28 -07:00
},
incoming_tx,
)
}
fn build_tcp_packet(&self, flags: u16, payload: Option<&[u8]>) -> Bytes {
2021-09-16 23:57:32 -07:00
build_tcp_packet(
2021-09-16 11:25:28 -07:00
self.local_addr,
self.remote_addr,
self.seq.load(Ordering::Relaxed),
self.ack.load(Ordering::Relaxed),
flags,
payload,
2021-09-16 23:57:32 -07:00
)
2021-09-16 11:25:28 -07:00
}
2021-09-16 13:27:52 -07:00
pub async fn send(&self, payload: &[u8]) -> Option<()> {
let mut closing = self.closing_rx.clone();
2021-09-16 11:25:28 -07:00
match self.state {
State::Established => {
let buf = self.build_tcp_packet(tcp::TcpFlags::ACK, Some(payload));
self.seq.fetch_add(payload.len() as u32, Ordering::Relaxed);
2021-09-16 13:27:52 -07:00
tokio::select! {
res = self.tun.send(&buf) => {
2021-09-16 13:27:52 -07:00
res.unwrap();
Some(())
},
_ = closing.changed() => {
None
}
}
2021-09-16 11:25:28 -07:00
}
_ => unreachable!(),
}
}
2021-09-16 13:27:52 -07:00
pub async fn recv(&self, buf: &mut [u8]) -> Option<usize> {
let mut closing = self.closing_rx.clone();
2021-09-16 11:25:28 -07:00
match self.state {
State::Established => {
2021-09-16 13:27:52 -07:00
let mut incoming = self.incoming.lock().await;
tokio::select! {
Some(raw_buf) = incoming.recv() => {
let (_v4_packet, tcp_packet) = parse_ipv4_packet(&raw_buf);
2021-09-16 11:25:28 -07:00
2021-09-16 13:27:52 -07:00
if (tcp_packet.get_flags() & tcp::TcpFlags::RST) != 0 {
info!("Connection {} reset by peer", self);
self.close();
return None;
}
let payload = tcp_packet.payload();
2021-09-16 11:25:28 -07:00
2021-09-16 13:27:52 -07:00
self.ack
.store(tcp_packet.get_sequence().wrapping_add(1), Ordering::Relaxed);
2021-09-16 11:25:28 -07:00
2021-09-16 13:27:52 -07:00
buf[..payload.len()].copy_from_slice(payload);
Some(payload.len())
},
_ = closing.changed() => {
None
}
}
2021-09-16 11:25:28 -07:00
}
_ => unreachable!(),
}
}
2021-09-16 13:27:52 -07:00
pub fn close(&self) {
self.closing_tx.send(()).unwrap();
}
2021-09-16 11:25:28 -07:00
async fn accept(mut self) {
2021-09-16 13:27:52 -07:00
for _ in 0..RETRIES {
2021-09-16 11:25:28 -07:00
match self.state {
State::Idle => {
let buf = self.build_tcp_packet(tcp::TcpFlags::SYN | tcp::TcpFlags::ACK, None);
// ACK set by constructor
self.tun.send(&buf).await.unwrap();
2021-09-16 11:25:28 -07:00
self.state = State::SynReceived;
2021-09-16 13:27:52 -07:00
info!("Sent SYN + ACK to client");
2021-09-16 11:25:28 -07:00
}
State::SynReceived => {
let res = time::timeout(TIMEOUT, self.incoming.lock().await.recv()).await;
if let Ok(buf) = res {
let buf = buf.unwrap();
let (_v4_packet, tcp_packet) = parse_ipv4_packet(&buf);
if (tcp_packet.get_flags() & tcp::TcpFlags::RST) != 0 {
return;
}
if tcp_packet.get_flags() == tcp::TcpFlags::ACK
&& tcp_packet.get_acknowledgement()
== self.seq.load(Ordering::Relaxed) + 1
{
// found our ACK
self.seq.fetch_add(1, Ordering::Relaxed);
self.state = State::Established;
2021-09-16 13:27:52 -07:00
info!("Connection from {:?} established", self.remote_addr);
2021-09-16 11:25:28 -07:00
let ready = self.shared.ready.clone();
if let Err(e) = ready.send(self).await {
error!("Unable to send accepted socket to ready queue: {}", e);
}
2021-09-16 11:25:28 -07:00
return;
}
} else {
2021-09-16 13:27:52 -07:00
info!("Waiting for client ACK timed out");
self.state = State::Idle;
2021-09-16 11:25:28 -07:00
}
}
_ => unreachable!(),
}
}
}
2021-09-16 13:27:52 -07:00
async fn connect(&mut self) -> Option<()> {
for _ in 0..RETRIES {
2021-09-16 11:25:28 -07:00
match self.state {
State::Idle => {
let buf = self.build_tcp_packet(tcp::TcpFlags::SYN, None);
self.tun.send(&buf).await.unwrap();
2021-09-16 11:25:28 -07:00
self.state = State::SynSent;
2021-09-16 13:27:52 -07:00
info!("Sent SYN to server");
2021-09-16 11:25:28 -07:00
}
State::SynSent => {
match time::timeout(TIMEOUT, self.incoming.lock().await.recv()).await {
Ok(buf) => {
let buf = buf.unwrap();
let (_v4_packet, tcp_packet) = parse_ipv4_packet(&buf);
if (tcp_packet.get_flags() & tcp::TcpFlags::RST) != 0 {
2021-09-16 13:27:52 -07:00
return None;
2021-09-16 11:25:28 -07:00
}
if tcp_packet.get_flags() == tcp::TcpFlags::SYN | tcp::TcpFlags::ACK
&& tcp_packet.get_acknowledgement()
== self.seq.load(Ordering::Relaxed) + 1
{
// found our SYN + ACK
self.seq.fetch_add(1, Ordering::Relaxed);
self.ack
.store(tcp_packet.get_sequence() + 1, Ordering::Relaxed);
// send ACK to finish handshake
let buf = self.build_tcp_packet(tcp::TcpFlags::ACK, None);
self.tun.send(&buf).await.unwrap();
2021-09-16 11:25:28 -07:00
self.state = State::Established;
2021-09-16 13:27:52 -07:00
info!("Connection to {:?} established", self.remote_addr);
return Some(());
2021-09-16 11:25:28 -07:00
}
}
Err(_) => {
2021-09-16 13:27:52 -07:00
info!("Waiting for SYN + ACK timed out");
2021-09-16 11:25:28 -07:00
self.state = State::Idle;
}
}
}
_ => unreachable!(),
}
}
2021-09-16 13:27:52 -07:00
None
2021-09-16 11:25:28 -07:00
}
}
impl Drop for Socket {
fn drop(&mut self) {
let tuple = AddrTuple::new(self.local_addr, self.remote_addr);
2021-09-16 11:25:28 -07:00
// dissociates ourself from the dispatch map
assert!(self.shared.tuples.write().unwrap().remove(&tuple).is_some());
// purge cache
self.shared.tuples_purge.send(tuple).unwrap();
2021-09-16 11:25:28 -07:00
let buf = self.build_tcp_packet(tcp::TcpFlags::RST, None);
if let Err(e) = self.tun.try_send(&buf) {
warn!("Unable to send RST to remote end: {}", e);
}
2021-09-16 13:27:52 -07:00
self.close();
info!("Fake TCP connection to {} closed", self);
}
}
impl fmt::Display for Socket {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(
f,
"(Fake TCP connection from {} to {})",
self.local_addr, self.remote_addr
)
2021-09-16 11:25:28 -07:00
}
}
impl Stack {
pub fn new(tun: Vec<Tun>) -> Stack {
let tun: Vec<Arc<Tun>> = tun.into_iter().map(Arc::new).collect();
2021-09-16 11:25:28 -07:00
let (ready_tx, ready_rx) = mpsc::channel(MPSC_BUFFER_LEN);
let (tuples_purge_tx, _tuples_purge_rx) = broadcast::channel(16);
2021-09-16 11:25:28 -07:00
let shared = Arc::new(Shared {
tuples: RwLock::new(HashMap::new()),
tun: tun.clone(),
listening: RwLock::new(HashSet::new()),
2021-09-16 11:25:28 -07:00
ready: ready_tx,
tuples_purge: tuples_purge_tx.clone(),
2021-09-16 11:25:28 -07:00
});
let local_ip = tun[0].destination().unwrap();
for t in tun {
tokio::spawn(Stack::reader_task(
t,
shared.clone(),
tuples_purge_tx.subscribe(),
));
}
2021-09-16 11:25:28 -07:00
Stack {
shared,
local_ip,
ready: ready_rx,
}
}
pub fn listen(&mut self, port: u16) {
assert!(self.shared.listening.write().unwrap().insert(port));
2021-09-16 11:25:28 -07:00
}
pub async fn accept(&mut self) -> Socket {
self.ready.recv().await.unwrap()
}
2021-09-16 13:27:52 -07:00
pub async fn connect(&mut self, addr: SocketAddrV4) -> Option<Socket> {
2021-09-16 11:25:28 -07:00
let mut rng = SmallRng::from_entropy();
let local_port: u16 = rng.gen_range(1024..65535);
let local_addr = SocketAddrV4::new(self.local_ip, local_port);
let tuple = AddrTuple::new(local_addr, addr);
let (mut sock, incoming) = Socket::new(
self.shared.clone(),
self.shared.tun.choose(&mut rng).unwrap().clone(),
2021-09-16 11:25:28 -07:00
local_addr,
addr,
None,
State::Idle,
);
{
let mut tuples = self.shared.tuples.write().unwrap();
assert!(tuples.insert(tuple, incoming.clone()).is_none());
2021-09-16 11:25:28 -07:00
}
2021-09-16 13:27:52 -07:00
sock.connect().await.map(|_| sock)
2021-09-16 11:25:28 -07:00
}
async fn reader_task(
tun: Arc<Tun>,
shared: Arc<Shared>,
mut tuples_purge: broadcast::Receiver<AddrTuple>,
) {
let mut tuples: HashMap<AddrTuple, Sender<Bytes>> = HashMap::new();
2021-09-16 11:25:28 -07:00
loop {
let mut buf = BytesMut::with_capacity(MAX_PACKET_LEN);
buf.resize(MAX_PACKET_LEN, 0);
2021-09-16 11:25:28 -07:00
tokio::select! {
size = tun.recv(&mut buf) => {
let size = size.unwrap();
buf.truncate(size);
2021-09-16 11:25:28 -07:00
let buf = buf.freeze();
2021-09-16 11:25:28 -07:00
if buf[0] >> 4 != 4 {
// not an IPv4 packet
continue;
}
let (ip_packet, tcp_packet) = parse_ipv4_packet(&buf);
let local_addr =
SocketAddrV4::new(ip_packet.get_destination(), tcp_packet.get_destination());
2021-09-16 11:25:28 -07:00
let remote_addr = SocketAddrV4::new(ip_packet.get_source(), tcp_packet.get_source());
let tuple = AddrTuple::new(local_addr, remote_addr);
if let Some(c) = tuples.get(&tuple) {
2021-09-16 11:25:28 -07:00
c.send(buf).await.unwrap();
continue;
} else {
trace!("Cache miss, checking the shared tuples table for connection");
let sender;
{
let tuples = shared.tuples.read().unwrap();
sender = tuples.get(&tuple).cloned();
}
if let Some(c) = sender {
trace!("Storing connection information into local tuples");
tuples.insert(tuple, c.clone());
c.send(buf).await.unwrap();
continue;
}
2021-09-16 11:25:28 -07:00
}
if tcp_packet.get_flags() == tcp::TcpFlags::SYN
&& shared
.listening
.read()
.unwrap()
.contains(&tcp_packet.get_destination())
{
2021-09-16 11:25:28 -07:00
// SYN seen on listening socket
if tcp_packet.get_sequence() == 0 {
let (sock, incoming) = Socket::new(
shared.clone(),
tun.clone(),
local_addr,
remote_addr,
Some(tcp_packet.get_sequence() + 1),
State::Idle,
);
assert!(shared
.tuples
.write()
.unwrap()
.insert(tuple, incoming)
.is_none());
2021-09-16 11:25:28 -07:00
tokio::spawn(sock.accept());
} else {
trace!("Bad TCP SYN packet from {}, sending RST", remote_addr);
2021-09-16 11:25:28 -07:00
let buf = build_tcp_packet(
local_addr,
remote_addr,
0,
tcp_packet.get_sequence() + 1,
2021-09-16 11:25:28 -07:00
tcp::TcpFlags::RST,
None,
);
shared.tun[0].try_send(&buf).unwrap();
2021-09-16 11:25:28 -07:00
}
} else if (tcp_packet.get_flags() & tcp::TcpFlags::RST) == 0 {
info!("Unknown TCP packet from {}, sending RST", remote_addr);
let buf = build_tcp_packet(
local_addr,
remote_addr,
tcp_packet.get_acknowledgement(),
0,
tcp::TcpFlags::RST,
None,
);
shared.tun[0].try_send(&buf).unwrap();
2021-09-16 11:25:28 -07:00
}
},
tuple = tuples_purge.recv() => {
let tuple = tuple.unwrap();
tuples.remove(&tuple);
trace!("Removed cached tuple");
2021-09-16 11:25:28 -07:00
}
}
}
}
}