From 306817ae9a41aaf41fa11b654508cd9a4d3e1e72 Mon Sep 17 00:00:00 2001 From: "Sijie.Sun" Date: Thu, 9 Jan 2025 00:01:41 +0800 Subject: [PATCH] allow listener retry listen (#554) --- easytier/src/common/config.rs | 2 +- easytier/src/common/global_ctx.rs | 5 +- easytier/src/instance/listeners.rs | 223 ++++++++++++++++++++--------- 3 files changed, 157 insertions(+), 73 deletions(-) diff --git a/easytier/src/common/config.rs b/easytier/src/common/config.rs index 0e2eee13..d80d1965 100644 --- a/easytier/src/common/config.rs +++ b/easytier/src/common/config.rs @@ -27,7 +27,7 @@ pub fn gen_default_flags() -> Flags { relay_all_peer_rpc: false, disable_udp_hole_punching: false, ipv6_listener: "udp://[::]:0".to_string(), - multi_thread: false, + multi_thread: true, data_compress_algo: CompressionAlgoPb::None.into(), } } diff --git a/easytier/src/common/global_ctx.rs b/easytier/src/common/global_ctx.rs index 716aa78c..2a2e8c13 100644 --- a/easytier/src/common/global_ctx.rs +++ b/easytier/src/common/global_ctx.rs @@ -230,7 +230,10 @@ impl GlobalCtx { } pub fn add_running_listener(&self, url: url::Url) { - self.running_listeners.lock().unwrap().push(url); + let mut l = self.running_listeners.lock().unwrap(); + if !l.contains(&url) { + l.push(url); + } } pub fn get_vpn_portal_cidr(&self) -> Option { diff --git a/easytier/src/instance/listeners.rs b/easytier/src/instance/listeners.rs index 1bdfbd53..5750f2bb 100644 --- a/easytier/src/instance/listeners.rs +++ b/easytier/src/instance/listeners.rs @@ -1,8 +1,7 @@ use std::{fmt::Debug, sync::Arc}; -use anyhow::Context; use async_trait::async_trait; -use tokio::{sync::Mutex, task::JoinSet}; +use tokio::task::JoinSet; #[cfg(feature = "quic")] use crate::tunnel::quic::QUICTunnelListener; @@ -63,16 +62,20 @@ impl TunnelHandlerForListener for PeerManager { } } -#[derive(Debug, Clone)] -struct Listener { - inner: Arc>, +pub trait ListenerCreatorTrait: Fn() -> Box + Send + Sync {} +impl ListenerCreatorTrait for T where T: Fn() -> Box + Send {} +pub type ListenerCreator = Box; + +#[derive(Clone)] +struct ListenerFactory { + creator_fn: Arc, must_succ: bool, } pub struct ListenerManager { global_ctx: ArcGlobalCtx, net_ns: NetNS, - listeners: Vec, + listeners: Vec, peer_manager: Arc, tasks: JoinSet<()>, @@ -90,31 +93,39 @@ impl ListenerManage } pub async fn prepare_listeners(&mut self) -> Result<(), Error> { + let self_id = self.global_ctx.get_id(); self.add_listener( - RingTunnelListener::new( - format!("ring://{}", self.global_ctx.get_id()) - .parse() - .unwrap(), - ), + move || { + Box::new(RingTunnelListener::new( + format!("ring://{}", self_id).parse().unwrap(), + )) + }, true, ) .await?; for l in self.global_ctx.config.get_listener_uris().iter() { - let Ok(lis) = get_listener_by_url(l, self.global_ctx.clone()) else { + let l = l.clone(); + let Ok(_) = get_listener_by_url(&l, self.global_ctx.clone()) else { let msg = format!("failed to get listener by url: {}, maybe not supported", l); self.global_ctx .issue_event(GlobalCtxEvent::ListenerAddFailed(l.clone(), msg)); continue; }; - self.add_listener(lis, true).await?; + let ctx = self.global_ctx.clone(); + self.add_listener(move || get_listener_by_url(&l, ctx.clone()).unwrap(), true) + .await?; } if self.global_ctx.config.get_flags().enable_ipv6 { let ipv6_listener = self.global_ctx.config.get_flags().ipv6_listener.clone(); let _ = self .add_listener( - UdpTunnelListener::new(ipv6_listener.parse().unwrap()), + move || { + Box::new(UdpTunnelListener::new( + ipv6_listener.clone().parse().unwrap(), + )) + }, false, ) .await?; @@ -123,85 +134,91 @@ impl ListenerManage Ok(()) } - pub async fn add_listener(&mut self, listener: L, must_succ: bool) -> Result<(), Error> - where - L: TunnelListener + 'static, - { - let listener = Arc::new(Mutex::new(listener)); - self.listeners.push(Listener { - inner: listener, + pub async fn add_listener( + &mut self, + creator: C, + must_succ: bool, + ) -> Result<(), Error> { + self.listeners.push(ListenerFactory { + creator_fn: Arc::new(Box::new(creator)), must_succ, }); Ok(()) } - #[tracing::instrument] + #[tracing::instrument(skip(creator))] async fn run_listener( - listener: Arc>, + creator: Arc, peer_manager: Arc, global_ctx: ArcGlobalCtx, ) { - let mut l = listener.lock().await; - global_ctx.add_running_listener(l.local_url()); - global_ctx.issue_event(GlobalCtxEvent::ListenerAdded(l.local_url())); loop { - let ret = match l.accept().await { - Ok(ret) => ret, + let mut l = (creator)(); + let _g = global_ctx.net_ns.guard(); + match l.listen().await { + Ok(_) => { + global_ctx.add_running_listener(l.local_url()); + global_ctx.issue_event(GlobalCtxEvent::ListenerAdded(l.local_url())); + } Err(e) => { - global_ctx.issue_event(GlobalCtxEvent::ListenerAcceptFailed( + global_ctx.issue_event(GlobalCtxEvent::ListenerAddFailed( l.local_url(), e.to_string(), )); - tracing::error!(?e, ?l, "listener accept error"); + tracing::error!(?e, ?l, "listener listen error"); tokio::time::sleep(std::time::Duration::from_secs(1)).await; continue; } - }; + } + loop { + let ret = match l.accept().await { + Ok(ret) => ret, + Err(e) => { + global_ctx.issue_event(GlobalCtxEvent::ListenerAcceptFailed( + l.local_url(), + format!("error: {}, retry listen later...", e.to_string()), + )); + tracing::error!(?e, ?l, "listener accept error"); + tokio::time::sleep(std::time::Duration::from_secs(1)).await; + break; + } + }; - let tunnel_info = ret.info().unwrap(); - global_ctx.issue_event(GlobalCtxEvent::ConnectionAccepted( - tunnel_info - .local_addr - .clone() - .unwrap_or_default() - .to_string(), - tunnel_info - .remote_addr - .clone() - .unwrap_or_default() - .to_string(), - )); - tracing::info!(ret = ?ret, "conn accepted"); - let peer_manager = peer_manager.clone(); - let global_ctx = global_ctx.clone(); - tokio::spawn(async move { - let server_ret = peer_manager.handle_tunnel(ret).await; - if let Err(e) = &server_ret { - global_ctx.issue_event(GlobalCtxEvent::ConnectionError( - tunnel_info.local_addr.unwrap_or_default().to_string(), - tunnel_info.remote_addr.unwrap_or_default().to_string(), - e.to_string(), - )); - tracing::error!(error = ?e, "handle conn error"); - } - }); + let tunnel_info = ret.info().unwrap(); + global_ctx.issue_event(GlobalCtxEvent::ConnectionAccepted( + tunnel_info + .local_addr + .clone() + .unwrap_or_default() + .to_string(), + tunnel_info + .remote_addr + .clone() + .unwrap_or_default() + .to_string(), + )); + tracing::info!(ret = ?ret, "conn accepted"); + let peer_manager = peer_manager.clone(); + let global_ctx = global_ctx.clone(); + tokio::spawn(async move { + let server_ret = peer_manager.handle_tunnel(ret).await; + if let Err(e) = &server_ret { + global_ctx.issue_event(GlobalCtxEvent::ConnectionError( + tunnel_info.local_addr.unwrap_or_default().to_string(), + tunnel_info.remote_addr.unwrap_or_default().to_string(), + e.to_string(), + )); + tracing::error!(error = ?e, "handle conn error"); + } + }); + } } } pub async fn run(&mut self) -> Result<(), Error> { for listener in &self.listeners { - let _guard = self.net_ns.guard(); - let addr = listener.inner.lock().await.local_url(); - tracing::warn!("run listener: {:?}", listener); - listener - .inner - .lock() - .await - .listen() - .await - .with_context(|| format!("failed to add listener {}", addr))?; self.tasks.spawn(Self::run_listener( - listener.inner.clone(), + listener.creator_fn.clone(), self.peer_manager.clone(), self.global_ctx.clone(), )); @@ -213,12 +230,14 @@ impl ListenerManage #[cfg(test)] mod tests { + use std::sync::atomic::{AtomicI32, Ordering}; + use futures::{SinkExt, StreamExt}; use tokio::time::timeout; use crate::{ common::global_ctx::tests::get_mock_global_ctx, - tunnel::{packet_def::ZCPacket, ring::RingTunnelConnector, TunnelConnector}, + tunnel::{packet_def::ZCPacket, ring::RingTunnelConnector, TunnelConnector, TunnelError}, }; use super::*; @@ -245,12 +264,18 @@ mod tests { let ring_id = format!("ring://{}", uuid::Uuid::new_v4()); + let ring_id_clone = ring_id.clone(); listener_mgr - .add_listener(RingTunnelListener::new(ring_id.parse().unwrap()), true) + .add_listener( + move || Box::new(RingTunnelListener::new(ring_id_clone.parse().unwrap())), + true, + ) .await .unwrap(); listener_mgr.run().await.unwrap(); + tokio::time::sleep(std::time::Duration::from_secs(1)).await; + let connect_once = |ring_id| async move { let tunnel = RingTunnelConnector::new(ring_id).connect().await.unwrap(); let (mut recv, _send) = tunnel.split(); @@ -269,4 +294,60 @@ mod tests { .await .unwrap(); } + + #[tokio::test] + async fn retry_listen() { + let counter = Arc::new(AtomicI32::new(0)); + let drop_counter = Arc::new(AtomicI32::new(0)); + struct MockListener { + counter: Arc, + drop_counter: Arc, + } + + #[async_trait::async_trait] + impl TunnelListener for MockListener { + fn local_url(&self) -> url::Url { + "mock://".parse().unwrap() + } + + async fn listen(&mut self) -> Result<(), TunnelError> { + self.counter.fetch_add(1, Ordering::Relaxed); + Ok(()) + } + + async fn accept(&mut self) -> Result, TunnelError> { + tokio::time::sleep(std::time::Duration::from_secs(1)).await; + Err(TunnelError::BufferFull) + } + } + + impl Drop for MockListener { + fn drop(&mut self) { + self.drop_counter.fetch_add(1, Ordering::Relaxed); + } + } + + let handler = Arc::new(MockListenerHandler {}); + let mut listener_mgr = ListenerManager::new(get_mock_global_ctx(), handler.clone()); + let counter_clone = counter.clone(); + let drop_counter_clone = drop_counter.clone(); + listener_mgr + .add_listener( + move || { + Box::new(MockListener { + counter: counter_clone.clone(), + drop_counter: drop_counter_clone.clone(), + }) + }, + true, + ) + .await + .unwrap(); + listener_mgr.run().await.unwrap(); + + tokio::time::sleep(std::time::Duration::from_secs(3)).await; + + assert!(counter.load(Ordering::Relaxed) >= 2); + assert!(drop_counter.load(Ordering::Relaxed) >= 1); + } }