perf(locks) use RwLocks in place of Mutex when possible

This commit is contained in:
Datong Sun 2021-09-17 06:26:25 -07:00 committed by Datong Sun
parent 5866cbe512
commit ae52531288
2 changed files with 17 additions and 17 deletions

View File

@ -7,7 +7,7 @@ use std::net::{SocketAddr, SocketAddrV4};
use std::sync::Arc;
use std::time::Duration;
use tokio::net::UdpSocket;
use tokio::sync::Mutex;
use tokio::sync::RwLock;
use tokio::time;
use tokio_tun::TunBuilder;
@ -64,7 +64,7 @@ async fn main() {
info!("Created TUN device {}", tun.name());
let udp_sock = Arc::new(UdpSocket::bind(local_addr).await.unwrap());
let connections = Arc::new(Mutex::new(
let connections = Arc::new(RwLock::new(
LruCache::<SocketAddrV4, Arc<Socket>>::with_expiry_duration(UDP_TTL),
));
@ -77,7 +77,7 @@ async fn main() {
loop {
tokio::select! {
Ok((size, SocketAddr::V4(addr))) = udp_sock.recv_from(&mut buf_r) => {
if let Some(sock) = connections.lock().await.get_mut(&addr) {
if let Some(sock) = connections.read().await.peek(&addr) {
sock.send(&buf_r[..size]).await;
continue;
}
@ -95,7 +95,7 @@ async fn main() {
continue;
}
assert!(connections.lock().await.insert(addr, sock.clone()).is_none());
assert!(connections.write().await.insert(addr, sock.clone()).is_none());
debug!("inserted fake TCP socket into LruCache");
let udp_sock = udp_sock.clone();
@ -108,7 +108,7 @@ async fn main() {
udp_sock.send_to(&buf_r[..size], addr).await.unwrap();
},
None => {
connections.lock().await.remove(&addr);
connections.write().await.remove(&addr);
debug!("removed fake TCP socket from LruCache");
return;
},
@ -119,7 +119,7 @@ async fn main() {
_ = cleanup_timer.tick() => {
let mut total = 0;
for c in connections.lock().await.notify_iter() {
for c in connections.write().await.notify_iter() {
if let TimedEntry::Expired(_addr, sock) = c {
sock.close();
total += 1;

View File

@ -9,7 +9,7 @@ use std::collections::{HashMap, HashSet};
use std::fmt;
use std::net::{Ipv4Addr, SocketAddrV4};
use std::sync::atomic::{AtomicU32, Ordering};
use std::sync::{Arc, Mutex};
use std::sync::{Arc, RwLock};
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::sync::mpsc::{self, Receiver, Sender};
use tokio::sync::watch;
@ -38,8 +38,8 @@ impl AddrTuple {
#[derive(Debug)]
struct Shared {
tuples: Mutex<HashMap<AddrTuple, Arc<Sender<Bytes>>>>,
listening: Mutex<HashSet<u16>>,
tuples: RwLock<HashMap<AddrTuple, Arc<Sender<Bytes>>>>,
listening: RwLock<HashSet<u16>>,
outgoing: Sender<Bytes>,
ready: Sender<Socket>,
}
@ -279,7 +279,7 @@ impl Drop for Socket {
assert!(self
.shared
.tuples
.lock()
.write()
.unwrap()
.remove(&AddrTuple::new(self.local_addr, self.remote_addr))
.is_some());
@ -306,9 +306,9 @@ impl Stack {
let (outgoing_tx, outgoing_rx) = mpsc::channel(MPSC_BUFFER_LEN);
let (ready_tx, ready_rx) = mpsc::channel(MPSC_BUFFER_LEN);
let shared = Arc::new(Shared {
tuples: Mutex::new(HashMap::new()),
tuples: RwLock::new(HashMap::new()),
outgoing: outgoing_tx,
listening: Mutex::new(HashSet::new()),
listening: RwLock::new(HashSet::new()),
ready: ready_tx,
});
let local_ip = tun.destination().unwrap();
@ -322,7 +322,7 @@ impl Stack {
}
pub fn listen(&mut self, port: u16) {
assert!(self.shared.listening.lock().unwrap().insert(port));
assert!(self.shared.listening.write().unwrap().insert(port));
}
pub async fn accept(&mut self) -> Socket {
@ -344,7 +344,7 @@ impl Stack {
);
{
let mut tuples = self.shared.tuples.lock().unwrap();
let mut tuples = self.shared.tuples.write().unwrap();
assert!(tuples.insert(tuple, Arc::new(incoming.clone())).is_none());
}
@ -378,7 +378,7 @@ impl Stack {
let sender;
{
let tuples = shared.tuples.lock().unwrap();
let tuples = shared.tuples.read().unwrap();
sender = tuples.get(&tuple).cloned();
}
@ -387,11 +387,11 @@ impl Stack {
continue;
}
if tcp_packet.get_flags() == tcp::TcpFlags::SYN && shared.listening.lock().unwrap().contains(&tcp_packet.get_destination()) {
if tcp_packet.get_flags() == tcp::TcpFlags::SYN && shared.listening.read().unwrap().contains(&tcp_packet.get_destination()) {
// SYN seen on listening socket
if tcp_packet.get_sequence() == 0 {
let (sock, incoming) = Socket::new(Mode::Server, shared.clone(), local_addr, remote_addr, Some(tcp_packet.get_sequence() + 1), State::Idle);
assert!(shared.tuples.lock().unwrap().insert(tuple, Arc::new(incoming)).is_none());
assert!(shared.tuples.write().unwrap().insert(tuple, Arc::new(incoming)).is_none());
tokio::spawn(sock.accept());
} else {
trace!("Bad TCP SYN packet from {}, sending RST", remote_addr);