From 74183071f1ec2a8076a2ba60c16be570a3a880b8 Mon Sep 17 00:00:00 2001 From: Datong Sun Date: Fri, 15 Apr 2022 07:58:16 -0700 Subject: [PATCH] style(phantun) remove unnecessary `tokio::select` call --- phantun/src/bin/client.rs | 204 +++++++++++++++++++------------------- phantun/src/bin/server.rs | 10 +- 2 files changed, 108 insertions(+), 106 deletions(-) diff --git a/phantun/src/bin/client.rs b/phantun/src/bin/client.rs index ef4bf22..816e6da 100644 --- a/phantun/src/bin/client.rs +++ b/phantun/src/bin/client.rs @@ -4,6 +4,7 @@ use fake_tcp::{Socket, Stack}; use log::{debug, error, info}; use phantun::utils::new_udp_reuseport; use std::collections::HashMap; +use std::io; use std::net::{Ipv4Addr, SocketAddr}; use std::sync::Arc; use tokio::sync::{Notify, RwLock}; @@ -14,7 +15,7 @@ use tokio_util::sync::CancellationToken; use phantun::UDP_TTL; #[tokio::main] -async fn main() { +async fn main() -> io::Result<()> { pretty_env_logger::init(); let matches = Command::new("Phantun Client") @@ -122,119 +123,120 @@ async fn main() { let mut buf_r = [0u8; MAX_PACKET_LEN]; loop { - tokio::select! { - Ok((size, addr)) = udp_sock.recv_from(&mut buf_r) => { - // seen UDP packet to listening socket, this means: - // 1. It is a new UDP connection, or - // 2. It is some extra packets not filtered by more specific - // connected UDP socket yet - if let Some(sock) = connections.read().await.get(&addr) { - sock.send(&buf_r[..size]).await; - continue; - } + let (size, addr) = udp_sock.recv_from(&mut buf_r).await?; + // seen UDP packet to listening socket, this means: + // 1. It is a new UDP connection, or + // 2. It is some extra packets not filtered by more specific + // connected UDP socket yet + if let Some(sock) = connections.read().await.get(&addr) { + sock.send(&buf_r[..size]).await; + continue; + } - info!("New UDP client from {}", addr); - let sock = stack.connect(remote_addr).await; - if sock.is_none() { - error!("Unable to connect to remote {}", remote_addr); - continue; - } + info!("New UDP client from {}", addr); + let sock = stack.connect(remote_addr).await; + if sock.is_none() { + error!("Unable to connect to remote {}", remote_addr); + continue; + } - let sock = Arc::new(sock.unwrap()); - // send first packet - let res = sock.send(&buf_r[..size]).await; - if res.is_none() { - continue; - } + let sock = Arc::new(sock.unwrap()); + // send first packet + let res = sock.send(&buf_r[..size]).await; + if res.is_none() { + continue; + } - assert!(connections.write().await.insert(addr, sock.clone()).is_none()); - debug!("inserted fake TCP socket into connection table"); + assert!(connections + .write() + .await + .insert(addr, sock.clone()) + .is_none()); + debug!("inserted fake TCP socket into connection table"); - // spawn "fastpath" UDP socket and task, this will offload main task - // from forwarding UDP packets + // spawn "fastpath" UDP socket and task, this will offload main task + // from forwarding UDP packets - let packet_received = Arc::new(Notify::new()); - let quit = CancellationToken::new(); + let packet_received = Arc::new(Notify::new()); + let quit = CancellationToken::new(); - for i in 0..num_cpus { - let sock = sock.clone(); - let quit = quit.clone(); - let packet_received = packet_received.clone(); + for i in 0..num_cpus { + let sock = sock.clone(); + let quit = quit.clone(); + let packet_received = packet_received.clone(); - tokio::spawn(async move { - let mut buf_udp = [0u8; MAX_PACKET_LEN]; - let mut buf_tcp = [0u8; MAX_PACKET_LEN]; - let udp_sock = new_udp_reuseport(local_addr); - udp_sock.connect(addr).await.unwrap(); + tokio::spawn(async move { + let mut buf_udp = [0u8; MAX_PACKET_LEN]; + let mut buf_tcp = [0u8; MAX_PACKET_LEN]; + let udp_sock = new_udp_reuseport(local_addr); + udp_sock.connect(addr).await.unwrap(); - loop { - tokio::select! { - Ok(size) = udp_sock.recv(&mut buf_udp) => { - if sock.send(&buf_udp[..size]).await.is_none() { - debug!("removed fake TCP socket from connections table"); - quit.cancel(); - return; - } - - packet_received.notify_one(); - }, - res = sock.recv(&mut buf_tcp) => { - match res { - Some(size) => { - if size > 0 { - if let Err(e) = udp_sock.send(&buf_tcp[..size]).await { - error!("Unable to send UDP packet to {}: {}, closing connection", e, addr); - quit.cancel(); - return; - } - } - }, - None => { - debug!("removed fake TCP socket from connections table"); - quit.cancel(); - return; - }, - } - - packet_received.notify_one(); - }, - _ = quit.cancelled() => { - debug!("worker {} terminated", i); - return; - }, - }; - } - }); - } - - let connections = connections.clone(); - tokio::spawn(async move { - loop { - let read_timeout = time::sleep(UDP_TTL); - let packet_received_fut = packet_received.notified(); - - tokio::select! { - _ = read_timeout => { - info!("No traffic seen in the last {:?}, closing connection", UDP_TTL); - connections.write().await.remove(&addr); + loop { + tokio::select! { + Ok(size) = udp_sock.recv(&mut buf_udp) => { + if sock.send(&buf_udp[..size]).await.is_none() { debug!("removed fake TCP socket from connections table"); - quit.cancel(); return; - }, - _ = quit.cancelled() => { - connections.write().await.remove(&addr); - debug!("removed fake TCP socket from connections table"); - return; - }, - _ = packet_received_fut => {}, - } - } - }); - }, + } + + packet_received.notify_one(); + }, + res = sock.recv(&mut buf_tcp) => { + match res { + Some(size) => { + if size > 0 { + if let Err(e) = udp_sock.send(&buf_tcp[..size]).await { + error!("Unable to send UDP packet to {}: {}, closing connection", e, addr); + quit.cancel(); + return; + } + } + }, + None => { + debug!("removed fake TCP socket from connections table"); + quit.cancel(); + return; + }, + } + + packet_received.notify_one(); + }, + _ = quit.cancelled() => { + debug!("worker {} terminated", i); + return; + }, + }; + } + }); } + + let connections = connections.clone(); + tokio::spawn(async move { + loop { + let read_timeout = time::sleep(UDP_TTL); + let packet_received_fut = packet_received.notified(); + + tokio::select! { + _ = read_timeout => { + info!("No traffic seen in the last {:?}, closing connection", UDP_TTL); + connections.write().await.remove(&addr); + debug!("removed fake TCP socket from connections table"); + + quit.cancel(); + return; + }, + _ = quit.cancelled() => { + connections.write().await.remove(&addr); + debug!("removed fake TCP socket from connections table"); + return; + }, + _ = packet_received_fut => {}, + } + } + }); } }); - tokio::join!(main_loop).0.unwrap(); + tokio::join!(main_loop).0.unwrap() } diff --git a/phantun/src/bin/server.rs b/phantun/src/bin/server.rs index 4a32a24..5b8f4c6 100644 --- a/phantun/src/bin/server.rs +++ b/phantun/src/bin/server.rs @@ -3,6 +3,7 @@ use fake_tcp::packet::MAX_PACKET_LEN; use fake_tcp::Stack; use log::{debug, error, info}; use phantun::utils::new_udp_reuseport; +use std::io; use std::net::Ipv4Addr; use std::sync::Arc; use tokio::net::UdpSocket; @@ -14,7 +15,7 @@ use tokio_util::sync::CancellationToken; use phantun::UDP_TTL; #[tokio::main] -async fn main() { +async fn main() -> io::Result<()> { pretty_env_logger::init(); let matches = Command::new("Phantun Server") @@ -128,9 +129,8 @@ async fn main() { } else { "[::]:0" }) - .await - .unwrap(); - let local_addr = udp_sock.local_addr().unwrap(); + .await?; + let local_addr = udp_sock.local_addr()?; drop(udp_sock); for i in 0..num_cpus { @@ -199,5 +199,5 @@ async fn main() { } }); - tokio::join!(main_loop).0.unwrap(); + tokio::join!(main_loop).0.unwrap() }