diff --git a/easytier-core/Cargo.toml b/easytier-core/Cargo.toml index b70a7882..4eedcb71 100644 --- a/easytier-core/Cargo.toml +++ b/easytier-core/Cargo.toml @@ -64,7 +64,6 @@ tonic = "0.10" prost = "0.12" anyhow = "1.0" tarpc = { version = "0.32", features = ["tokio1", "serde1"] } -bincode = "1.3" url = "2.5.0" diff --git a/easytier-core/src/connector/direct.rs b/easytier-core/src/connector/direct.rs index 92cc625a..8b7a7500 100644 --- a/easytier-core/src/connector/direct.rs +++ b/easytier-core/src/connector/direct.rs @@ -27,7 +27,7 @@ pub trait DirectConnectorRpc { #[async_trait::async_trait] pub trait PeerManagerForDirectConnector { async fn list_peers(&self) -> Vec; - async fn list_peer_conns(&self, peer_id: &PeerId) -> Option>; + async fn list_peer_conns(&self, peer_id: PeerId) -> Option>; fn get_peer_rpc_mgr(&self) -> Arc; } @@ -44,7 +44,7 @@ impl PeerManagerForDirectConnector for PeerManager { ret } - async fn list_peer_conns(&self, peer_id: &PeerId) -> Option> { + async fn list_peer_conns(&self, peer_id: PeerId) -> Option> { self.get_peer_map().list_peer_conns(peer_id).await } @@ -221,7 +221,7 @@ impl DirectConnectorManager { ) -> Result<(), Error> { let peer_manager = data.peer_manager.clone(); // check if we have direct connection with dst_peer_id - if let Some(c) = peer_manager.list_peer_conns(&dst_peer_id).await { + if let Some(c) = peer_manager.list_peer_conns(dst_peer_id).await { // currently if we have any type of direct connection (udp or tcp), we will not try to connect if !c.is_empty() { return Ok(()); diff --git a/easytier-core/src/connector/udp_hole_punch.rs b/easytier-core/src/connector/udp_hole_punch.rs index 5e40790b..801d0784 100644 --- a/easytier-core/src/connector/udp_hole_punch.rs +++ b/easytier-core/src/connector/udp_hole_punch.rs @@ -284,7 +284,7 @@ impl UdpHolePunchConnector { }; let peer_id: PeerId = route.peer_id; - let conns = data.peer_mgr.list_peer_conns(&peer_id).await; + let conns = data.peer_mgr.list_peer_conns(peer_id).await; if conns.is_some() && conns.unwrap().len() > 0 { continue; } diff --git a/easytier-core/src/gateway/icmp_proxy.rs b/easytier-core/src/gateway/icmp_proxy.rs index c660b458..03ed0b05 100644 --- a/easytier-core/src/gateway/icmp_proxy.rs +++ b/easytier-core/src/gateway/icmp_proxy.rs @@ -21,10 +21,7 @@ use tracing::Instrument; use crate::{ common::{error::Error, global_ctx::ArcGlobalCtx, PeerId}, - peers::{ - packet, - peer_manager::{PeerManager, PeerPacketFilter}, - }, + peers::{packet, peer_manager::PeerManager, PeerPacketFilter}, }; use super::CidrSet; diff --git a/easytier-core/src/gateway/tcp_proxy.rs b/easytier-core/src/gateway/tcp_proxy.rs index 2f1e2b76..ad30bf21 100644 --- a/easytier-core/src/gateway/tcp_proxy.rs +++ b/easytier-core/src/gateway/tcp_proxy.rs @@ -18,7 +18,8 @@ use crate::common::error::Result; use crate::common::global_ctx::GlobalCtx; use crate::common::netns::NetNS; use crate::peers::packet::{self, ArchivedPacket}; -use crate::peers::peer_manager::{NicPacketFilter, PeerManager, PeerPacketFilter}; +use crate::peers::peer_manager::PeerManager; +use crate::peers::{NicPacketFilter, PeerPacketFilter}; use super::CidrSet; diff --git a/easytier-core/src/gateway/udp_proxy.rs b/easytier-core/src/gateway/udp_proxy.rs index 970629a0..d1fbcf45 100644 --- a/easytier-core/src/gateway/udp_proxy.rs +++ b/easytier-core/src/gateway/udp_proxy.rs @@ -26,10 +26,7 @@ use tracing::Level; use crate::{ common::{error::Error, global_ctx::ArcGlobalCtx, PeerId}, - peers::{ - packet, - peer_manager::{PeerManager, PeerPacketFilter}, - }, + peers::{packet, peer_manager::PeerManager, PeerPacketFilter}, tunnels::common::setup_sokcet2, }; diff --git a/easytier-core/src/peers/mod.rs b/easytier-core/src/peers/mod.rs index f71b80a1..f9b783a8 100644 --- a/easytier-core/src/peers/mod.rs +++ b/easytier-core/src/peers/mod.rs @@ -13,3 +13,26 @@ pub mod foreign_network_manager; #[cfg(test)] pub mod tests; + +use tokio_util::bytes::{Bytes, BytesMut}; + +#[async_trait::async_trait] +#[auto_impl::auto_impl(Arc)] +pub trait PeerPacketFilter { + async fn try_process_packet_from_peer( + &self, + _packet: &packet::ArchivedPacket, + _data: &Bytes, + ) -> Option<()> { + None + } +} + +#[async_trait::async_trait] +#[auto_impl::auto_impl(Arc)] +pub trait NicPacketFilter { + async fn try_process_packet_from_nic(&self, data: BytesMut) -> BytesMut; +} + +type BoxPeerPacketFilter = Box; +type BoxNicPacketFilter = Box; diff --git a/easytier-core/src/peers/peer_manager.rs b/easytier-core/src/peers/peer_manager.rs index a36a197f..63b80de9 100644 --- a/easytier-core/src/peers/peer_manager.rs +++ b/easytier-core/src/peers/peer_manager.rs @@ -1,4 +1,8 @@ -use std::{fmt::Debug, net::Ipv4Addr, sync::Arc}; +use std::{ + fmt::Debug, + net::Ipv4Addr, + sync::{Arc, Weak}, +}; use async_trait::async_trait; use futures::{StreamExt, TryFutureExt}; @@ -19,7 +23,8 @@ use crate::{ PeerId, }, peers::{ - packet, peer_conn::PeerConn, peer_rpc::PeerRpcManagerTransport, route_trait::RouteInterface, + packet, peer_conn::PeerConn, peer_rpc::PeerRpcManagerTransport, + route_trait::RouteInterface, PeerPacketFilter, }, tunnels::{SinkItem, Tunnel, TunnelConnector}, }; @@ -32,12 +37,13 @@ use super::{ peer_rip_route::BasicRoute, peer_rpc::PeerRpcManager, route_trait::{ArcRoute, Route}, + BoxNicPacketFilter, BoxPeerPacketFilter, }; struct RpcTransport { my_peer_id: PeerId, - peers: Arc, - foreign_peers: Mutex>>, + peers: Weak, + foreign_peers: Mutex>>, packet_recv: Mutex>, peer_rpc_tspt_sender: UnboundedSender, @@ -50,15 +56,21 @@ impl PeerRpcManagerTransport for RpcTransport { } async fn send(&self, msg: Bytes, dst_peer_id: PeerId) -> Result<(), Error> { - if let Some(foreign_peers) = self.foreign_peers.lock().await.as_ref() { - if foreign_peers.has_peer(dst_peer_id) { - return foreign_peers.send_msg(msg, dst_peer_id).await; - } - } - self.peers - .send_msg(msg, dst_peer_id) - .map_err(|e| e.into()) + let foreign_peers = self + .foreign_peers + .lock() .await + .as_ref() + .ok_or(Error::Unknown)? + .upgrade() + .ok_or(Error::Unknown)?; + let peers = self.peers.upgrade().ok_or(Error::Unknown)?; + + if foreign_peers.has_peer(dst_peer_id) { + return foreign_peers.send_msg(msg, dst_peer_id).await; + } + + peers.send_msg(msg, dst_peer_id).map_err(|e| e.into()).await } async fn recv(&self) -> Result { @@ -70,25 +82,6 @@ impl PeerRpcManagerTransport for RpcTransport { } } -#[async_trait::async_trait] -#[auto_impl::auto_impl(Arc)] -pub trait PeerPacketFilter { - async fn try_process_packet_from_peer( - &self, - packet: &packet::ArchivedPacket, - data: &Bytes, - ) -> Option<()>; -} - -#[async_trait::async_trait] -#[auto_impl::auto_impl(Arc)] -pub trait NicPacketFilter { - async fn try_process_packet_from_nic(&self, data: BytesMut) -> BytesMut; -} - -type BoxPeerPacketFilter = Box; -type BoxNicPacketFilter = Box; - pub struct PeerManager { my_peer_id: PeerId, @@ -138,7 +131,7 @@ impl PeerManager { let (peer_rpc_tspt_sender, peer_rpc_tspt_recv) = mpsc::unbounded_channel(); let rpc_tspt = Arc::new(RpcTransport { my_peer_id, - peers: peers.clone(), + peers: Arc::downgrade(&peers), foreign_peers: Mutex::new(None), packet_recv: Mutex::new(peer_rpc_tspt_recv), peer_rpc_tspt_sender, @@ -316,10 +309,6 @@ impl PeerManager { })) .await; - // for route - self.add_packet_process_pipeline(Box::new(self.basic_route.clone())) - .await; - // for peer rpc packet struct PeerRpcPacketProcessor { peer_rpc_tspt_sender: UnboundedSender, @@ -348,19 +337,31 @@ impl PeerManager { pub async fn add_route(&self, route: T) where - T: Route + Send + Sync + 'static, + T: Route + PeerPacketFilter + Send + Sync + Clone + 'static, { + // for route + self.add_packet_process_pipeline(Box::new(route.clone())) + .await; + struct Interface { my_peer_id: PeerId, - peers: Arc, - foreign_network_client: Arc, + peers: Weak, + foreign_network_client: Weak, } #[async_trait] impl RouteInterface for Interface { async fn list_peers(&self) -> Vec { - let mut peers = self.foreign_network_client.list_foreign_peers(); - peers.extend(self.peers.list_peers_with_conn().await); + let Some(foreign_client) = self.foreign_network_client.upgrade() else { + return vec![]; + }; + + let Some(peer_map) = self.peers.upgrade() else { + return vec![]; + }; + + let mut peers = foreign_client.list_foreign_peers(); + peers.extend(peer_map.list_peers_with_conn().await); peers } async fn send_route_packet( @@ -369,19 +370,20 @@ impl PeerManager { route_id: u8, dst_peer_id: PeerId, ) -> Result<(), Error> { + let foreign_client = self + .foreign_network_client + .upgrade() + .ok_or(Error::Unknown)?; + let peer_map = self.peers.upgrade().ok_or(Error::Unknown)?; + let packet_bytes: Bytes = packet::Packet::new_route_packet(self.my_peer_id, dst_peer_id, route_id, &msg) .into(); - if self.foreign_network_client.has_next_hop(dst_peer_id) { - return self - .foreign_network_client - .send_msg(packet_bytes, dst_peer_id) - .await; + if foreign_client.has_next_hop(dst_peer_id) { + return foreign_client.send_msg(packet_bytes, dst_peer_id).await; } - self.peers - .send_msg_directly(packet_bytes, dst_peer_id) - .await + peer_map.send_msg_directly(packet_bytes, dst_peer_id).await } fn my_peer_id(&self) -> PeerId { self.my_peer_id @@ -392,8 +394,8 @@ impl PeerManager { let _route_id = route .open(Box::new(Interface { my_peer_id, - peers: self.peers.clone(), - foreign_network_client: self.foreign_network_client.clone(), + peers: Arc::downgrade(&self.peers), + foreign_network_client: Arc::downgrade(&self.foreign_network_client), })) .await .unwrap(); @@ -485,7 +487,7 @@ impl PeerManager { .foreign_peers .lock() .await - .replace(self.foreign_network_client.get_peer_map().clone()); + .replace(Arc::downgrade(&self.foreign_network_client.get_peer_map())); self.foreign_network_manager.run().await; self.foreign_network_client.run().await; @@ -539,3 +541,45 @@ impl PeerManager { self.foreign_network_client.clone() } } + +#[cfg(test)] +mod tests { + + use crate::{ + connector::udp_hole_punch::tests::create_mock_peer_manager_with_mock_stun, + peers::tests::{connect_peer_manager, wait_for_condition, wait_route_appear}, + rpc::NatType, + }; + + #[tokio::test] + async fn drop_peer_manager() { + let peer_mgr_a = create_mock_peer_manager_with_mock_stun(NatType::Unknown).await; + let peer_mgr_b = create_mock_peer_manager_with_mock_stun(NatType::Unknown).await; + let peer_mgr_c = create_mock_peer_manager_with_mock_stun(NatType::Unknown).await; + connect_peer_manager(peer_mgr_a.clone(), peer_mgr_b.clone()).await; + connect_peer_manager(peer_mgr_b.clone(), peer_mgr_c.clone()).await; + connect_peer_manager(peer_mgr_a.clone(), peer_mgr_c.clone()).await; + + wait_route_appear(peer_mgr_a.clone(), peer_mgr_b.my_peer_id()) + .await + .unwrap(); + wait_route_appear(peer_mgr_a.clone(), peer_mgr_c.my_peer_id()) + .await + .unwrap(); + + // wait mgr_a have 2 peers + wait_for_condition( + || async { peer_mgr_a.get_peer_map().list_peers_with_conn().await.len() == 2 }, + std::time::Duration::from_secs(5), + ) + .await; + + drop(peer_mgr_b); + + wait_for_condition( + || async { peer_mgr_a.get_peer_map().list_peers_with_conn().await.len() == 1 }, + std::time::Duration::from_secs(5), + ) + .await; + } +} diff --git a/easytier-core/src/peers/peer_map.rs b/easytier-core/src/peers/peer_map.rs index 84763dc5..8e4aefe6 100644 --- a/easytier-core/src/peers/peer_map.rs +++ b/easytier-core/src/peers/peer_map.rs @@ -164,8 +164,8 @@ impl PeerMap { ret } - pub async fn list_peer_conns(&self, peer_id: &PeerId) -> Option> { - if let Some(p) = self.get_peer_by_id(*peer_id) { + pub async fn list_peer_conns(&self, peer_id: PeerId) -> Option> { + if let Some(p) = self.get_peer_by_id(peer_id) { Some(p.list_peer_conns().await) } else { return None; @@ -206,7 +206,7 @@ impl PeerMap { let mut to_remove = vec![]; for peer_id in self.list_peers().await { - let conns = self.list_peer_conns(&peer_id).await; + let conns = self.list_peer_conns(peer_id).await; if conns.is_none() || conns.as_ref().unwrap().is_empty() { to_remove.push(peer_id); } diff --git a/easytier-core/src/peers/peer_rip_route.rs b/easytier-core/src/peers/peer_rip_route.rs index ae39c759..606e7fe4 100644 --- a/easytier-core/src/peers/peer_rip_route.rs +++ b/easytier-core/src/peers/peer_rip_route.rs @@ -22,7 +22,7 @@ use crate::{ rpc::{NatType, StunInfo}, }; -use super::{packet::CtrlPacketPayload, peer_manager::PeerPacketFilter}; +use super::{packet::CtrlPacketPayload, PeerPacketFilter}; const SEND_ROUTE_PERIOD_SEC: u64 = 60; const SEND_ROUTE_FAST_REPLY_SEC: u64 = 5; diff --git a/easytier-core/src/peers/peer_rpc.rs b/easytier-core/src/peers/peer_rpc.rs index c88fe6bd..5daa581c 100644 --- a/easytier-core/src/peers/peer_rpc.rs +++ b/easytier-core/src/peers/peer_rpc.rs @@ -110,6 +110,11 @@ impl PeerRpcManager { loop { tokio::select! { Some(resp) = client_transport.next() => { + let Some(cur_req_peer_id) = cur_req_peer_id.take() else { + tracing::error!("[PEER RPC MGR] cur_req_peer_id is none, ignore this resp"); + continue; + }; + tracing::trace!(resp = ?resp, "recv packet from client"); if resp.is_err() { tracing::warn!(err = ?resp.err(), @@ -118,12 +123,7 @@ impl PeerRpcManager { } let resp = resp.unwrap(); - if cur_req_peer_id.is_none() { - tracing::error!("[PEER RPC MGR] cur_req_peer_id is none, ignore this resp"); - continue; - } - - let serialized_resp = bincode::serialize(&resp); + let serialized_resp = postcard::to_allocvec(&resp); if serialized_resp.is_err() { tracing::error!(error = ?serialized_resp.err(), "serialize resp failed"); continue; @@ -131,7 +131,7 @@ impl PeerRpcManager { let msg = Packet::new_tarpc_packet( tspt.my_peer_id(), - cur_req_peer_id.take().unwrap(), + cur_req_peer_id, service_id, false, serialized_resp.unwrap(), @@ -154,12 +154,17 @@ impl PeerRpcManager { continue; } + if cur_req_peer_id.is_some() { + tracing::warn!("cur_req_peer_id is not none, ignore this packet"); + continue; + } + assert_eq!(info.service_id, service_id); cur_req_peer_id = Some(packet.from_peer.clone().into()); tracing::trace!("recv packet from peer, packet: {:?}", packet); - let decoded_ret = bincode::deserialize(&info.content.as_slice()); + let decoded_ret = postcard::from_bytes(&info.content.as_slice()); if let Err(e) = decoded_ret { tracing::error!(error = ?e, "decode rpc packet failed"); continue; @@ -294,7 +299,7 @@ impl PeerRpcManager { continue; } - let a = bincode::serialize(&a.unwrap()); + let a = postcard::to_allocvec(&a.unwrap()); if a.is_err() { tracing::error!(error = ?a.err(), "bincode serialize failed"); continue; @@ -326,7 +331,7 @@ impl PeerRpcManager { continue; } - let decoded = bincode::deserialize(&info.unwrap().content.as_slice()); + let decoded = postcard::from_bytes(&info.unwrap().content.as_slice()); if let Err(e) = decoded { tracing::error!(error = ?e, "decode rpc packet failed"); continue; diff --git a/easytier-core/src/peers/rpc_service.rs b/easytier-core/src/peers/rpc_service.rs index 14bb009f..2f763840 100644 --- a/easytier-core/src/peers/rpc_service.rs +++ b/easytier-core/src/peers/rpc_service.rs @@ -25,12 +25,7 @@ impl PeerManagerRpcService { let mut peer_info = PeerInfo::default(); peer_info.peer_id = peer; - if let Some(conns) = self - .peer_manager - .get_peer_map() - .list_peer_conns(&peer) - .await - { + if let Some(conns) = self.peer_manager.get_peer_map().list_peer_conns(peer).await { peer_info.conns = conns; } diff --git a/easytier-core/src/peers/tests.rs b/easytier-core/src/peers/tests.rs index 39f70d8c..dc802bd4 100644 --- a/easytier-core/src/peers/tests.rs +++ b/easytier-core/src/peers/tests.rs @@ -1,5 +1,7 @@ use std::sync::Arc; +use futures::Future; + use crate::{ common::{error::Error, global_ctx::tests::get_mock_global_ctx, PeerId}, tunnels::ring_tunnel::create_ring_tunnel_pair, @@ -48,3 +50,18 @@ pub async fn wait_route_appear_with_cost( pub async fn wait_route_appear(peer_mgr: Arc, node_id: PeerId) -> Result<(), Error> { wait_route_appear_with_cost(peer_mgr, node_id, None).await } + +pub async fn wait_for_condition(mut condition: F, timeout: std::time::Duration) -> () +where + F: FnMut() -> FRet + Send, + FRet: Future, +{ + let now = std::time::Instant::now(); + while now.elapsed() < timeout { + if condition().await { + return; + } + tokio::time::sleep(std::time::Duration::from_millis(50)).await; + } + assert!(condition().await, "Timeout") +}