diff --git a/easytier/src/peers/peer_conn.rs b/easytier/src/peers/peer_conn.rs index 1110a725..e8eca284 100644 --- a/easytier/src/peers/peer_conn.rs +++ b/easytier/src/peers/peer_conn.rs @@ -42,13 +42,22 @@ pub type PeerConnId = uuid::Uuid; macro_rules! wait_response { ($stream: ident, $out_var:ident, $pattern:pat_param => $value:expr) => { - let rsp_vec = timeout(Duration::from_secs(1), $stream.next()).await; - if rsp_vec.is_err() { + let Ok(rsp_vec) = timeout(Duration::from_secs(1), $stream.next()).await else { return Err(TunnelError::WaitRespError( "wait handshake response timeout".to_owned(), )); - } - let rsp_vec = rsp_vec.unwrap().unwrap()?; + }; + let Some(rsp_vec) = rsp_vec else { + return Err(TunnelError::WaitRespError( + "wait handshake response get none".to_owned(), + )); + }; + let Ok(rsp_vec) = rsp_vec else { + return Err(TunnelError::WaitRespError(format!( + "wait handshake response get error {}", + rsp_vec.err().unwrap() + ))); + }; let $out_var; let rsp_bytes = Packet::decode(&rsp_vec); diff --git a/easytier/src/tunnels/ring_tunnel.rs b/easytier/src/tunnels/ring_tunnel.rs index 2d037bbb..83f85fc9 100644 --- a/easytier/src/tunnels/ring_tunnel.rs +++ b/easytier/src/tunnels/ring_tunnel.rs @@ -1,6 +1,9 @@ use std::{ collections::HashMap, - sync::{atomic::AtomicBool, Arc}, + sync::{ + atomic::{AtomicBool, AtomicU32}, + Arc, + }, task::Poll, }; @@ -22,28 +25,55 @@ use uuid::Uuid; use crate::tunnels::{SinkError, SinkItem}; use super::{ - build_url_from_socket_addr, check_scheme_and_get_socket_addr, DatagramSink, DatagramStream, - Tunnel, TunnelConnector, TunnelError, TunnelInfo, TunnelListener, + build_url_from_socket_addr, check_scheme_and_get_socket_addr, common::FramedTunnel, + DatagramSink, DatagramStream, Tunnel, TunnelConnector, TunnelError, TunnelInfo, TunnelListener, }; static RING_TUNNEL_CAP: usize = 1000; +struct Ring { + id: Uuid, + ring: ArrayQueue, + consume_notify: Notify, + produce_notify: Notify, + closed: AtomicBool, +} + +impl Ring { + fn new(cap: usize, id: uuid::Uuid) -> Self { + Self { + id, + ring: ArrayQueue::new(cap), + consume_notify: Notify::new(), + produce_notify: Notify::new(), + closed: AtomicBool::new(false), + } + } + + fn close(&self) { + self.closed + .store(true, std::sync::atomic::Ordering::Relaxed); + self.produce_notify.notify_one(); + } + + fn closed(&self) -> bool { + self.closed.load(std::sync::atomic::Ordering::Relaxed) + } +} + pub struct RingTunnel { id: Uuid, - ring: Arc>, - consume_notify: Arc, - produce_notify: Arc, - closed: Arc, + ring: Arc, + sender_counter: Arc, } impl RingTunnel { pub fn new(cap: usize) -> Self { + let id = Uuid::new_v4(); RingTunnel { - id: Uuid::new_v4(), - ring: Arc::new(ArrayQueue::new(cap)), - consume_notify: Arc::new(Notify::new()), - produce_notify: Arc::new(Notify::new()), - closed: Arc::new(AtomicBool::new(false)), + id: id.clone(), + ring: Arc::new(Ring::new(cap, id)), + sender_counter: Arc::new(AtomicU32::new(1)), } } @@ -55,27 +85,24 @@ impl RingTunnel { fn recv_stream(&self) -> impl DatagramStream { let ring = self.ring.clone(); - let produce_notify = self.produce_notify.clone(); - let consume_notify = self.consume_notify.clone(); - let closed = self.closed.clone(); let id = self.id; stream! { loop { - if closed.load(std::sync::atomic::Ordering::Relaxed) { - log::warn!("ring recv tunnel {:?} closed", id); - yield Err(TunnelError::CommonError("Closed".to_owned())); - } - match ring.pop() { + match ring.ring.pop() { Some(v) => { let mut out = BytesMut::new(); out.extend_from_slice(&v); - consume_notify.notify_one(); + ring.consume_notify.notify_one(); log::trace!("id: {}, recv buffer, len: {:?}, buf: {:?}", id, v.len(), &v); yield Ok(out); }, None => { + if ring.closed() { + log::warn!("ring recv tunnel {:?} closed", id); + yield Err(TunnelError::CommonError("ring closed".to_owned())); + } log::trace!("waiting recv buffer, id: {}", id); - produce_notify.notified().await; + ring.produce_notify.notified().await; } } } @@ -84,18 +111,13 @@ impl RingTunnel { fn send_sink(&self) -> impl DatagramSink { let ring = self.ring.clone(); - let produce_notify = self.produce_notify.clone(); - let consume_notify = self.consume_notify.clone(); - let closed = self.closed.clone(); - let id = self.id; - - // type T = RingTunnel; - + let sender_counter = self.sender_counter.clone(); use tokio::task::JoinHandle; struct T { - ring: RingTunnel, + ring: Arc, wait_consume_task: Option>, + sender_counter: Arc, } impl T { @@ -110,16 +132,15 @@ impl RingTunnel { } if self_mut.wait_consume_task.is_none() { let id = self_mut.ring.id; - let consume_notify = self_mut.ring.consume_notify.clone(); - let ring = self_mut.ring.ring.clone(); + let ring = self_mut.ring.clone(); let task = async move { log::trace!( "waiting ring consume done, expected_size: {}, id: {}", expected_size, id ); - while ring.len() > expected_size { - consume_notify.notified().await; + while ring.ring.len() > expected_size { + ring.consume_notify.notified().await; } log::trace!( "ring consume done, expected_size: {}, id: {}", @@ -147,6 +168,12 @@ impl RingTunnel { self: std::pin::Pin<&mut Self>, cx: &mut std::task::Context<'_>, ) -> std::task::Poll> { + if self.ring.closed() { + return Poll::Ready(Err(TunnelError::CommonError( + "ring closed during ready".to_owned(), + ) + .into())); + } let expected_size = self.ring.ring.capacity() - 1; match self.wait_ring_consume(cx, expected_size) { Poll::Ready(_) => Poll::Ready(Ok(())), @@ -158,6 +185,11 @@ impl RingTunnel { self: std::pin::Pin<&mut Self>, item: SinkItem, ) -> Result<(), Self::Error> { + if self.ring.closed() { + return Err( + TunnelError::CommonError("ring closed during send".to_owned()).into(), + ); + } log::trace!("id: {}, send buffer, buf: {:?}", self.ring.id, &item); self.ring.ring.push(item).unwrap(); self.ring.produce_notify.notify_one(); @@ -168,6 +200,12 @@ impl RingTunnel { self: std::pin::Pin<&mut Self>, _cx: &mut std::task::Context<'_>, ) -> std::task::Poll> { + if self.ring.closed() { + return Poll::Ready(Err(TunnelError::CommonError( + "ring closed during flush".to_owned(), + ) + .into())); + } Poll::Ready(Ok(())) } @@ -175,24 +213,38 @@ impl RingTunnel { self: std::pin::Pin<&mut Self>, _cx: &mut std::task::Context<'_>, ) -> std::task::Poll> { - self.ring - .closed - .store(true, std::sync::atomic::Ordering::Relaxed); - log::warn!("ring tunnel send {:?} closed", self.ring.id); - self.ring.produce_notify.notify_one(); + self.ring.close(); Poll::Ready(Ok(())) } } + impl Drop for T { + fn drop(&mut self) { + let rem = self + .sender_counter + .fetch_sub(1, std::sync::atomic::Ordering::Relaxed); + if rem == 1 { + self.ring.close() + } + } + } + + sender_counter.fetch_add(1, std::sync::atomic::Ordering::Relaxed); T { - ring: RingTunnel { - id, - ring, - consume_notify, - produce_notify, - closed, - }, + ring, wait_consume_task: None, + sender_counter, + } + } +} + +impl Drop for RingTunnel { + fn drop(&mut self) { + let rem = self + .sender_counter + .fetch_sub(1, std::sync::atomic::Ordering::Relaxed); + if rem == 1 { + self.ring.close() } } } @@ -213,11 +265,6 @@ impl Tunnel for RingTunnel { fn info(&self) -> Option { None - // Some(TunnelInfo { - // tunnel_type: "ring".to_owned(), - // local_addr: format!("ring://{}", self.id), - // remote_addr: format!("ring://{}", self.id), - // }) } } @@ -241,48 +288,29 @@ impl RingTunnelListener { } } } -struct ConnectionForServer { - conn: Arc, -} -impl Tunnel for ConnectionForServer { - fn stream(&self) -> Box { - Box::new(self.conn.server.recv_stream()) - } - - fn sink(&self) -> Box { - Box::new(self.conn.client.send_sink()) - } - - fn info(&self) -> Option { - Some(TunnelInfo { +fn get_tunnel_for_client(conn: Arc) -> Box { + FramedTunnel::new_tunnel_with_info( + Box::pin(conn.client.recv_stream()), + conn.server.send_sink(), + TunnelInfo { tunnel_type: "ring".to_owned(), - local_addr: build_url_from_socket_addr(&self.conn.server.id.into(), "ring").into(), - remote_addr: build_url_from_socket_addr(&self.conn.client.id.into(), "ring").into(), - }) - } + local_addr: build_url_from_socket_addr(&conn.client.id.into(), "ring").into(), + remote_addr: build_url_from_socket_addr(&conn.server.id.into(), "ring").into(), + }, + ) } -struct ConnectionForClient { - conn: Arc, -} - -impl Tunnel for ConnectionForClient { - fn stream(&self) -> Box { - Box::new(self.conn.client.recv_stream()) - } - - fn sink(&self) -> Box { - Box::new(self.conn.server.send_sink()) - } - - fn info(&self) -> Option { - Some(TunnelInfo { +fn get_tunnel_for_server(conn: Arc) -> Box { + FramedTunnel::new_tunnel_with_info( + Box::pin(conn.server.recv_stream()), + conn.client.send_sink(), + TunnelInfo { tunnel_type: "ring".to_owned(), - local_addr: build_url_from_socket_addr(&self.conn.client.id.into(), "ring").into(), - remote_addr: build_url_from_socket_addr(&self.conn.server.id.into(), "ring").into(), - }) - } + local_addr: build_url_from_socket_addr(&conn.server.id.into(), "ring").into(), + remote_addr: build_url_from_socket_addr(&conn.client.id.into(), "ring").into(), + }, + ) } impl RingTunnelListener { @@ -308,7 +336,7 @@ impl TunnelListener for RingTunnelListener { if let Some(conn) = self.conn_receiver.recv().await { if conn.server.id == my_addr { log::info!("accept new conn of key: {}", self.listerner_addr); - return Ok(Box::new(ConnectionForServer { conn })); + return Ok(get_tunnel_for_server(conn)); } else { tracing::error!(?conn.server.id, ?my_addr, "got new conn with wrong id"); return Err(TunnelError::CommonError( @@ -353,7 +381,7 @@ impl TunnelConnector for RingTunnelConnector { entry .send(conn.clone()) .map_err(|_| TunnelError::CommonError("send conn to listner failed".to_owned()))?; - Ok(Box::new(ConnectionForClient { conn })) + Ok(get_tunnel_for_client(conn)) } fn remote_url(&self) -> url::Url { @@ -367,13 +395,15 @@ pub fn create_ring_tunnel_pair() -> (Box, Box) { server: RingTunnel::new(RING_TUNNEL_CAP), }); ( - Box::new(ConnectionForServer { conn: conn.clone() }), - Box::new(ConnectionForClient { conn }), + Box::new(get_tunnel_for_server(conn.clone())), + Box::new(get_tunnel_for_client(conn)), ) } #[cfg(test)] mod tests { + use futures::StreamExt; + use crate::tunnels::common::tests::{_tunnel_bench, _tunnel_pingpong}; use super::*; @@ -393,4 +423,14 @@ mod tests { let connector = RingTunnelConnector::new(id); _tunnel_bench(listener, connector).await } + + #[tokio::test] + async fn ring_close() { + let (stunnel, ctunnel) = create_ring_tunnel_pair(); + drop(stunnel); + + let mut stream = ctunnel.pin_stream(); + let ret = stream.next().await; + assert!(ret.as_ref().unwrap().is_err(), "expect Err, got {:?}", ret); + } } diff --git a/easytier/src/tunnels/wireguard.rs b/easytier/src/tunnels/wireguard.rs index b09225c5..c6e79c6a 100644 --- a/easytier/src/tunnels/wireguard.rs +++ b/easytier/src/tunnels/wireguard.rs @@ -4,7 +4,7 @@ use std::{ hash::Hasher, net::SocketAddr, pin::Pin, - sync::Arc, + sync::{atomic::AtomicBool, Arc}, time::Duration, }; @@ -120,6 +120,7 @@ struct WgPeerData { sink: Arc>>>, stream: Arc>>>, wg_type: WgType, + stopped: Arc, } impl Debug for WgPeerData { @@ -366,6 +367,8 @@ impl WgPeer { tracing::error!("Failed to handle packet from me: {}", e); } } + data.stopped + .store(true, std::sync::atomic::Ordering::Relaxed); } async fn handle_packet_from_peer(&mut self, packet: &[u8]) { @@ -395,6 +398,7 @@ impl WgPeer { sink: Arc::new(Mutex::new(stunnel.pin_sink())), stream: Arc::new(Mutex::new(stunnel.pin_stream())), wg_type: self.config.wg_type.clone(), + stopped: Arc::new(AtomicBool::new(false)), }; self.data = Some(data.clone()); @@ -403,6 +407,14 @@ impl WgPeer { ctunnel } + + fn stopped(&self) -> bool { + self.data + .as_ref() + .unwrap() + .stopped + .load(std::sync::atomic::Ordering::Relaxed) + } } impl Drop for WgPeer { @@ -427,6 +439,8 @@ pub struct WgTunnelListener { conn_recv: ConnReceiver, conn_send: Option, + wg_peer_map: Arc>, + tasks: JoinSet<()>, } @@ -441,6 +455,8 @@ impl WgTunnelListener { conn_recv, conn_send: Some(conn_send), + wg_peer_map: Arc::new(DashMap::new()), + tasks: JoinSet::new(), } } @@ -453,15 +469,16 @@ impl WgTunnelListener { socket: Arc, config: WgConfig, conn_sender: ConnSender, + peer_map: Arc>, ) { let mut tasks = JoinSet::new(); - let peer_map: Arc> = Arc::new(DashMap::new()); let peer_map_clone = peer_map.clone(); tasks.spawn(async move { loop { - peer_map_clone.retain(|_, peer| peer.access_time.elapsed().as_secs() < 600); - tokio::time::sleep(Duration::from_secs(60)).await; + peer_map_clone + .retain(|_, peer| peer.access_time.elapsed().as_secs() < 61 && !peer.stopped()); + tokio::time::sleep(Duration::from_secs(1)).await; } }); @@ -524,6 +541,7 @@ impl TunnelListener for WgTunnelListener { self.get_udp_socket(), self.config.clone(), self.conn_send.take().unwrap(), + self.wg_peer_map.clone(), )); Ok(()) @@ -788,4 +806,36 @@ pub mod tests { connector.set_bind_addrs(vec!["10.0.0.1:0".parse().unwrap()]); _tunnel_pingpong(listener, connector).await } + + #[tokio::test] + async fn wg_server_erase_from_map_after_close() { + let (server_cfg, client_cfg) = create_wg_config(); + let mut listener = + WgTunnelListener::new("wg://127.0.0.1:5595".parse().unwrap(), server_cfg); + listener.listen().await.unwrap(); + + const CONN_COUNT: usize = 10; + + tokio::spawn(async move { + for _ in 0..CONN_COUNT { + let mut connector = WgTunnelConnector::new( + "wg://127.0.0.1:5595".parse().unwrap(), + client_cfg.clone(), + ); + let ret = connector.connect().await; + assert!(ret.is_ok()); + drop(ret); + } + }); + + for _ in 0..CONN_COUNT { + let conn = listener.accept().await; + assert!(conn.is_ok()); + drop(conn); + } + + tokio::time::sleep(tokio::time::Duration::from_secs(2)).await; + + assert_eq!(0, listener.wg_peer_map.len()); + } }