From 577cef131b92a4aada4374ad53ad52a344c03d71 Mon Sep 17 00:00:00 2001 From: "sijie.sun" Date: Sun, 28 Apr 2024 22:08:11 +0800 Subject: [PATCH] fix wireguard deadlock --- easytier/src/tunnel/wireguard.rs | 22 ++++++++++++---------- 1 file changed, 12 insertions(+), 10 deletions(-) diff --git a/easytier/src/tunnel/wireguard.rs b/easytier/src/tunnel/wireguard.rs index 5278d28f..c4dd2fce 100644 --- a/easytier/src/tunnel/wireguard.rs +++ b/easytier/src/tunnel/wireguard.rs @@ -14,6 +14,7 @@ use boringtun::{ x25519::{PublicKey, StaticSecret}, }; use bytes::BytesMut; +use crossbeam::atomic::AtomicCell; use dashmap::DashMap; use futures::{stream::FuturesUnordered, SinkExt, StreamExt}; use rand::RngCore; @@ -343,7 +344,7 @@ struct WgPeer { data: Option, tasks: JoinSet<()>, - access_time: std::time::Instant, + access_time: AtomicCell, } impl WgPeer { @@ -358,7 +359,7 @@ impl WgPeer { data: None, tasks: JoinSet::new(), - access_time: std::time::Instant::now(), + access_time: AtomicCell::new(std::time::Instant::now()), } } @@ -373,8 +374,8 @@ impl WgPeer { .store(true, std::sync::atomic::Ordering::Relaxed); } - async fn handle_packet_from_peer(&mut self, packet: &[u8]) { - self.access_time = std::time::Instant::now(); + async fn handle_packet_from_peer(&self, packet: &[u8]) { + self.access_time.store(std::time::Instant::now()); tracing::trace!("Received {} bytes from peer", packet.len()); let data = self.data.as_ref().unwrap(); // TODO: improve this @@ -436,7 +437,7 @@ pub struct WgTunnelListener { conn_recv: ConnReceiver, conn_send: Option, - wg_peer_map: Arc>, + wg_peer_map: Arc>>, tasks: JoinSet<()>, } @@ -466,15 +467,16 @@ impl WgTunnelListener { socket: Arc, config: WgConfig, conn_sender: ConnSender, - peer_map: Arc>, + peer_map: Arc>>, ) { let mut tasks = JoinSet::new(); let peer_map_clone = peer_map.clone(); tasks.spawn(async move { loop { - peer_map_clone - .retain(|_, peer| peer.access_time.elapsed().as_secs() < 61 && !peer.stopped()); + peer_map_clone.retain(|_, peer| { + peer.access_time.load().elapsed().as_secs() < 61 && !peer.stopped() + }); tokio::time::sleep(Duration::from_secs(1)).await; } }); @@ -509,10 +511,10 @@ impl WgTunnelListener { if let Err(e) = conn_sender.send(tunnel) { tracing::error!("Failed to send tunnel to conn_sender: {}", e); } - peer_map.insert(addr, wg); + peer_map.insert(addr, Arc::new(wg)); } - let mut peer = peer_map.get_mut(&addr).unwrap(); + let peer = peer_map.get(&addr).unwrap().clone(); peer.handle_packet_from_peer(data).await; } }