style(phantun) remove unnecessary tokio::select call

This commit is contained in:
Datong Sun 2022-04-15 07:58:16 -07:00
parent 2f4eaafccd
commit 74183071f1
2 changed files with 108 additions and 106 deletions

View File

@ -4,6 +4,7 @@ use fake_tcp::{Socket, Stack};
use log::{debug, error, info}; use log::{debug, error, info};
use phantun::utils::new_udp_reuseport; use phantun::utils::new_udp_reuseport;
use std::collections::HashMap; use std::collections::HashMap;
use std::io;
use std::net::{Ipv4Addr, SocketAddr}; use std::net::{Ipv4Addr, SocketAddr};
use std::sync::Arc; use std::sync::Arc;
use tokio::sync::{Notify, RwLock}; use tokio::sync::{Notify, RwLock};
@ -14,7 +15,7 @@ use tokio_util::sync::CancellationToken;
use phantun::UDP_TTL; use phantun::UDP_TTL;
#[tokio::main] #[tokio::main]
async fn main() { async fn main() -> io::Result<()> {
pretty_env_logger::init(); pretty_env_logger::init();
let matches = Command::new("Phantun Client") let matches = Command::new("Phantun Client")
@ -122,119 +123,120 @@ async fn main() {
let mut buf_r = [0u8; MAX_PACKET_LEN]; let mut buf_r = [0u8; MAX_PACKET_LEN];
loop { loop {
tokio::select! { let (size, addr) = udp_sock.recv_from(&mut buf_r).await?;
Ok((size, addr)) = udp_sock.recv_from(&mut buf_r) => { // seen UDP packet to listening socket, this means:
// seen UDP packet to listening socket, this means: // 1. It is a new UDP connection, or
// 1. It is a new UDP connection, or // 2. It is some extra packets not filtered by more specific
// 2. It is some extra packets not filtered by more specific // connected UDP socket yet
// connected UDP socket yet if let Some(sock) = connections.read().await.get(&addr) {
if let Some(sock) = connections.read().await.get(&addr) { sock.send(&buf_r[..size]).await;
sock.send(&buf_r[..size]).await; continue;
continue; }
}
info!("New UDP client from {}", addr); info!("New UDP client from {}", addr);
let sock = stack.connect(remote_addr).await; let sock = stack.connect(remote_addr).await;
if sock.is_none() { if sock.is_none() {
error!("Unable to connect to remote {}", remote_addr); error!("Unable to connect to remote {}", remote_addr);
continue; continue;
} }
let sock = Arc::new(sock.unwrap()); let sock = Arc::new(sock.unwrap());
// send first packet // send first packet
let res = sock.send(&buf_r[..size]).await; let res = sock.send(&buf_r[..size]).await;
if res.is_none() { if res.is_none() {
continue; continue;
} }
assert!(connections.write().await.insert(addr, sock.clone()).is_none()); assert!(connections
debug!("inserted fake TCP socket into connection table"); .write()
.await
.insert(addr, sock.clone())
.is_none());
debug!("inserted fake TCP socket into connection table");
// spawn "fastpath" UDP socket and task, this will offload main task // spawn "fastpath" UDP socket and task, this will offload main task
// from forwarding UDP packets // from forwarding UDP packets
let packet_received = Arc::new(Notify::new()); let packet_received = Arc::new(Notify::new());
let quit = CancellationToken::new(); let quit = CancellationToken::new();
for i in 0..num_cpus { for i in 0..num_cpus {
let sock = sock.clone(); let sock = sock.clone();
let quit = quit.clone(); let quit = quit.clone();
let packet_received = packet_received.clone(); let packet_received = packet_received.clone();
tokio::spawn(async move { tokio::spawn(async move {
let mut buf_udp = [0u8; MAX_PACKET_LEN]; let mut buf_udp = [0u8; MAX_PACKET_LEN];
let mut buf_tcp = [0u8; MAX_PACKET_LEN]; let mut buf_tcp = [0u8; MAX_PACKET_LEN];
let udp_sock = new_udp_reuseport(local_addr); let udp_sock = new_udp_reuseport(local_addr);
udp_sock.connect(addr).await.unwrap(); udp_sock.connect(addr).await.unwrap();
loop { loop {
tokio::select! { tokio::select! {
Ok(size) = udp_sock.recv(&mut buf_udp) => { Ok(size) = udp_sock.recv(&mut buf_udp) => {
if sock.send(&buf_udp[..size]).await.is_none() { if sock.send(&buf_udp[..size]).await.is_none() {
debug!("removed fake TCP socket from connections table");
quit.cancel();
return;
}
packet_received.notify_one();
},
res = sock.recv(&mut buf_tcp) => {
match res {
Some(size) => {
if size > 0 {
if let Err(e) = udp_sock.send(&buf_tcp[..size]).await {
error!("Unable to send UDP packet to {}: {}, closing connection", e, addr);
quit.cancel();
return;
}
}
},
None => {
debug!("removed fake TCP socket from connections table");
quit.cancel();
return;
},
}
packet_received.notify_one();
},
_ = quit.cancelled() => {
debug!("worker {} terminated", i);
return;
},
};
}
});
}
let connections = connections.clone();
tokio::spawn(async move {
loop {
let read_timeout = time::sleep(UDP_TTL);
let packet_received_fut = packet_received.notified();
tokio::select! {
_ = read_timeout => {
info!("No traffic seen in the last {:?}, closing connection", UDP_TTL);
connections.write().await.remove(&addr);
debug!("removed fake TCP socket from connections table"); debug!("removed fake TCP socket from connections table");
quit.cancel(); quit.cancel();
return; return;
}, }
_ = quit.cancelled() => {
connections.write().await.remove(&addr); packet_received.notify_one();
debug!("removed fake TCP socket from connections table"); },
return; res = sock.recv(&mut buf_tcp) => {
}, match res {
_ = packet_received_fut => {}, Some(size) => {
} if size > 0 {
} if let Err(e) = udp_sock.send(&buf_tcp[..size]).await {
}); error!("Unable to send UDP packet to {}: {}, closing connection", e, addr);
}, quit.cancel();
return;
}
}
},
None => {
debug!("removed fake TCP socket from connections table");
quit.cancel();
return;
},
}
packet_received.notify_one();
},
_ = quit.cancelled() => {
debug!("worker {} terminated", i);
return;
},
};
}
});
} }
let connections = connections.clone();
tokio::spawn(async move {
loop {
let read_timeout = time::sleep(UDP_TTL);
let packet_received_fut = packet_received.notified();
tokio::select! {
_ = read_timeout => {
info!("No traffic seen in the last {:?}, closing connection", UDP_TTL);
connections.write().await.remove(&addr);
debug!("removed fake TCP socket from connections table");
quit.cancel();
return;
},
_ = quit.cancelled() => {
connections.write().await.remove(&addr);
debug!("removed fake TCP socket from connections table");
return;
},
_ = packet_received_fut => {},
}
}
});
} }
}); });
tokio::join!(main_loop).0.unwrap(); tokio::join!(main_loop).0.unwrap()
} }

View File

@ -3,6 +3,7 @@ use fake_tcp::packet::MAX_PACKET_LEN;
use fake_tcp::Stack; use fake_tcp::Stack;
use log::{debug, error, info}; use log::{debug, error, info};
use phantun::utils::new_udp_reuseport; use phantun::utils::new_udp_reuseport;
use std::io;
use std::net::Ipv4Addr; use std::net::Ipv4Addr;
use std::sync::Arc; use std::sync::Arc;
use tokio::net::UdpSocket; use tokio::net::UdpSocket;
@ -14,7 +15,7 @@ use tokio_util::sync::CancellationToken;
use phantun::UDP_TTL; use phantun::UDP_TTL;
#[tokio::main] #[tokio::main]
async fn main() { async fn main() -> io::Result<()> {
pretty_env_logger::init(); pretty_env_logger::init();
let matches = Command::new("Phantun Server") let matches = Command::new("Phantun Server")
@ -128,9 +129,8 @@ async fn main() {
} else { } else {
"[::]:0" "[::]:0"
}) })
.await .await?;
.unwrap(); let local_addr = udp_sock.local_addr()?;
let local_addr = udp_sock.local_addr().unwrap();
drop(udp_sock); drop(udp_sock);
for i in 0..num_cpus { for i in 0..num_cpus {
@ -199,5 +199,5 @@ async fn main() {
} }
}); });
tokio::join!(main_loop).0.unwrap(); tokio::join!(main_loop).0.unwrap()
} }