mirror of
https://github.com/dndx/phantun.git
synced 2025-01-31 20:29:31 +08:00
perf(locks) use RwLocks in place of Mutex when possible
This commit is contained in:
parent
5866cbe512
commit
ae52531288
@ -7,7 +7,7 @@ use std::net::{SocketAddr, SocketAddrV4};
|
|||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
use std::time::Duration;
|
use std::time::Duration;
|
||||||
use tokio::net::UdpSocket;
|
use tokio::net::UdpSocket;
|
||||||
use tokio::sync::Mutex;
|
use tokio::sync::RwLock;
|
||||||
use tokio::time;
|
use tokio::time;
|
||||||
use tokio_tun::TunBuilder;
|
use tokio_tun::TunBuilder;
|
||||||
|
|
||||||
@ -64,7 +64,7 @@ async fn main() {
|
|||||||
info!("Created TUN device {}", tun.name());
|
info!("Created TUN device {}", tun.name());
|
||||||
|
|
||||||
let udp_sock = Arc::new(UdpSocket::bind(local_addr).await.unwrap());
|
let udp_sock = Arc::new(UdpSocket::bind(local_addr).await.unwrap());
|
||||||
let connections = Arc::new(Mutex::new(
|
let connections = Arc::new(RwLock::new(
|
||||||
LruCache::<SocketAddrV4, Arc<Socket>>::with_expiry_duration(UDP_TTL),
|
LruCache::<SocketAddrV4, Arc<Socket>>::with_expiry_duration(UDP_TTL),
|
||||||
));
|
));
|
||||||
|
|
||||||
@ -77,7 +77,7 @@ async fn main() {
|
|||||||
loop {
|
loop {
|
||||||
tokio::select! {
|
tokio::select! {
|
||||||
Ok((size, SocketAddr::V4(addr))) = udp_sock.recv_from(&mut buf_r) => {
|
Ok((size, SocketAddr::V4(addr))) = udp_sock.recv_from(&mut buf_r) => {
|
||||||
if let Some(sock) = connections.lock().await.get_mut(&addr) {
|
if let Some(sock) = connections.read().await.peek(&addr) {
|
||||||
sock.send(&buf_r[..size]).await;
|
sock.send(&buf_r[..size]).await;
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
@ -95,7 +95,7 @@ async fn main() {
|
|||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
assert!(connections.lock().await.insert(addr, sock.clone()).is_none());
|
assert!(connections.write().await.insert(addr, sock.clone()).is_none());
|
||||||
debug!("inserted fake TCP socket into LruCache");
|
debug!("inserted fake TCP socket into LruCache");
|
||||||
let udp_sock = udp_sock.clone();
|
let udp_sock = udp_sock.clone();
|
||||||
|
|
||||||
@ -108,7 +108,7 @@ async fn main() {
|
|||||||
udp_sock.send_to(&buf_r[..size], addr).await.unwrap();
|
udp_sock.send_to(&buf_r[..size], addr).await.unwrap();
|
||||||
},
|
},
|
||||||
None => {
|
None => {
|
||||||
connections.lock().await.remove(&addr);
|
connections.write().await.remove(&addr);
|
||||||
debug!("removed fake TCP socket from LruCache");
|
debug!("removed fake TCP socket from LruCache");
|
||||||
return;
|
return;
|
||||||
},
|
},
|
||||||
@ -119,7 +119,7 @@ async fn main() {
|
|||||||
_ = cleanup_timer.tick() => {
|
_ = cleanup_timer.tick() => {
|
||||||
let mut total = 0;
|
let mut total = 0;
|
||||||
|
|
||||||
for c in connections.lock().await.notify_iter() {
|
for c in connections.write().await.notify_iter() {
|
||||||
if let TimedEntry::Expired(_addr, sock) = c {
|
if let TimedEntry::Expired(_addr, sock) = c {
|
||||||
sock.close();
|
sock.close();
|
||||||
total += 1;
|
total += 1;
|
||||||
|
@ -9,7 +9,7 @@ use std::collections::{HashMap, HashSet};
|
|||||||
use std::fmt;
|
use std::fmt;
|
||||||
use std::net::{Ipv4Addr, SocketAddrV4};
|
use std::net::{Ipv4Addr, SocketAddrV4};
|
||||||
use std::sync::atomic::{AtomicU32, Ordering};
|
use std::sync::atomic::{AtomicU32, Ordering};
|
||||||
use std::sync::{Arc, Mutex};
|
use std::sync::{Arc, RwLock};
|
||||||
use tokio::io::{AsyncReadExt, AsyncWriteExt};
|
use tokio::io::{AsyncReadExt, AsyncWriteExt};
|
||||||
use tokio::sync::mpsc::{self, Receiver, Sender};
|
use tokio::sync::mpsc::{self, Receiver, Sender};
|
||||||
use tokio::sync::watch;
|
use tokio::sync::watch;
|
||||||
@ -38,8 +38,8 @@ impl AddrTuple {
|
|||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
struct Shared {
|
struct Shared {
|
||||||
tuples: Mutex<HashMap<AddrTuple, Arc<Sender<Bytes>>>>,
|
tuples: RwLock<HashMap<AddrTuple, Arc<Sender<Bytes>>>>,
|
||||||
listening: Mutex<HashSet<u16>>,
|
listening: RwLock<HashSet<u16>>,
|
||||||
outgoing: Sender<Bytes>,
|
outgoing: Sender<Bytes>,
|
||||||
ready: Sender<Socket>,
|
ready: Sender<Socket>,
|
||||||
}
|
}
|
||||||
@ -279,7 +279,7 @@ impl Drop for Socket {
|
|||||||
assert!(self
|
assert!(self
|
||||||
.shared
|
.shared
|
||||||
.tuples
|
.tuples
|
||||||
.lock()
|
.write()
|
||||||
.unwrap()
|
.unwrap()
|
||||||
.remove(&AddrTuple::new(self.local_addr, self.remote_addr))
|
.remove(&AddrTuple::new(self.local_addr, self.remote_addr))
|
||||||
.is_some());
|
.is_some());
|
||||||
@ -306,9 +306,9 @@ impl Stack {
|
|||||||
let (outgoing_tx, outgoing_rx) = mpsc::channel(MPSC_BUFFER_LEN);
|
let (outgoing_tx, outgoing_rx) = mpsc::channel(MPSC_BUFFER_LEN);
|
||||||
let (ready_tx, ready_rx) = mpsc::channel(MPSC_BUFFER_LEN);
|
let (ready_tx, ready_rx) = mpsc::channel(MPSC_BUFFER_LEN);
|
||||||
let shared = Arc::new(Shared {
|
let shared = Arc::new(Shared {
|
||||||
tuples: Mutex::new(HashMap::new()),
|
tuples: RwLock::new(HashMap::new()),
|
||||||
outgoing: outgoing_tx,
|
outgoing: outgoing_tx,
|
||||||
listening: Mutex::new(HashSet::new()),
|
listening: RwLock::new(HashSet::new()),
|
||||||
ready: ready_tx,
|
ready: ready_tx,
|
||||||
});
|
});
|
||||||
let local_ip = tun.destination().unwrap();
|
let local_ip = tun.destination().unwrap();
|
||||||
@ -322,7 +322,7 @@ impl Stack {
|
|||||||
}
|
}
|
||||||
|
|
||||||
pub fn listen(&mut self, port: u16) {
|
pub fn listen(&mut self, port: u16) {
|
||||||
assert!(self.shared.listening.lock().unwrap().insert(port));
|
assert!(self.shared.listening.write().unwrap().insert(port));
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn accept(&mut self) -> Socket {
|
pub async fn accept(&mut self) -> Socket {
|
||||||
@ -344,7 +344,7 @@ impl Stack {
|
|||||||
);
|
);
|
||||||
|
|
||||||
{
|
{
|
||||||
let mut tuples = self.shared.tuples.lock().unwrap();
|
let mut tuples = self.shared.tuples.write().unwrap();
|
||||||
assert!(tuples.insert(tuple, Arc::new(incoming.clone())).is_none());
|
assert!(tuples.insert(tuple, Arc::new(incoming.clone())).is_none());
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -378,7 +378,7 @@ impl Stack {
|
|||||||
|
|
||||||
let sender;
|
let sender;
|
||||||
{
|
{
|
||||||
let tuples = shared.tuples.lock().unwrap();
|
let tuples = shared.tuples.read().unwrap();
|
||||||
sender = tuples.get(&tuple).cloned();
|
sender = tuples.get(&tuple).cloned();
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -387,11 +387,11 @@ impl Stack {
|
|||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
if tcp_packet.get_flags() == tcp::TcpFlags::SYN && shared.listening.lock().unwrap().contains(&tcp_packet.get_destination()) {
|
if tcp_packet.get_flags() == tcp::TcpFlags::SYN && shared.listening.read().unwrap().contains(&tcp_packet.get_destination()) {
|
||||||
// SYN seen on listening socket
|
// SYN seen on listening socket
|
||||||
if tcp_packet.get_sequence() == 0 {
|
if tcp_packet.get_sequence() == 0 {
|
||||||
let (sock, incoming) = Socket::new(Mode::Server, shared.clone(), local_addr, remote_addr, Some(tcp_packet.get_sequence() + 1), State::Idle);
|
let (sock, incoming) = Socket::new(Mode::Server, shared.clone(), local_addr, remote_addr, Some(tcp_packet.get_sequence() + 1), State::Idle);
|
||||||
assert!(shared.tuples.lock().unwrap().insert(tuple, Arc::new(incoming)).is_none());
|
assert!(shared.tuples.write().unwrap().insert(tuple, Arc::new(incoming)).is_none());
|
||||||
tokio::spawn(sock.accept());
|
tokio::spawn(sock.accept());
|
||||||
} else {
|
} else {
|
||||||
trace!("Bad TCP SYN packet from {}, sending RST", remote_addr);
|
trace!("Bad TCP SYN packet from {}, sending RST", remote_addr);
|
||||||
|
Loading…
x
Reference in New Issue
Block a user