diff --git a/src/bin/client.rs b/src/bin/client.rs index 5142042..9bd31e7 100644 --- a/src/bin/client.rs +++ b/src/bin/client.rs @@ -7,7 +7,7 @@ use std::net::{SocketAddr, SocketAddrV4}; use std::sync::Arc; use std::time::Duration; use tokio::net::UdpSocket; -use tokio::sync::Mutex; +use tokio::sync::RwLock; use tokio::time; use tokio_tun::TunBuilder; @@ -64,7 +64,7 @@ async fn main() { info!("Created TUN device {}", tun.name()); let udp_sock = Arc::new(UdpSocket::bind(local_addr).await.unwrap()); - let connections = Arc::new(Mutex::new( + let connections = Arc::new(RwLock::new( LruCache::>::with_expiry_duration(UDP_TTL), )); @@ -77,7 +77,7 @@ async fn main() { loop { tokio::select! { Ok((size, SocketAddr::V4(addr))) = udp_sock.recv_from(&mut buf_r) => { - if let Some(sock) = connections.lock().await.get_mut(&addr) { + if let Some(sock) = connections.read().await.peek(&addr) { sock.send(&buf_r[..size]).await; continue; } @@ -95,7 +95,7 @@ async fn main() { continue; } - assert!(connections.lock().await.insert(addr, sock.clone()).is_none()); + assert!(connections.write().await.insert(addr, sock.clone()).is_none()); debug!("inserted fake TCP socket into LruCache"); let udp_sock = udp_sock.clone(); @@ -108,7 +108,7 @@ async fn main() { udp_sock.send_to(&buf_r[..size], addr).await.unwrap(); }, None => { - connections.lock().await.remove(&addr); + connections.write().await.remove(&addr); debug!("removed fake TCP socket from LruCache"); return; }, @@ -119,7 +119,7 @@ async fn main() { _ = cleanup_timer.tick() => { let mut total = 0; - for c in connections.lock().await.notify_iter() { + for c in connections.write().await.notify_iter() { if let TimedEntry::Expired(_addr, sock) = c { sock.close(); total += 1; diff --git a/src/fake_tcp/mod.rs b/src/fake_tcp/mod.rs index 7d5ee93..609bdae 100644 --- a/src/fake_tcp/mod.rs +++ b/src/fake_tcp/mod.rs @@ -9,7 +9,7 @@ use std::collections::{HashMap, HashSet}; use std::fmt; use std::net::{Ipv4Addr, SocketAddrV4}; use std::sync::atomic::{AtomicU32, Ordering}; -use std::sync::{Arc, Mutex}; +use std::sync::{Arc, RwLock}; use tokio::io::{AsyncReadExt, AsyncWriteExt}; use tokio::sync::mpsc::{self, Receiver, Sender}; use tokio::sync::watch; @@ -38,8 +38,8 @@ impl AddrTuple { #[derive(Debug)] struct Shared { - tuples: Mutex>>>, - listening: Mutex>, + tuples: RwLock>>>, + listening: RwLock>, outgoing: Sender, ready: Sender, } @@ -279,7 +279,7 @@ impl Drop for Socket { assert!(self .shared .tuples - .lock() + .write() .unwrap() .remove(&AddrTuple::new(self.local_addr, self.remote_addr)) .is_some()); @@ -306,9 +306,9 @@ impl Stack { let (outgoing_tx, outgoing_rx) = mpsc::channel(MPSC_BUFFER_LEN); let (ready_tx, ready_rx) = mpsc::channel(MPSC_BUFFER_LEN); let shared = Arc::new(Shared { - tuples: Mutex::new(HashMap::new()), + tuples: RwLock::new(HashMap::new()), outgoing: outgoing_tx, - listening: Mutex::new(HashSet::new()), + listening: RwLock::new(HashSet::new()), ready: ready_tx, }); let local_ip = tun.destination().unwrap(); @@ -322,7 +322,7 @@ impl Stack { } pub fn listen(&mut self, port: u16) { - assert!(self.shared.listening.lock().unwrap().insert(port)); + assert!(self.shared.listening.write().unwrap().insert(port)); } pub async fn accept(&mut self) -> Socket { @@ -344,7 +344,7 @@ impl Stack { ); { - let mut tuples = self.shared.tuples.lock().unwrap(); + let mut tuples = self.shared.tuples.write().unwrap(); assert!(tuples.insert(tuple, Arc::new(incoming.clone())).is_none()); } @@ -378,7 +378,7 @@ impl Stack { let sender; { - let tuples = shared.tuples.lock().unwrap(); + let tuples = shared.tuples.read().unwrap(); sender = tuples.get(&tuple).cloned(); } @@ -387,11 +387,11 @@ impl Stack { continue; } - if tcp_packet.get_flags() == tcp::TcpFlags::SYN && shared.listening.lock().unwrap().contains(&tcp_packet.get_destination()) { + if tcp_packet.get_flags() == tcp::TcpFlags::SYN && shared.listening.read().unwrap().contains(&tcp_packet.get_destination()) { // SYN seen on listening socket if tcp_packet.get_sequence() == 0 { let (sock, incoming) = Socket::new(Mode::Server, shared.clone(), local_addr, remote_addr, Some(tcp_packet.get_sequence() + 1), State::Idle); - assert!(shared.tuples.lock().unwrap().insert(tuple, Arc::new(incoming)).is_none()); + assert!(shared.tuples.write().unwrap().insert(tuple, Arc::new(incoming)).is_none()); tokio::spawn(sock.accept()); } else { trace!("Bad TCP SYN packet from {}, sending RST", remote_addr);