diff --git a/easytier-core/src/connector/udp_hole_punch.rs b/easytier-core/src/connector/udp_hole_punch.rs index 14bbe1e2..961046c1 100644 --- a/easytier-core/src/connector/udp_hole_punch.rs +++ b/easytier-core/src/connector/udp_hole_punch.rs @@ -463,7 +463,7 @@ impl UdpHolePunchConnector { } #[cfg(test)] -mod tests { +pub mod tests { use std::sync::Arc; use crate::rpc::{NatType, StunInfo}; @@ -499,7 +499,9 @@ mod tests { } } - async fn create_mock_peer_manager_with_mock_stun(udp_nat_type: NatType) -> Arc { + pub async fn create_mock_peer_manager_with_mock_stun( + udp_nat_type: NatType, + ) -> Arc { let p_a = create_mock_peer_manager().await; let collector = Box::new(MockStunInfoCollector { udp_nat_type }); p_a.get_global_ctx().replace_stun_info_collector(collector); diff --git a/easytier-core/src/peers/peer_manager.rs b/easytier-core/src/peers/peer_manager.rs index a1ed2cd3..67e8675c 100644 --- a/easytier-core/src/peers/peer_manager.rs +++ b/easytier-core/src/peers/peer_manager.rs @@ -413,7 +413,7 @@ impl PeerManager { peer_map.close_peer(&peer_id).await.unwrap(); } - tokio::time::sleep(std::time::Duration::from_secs(10)).await; + tokio::time::sleep(std::time::Duration::from_secs(3)).await; } }); } @@ -447,4 +447,8 @@ impl PeerManager { pub fn get_nic_channel(&self) -> mpsc::Sender { self.nic_channel.clone() } + + pub fn get_basic_route(&self) -> Arc { + self.basic_route.clone() + } } diff --git a/easytier-core/src/peers/peer_rip_route.rs b/easytier-core/src/peers/peer_rip_route.rs index 4e1f878b..c1480642 100644 --- a/easytier-core/src/peers/peer_rip_route.rs +++ b/easytier-core/src/peers/peer_rip_route.rs @@ -1,9 +1,16 @@ -use std::{net::Ipv4Addr, sync::Arc, time::Duration}; +use std::{ + net::Ipv4Addr, + sync::{atomic::AtomicU32, Arc}, + time::{Duration, Instant}, +}; use async_trait::async_trait; use dashmap::DashMap; use rkyv::{Archive, Deserialize, Serialize}; -use tokio::{sync::Mutex, task::JoinSet}; +use tokio::{ + sync::{Mutex, RwLock}, + task::JoinSet, +}; use tokio_util::bytes::Bytes; use tracing::Instrument; use uuid::Uuid; @@ -25,7 +32,9 @@ use crate::{ use super::{packet::ArchivedPacketBody, peer_manager::PeerPacketFilter}; -#[derive(Archive, Deserialize, Serialize, Clone, Debug)] +type Version = u32; + +#[derive(Archive, Deserialize, Serialize, Clone, Debug, PartialEq)] #[archive(compare(PartialEq), check_bytes)] // Derives can be passed through to the generated type: #[archive_attr(derive(Debug))] @@ -77,6 +86,12 @@ impl SyncPeerInfo { pub struct SyncPeer { pub myself: SyncPeerInfo, pub neighbors: Vec, + // the route table version of myself + pub version: Version, + // the route table version of peer that we have received last time + pub peer_version: Option, + // if we do not have latest peer version, need_reply is true + pub need_reply: bool, } impl SyncPeer { @@ -85,14 +100,21 @@ impl SyncPeer { _to_peer: UUID, neighbors: Vec, global_ctx: ArcGlobalCtx, + version: Version, + peer_version: Option, + need_reply: bool, ) -> Self { SyncPeer { myself: SyncPeerInfo::new_self(from_peer, &global_ctx), neighbors, + version, + peer_version, + need_reply, } } } +#[derive(Debug)] struct SyncPeerFromRemote { packet: SyncPeer, last_update: std::time::Instant, @@ -100,7 +122,7 @@ struct SyncPeerFromRemote { type SyncPeerFromRemoteMap = Arc>; -#[derive(Clone, Debug)] +#[derive(Debug)] struct RouteTable { route_info: DashMap, ipv4_peer_id_map: DashMap, @@ -137,6 +159,24 @@ impl RouteTable { } } +#[derive(Debug, Clone)] +struct RouteVersion(Arc); + +impl RouteVersion { + fn new() -> Self { + // RouteVersion(Arc::new(AtomicU32::new(rand::random()))) + RouteVersion(Arc::new(AtomicU32::new(0))) + } + + fn get(&self) -> Version { + self.0.load(std::sync::atomic::Ordering::Relaxed) + } + + fn inc(&self) { + self.0.fetch_add(1, std::sync::atomic::Ordering::Relaxed); + } +} + pub struct BasicRoute { my_peer_id: packet::UUID, global_ctx: ArcGlobalCtx, @@ -149,13 +189,17 @@ pub struct BasicRoute { tasks: Mutex>, need_sync_notifier: Arc, + + version: RouteVersion, + myself: Arc>, + last_send_time_map: Arc>, } impl BasicRoute { pub fn new(my_peer_id: Uuid, global_ctx: ArcGlobalCtx) -> Self { BasicRoute { my_peer_id: my_peer_id.into(), - global_ctx, + global_ctx: global_ctx.clone(), interface: Arc::new(Mutex::new(None)), route_table: Arc::new(RouteTable::new()), @@ -164,6 +208,13 @@ impl BasicRoute { tasks: Mutex::new(JoinSet::new()), need_sync_notifier: Arc::new(tokio::sync::Notify::new()), + + version: RouteVersion::new(), + myself: Arc::new(RwLock::new(SyncPeerInfo::new_self( + my_peer_id.into(), + &global_ctx, + ))), + last_send_time_map: Arc::new(DashMap::new()), } } @@ -186,6 +237,16 @@ impl BasicRoute { route_table.copy_from(&new_route_table); } + async fn update_myself(myself: &Arc>, global_ctx: &ArcGlobalCtx) -> bool { + let new_myself = SyncPeerInfo::new_self(global_ctx.get_id().into(), &global_ctx); + if *myself.read().await != new_myself { + *myself.write().await = new_myself; + true + } else { + false + } + } + fn update_route_table_with_req( my_id: packet::UUID, packet: &SyncPeer, @@ -210,7 +271,7 @@ impl BasicRoute { .value() .clone(); - if ret.cost > 32 { + if ret.cost > 6 { log::error!( "cost too large: {}, may lost connection, remove it", ret.cost @@ -260,6 +321,9 @@ impl BasicRoute { global_ctx: ArcGlobalCtx, peer_id: PeerId, route_table: Arc, + my_version: Version, + peer_version: Option, + need_reply: bool, ) -> Result<(), Error> { let mut route_info_copy: Vec = Vec::new(); // copy the route info @@ -267,7 +331,15 @@ impl BasicRoute { let (k, v) = item.pair(); route_info_copy.push(v.clone().clone_for_route_table(&(*k).into(), v.cost, &v)); } - let msg = SyncPeer::new(my_peer_id, peer_id.into(), route_info_copy, global_ctx); + let msg = SyncPeer::new( + my_peer_id, + peer_id.into(), + route_info_copy, + global_ctx, + my_version, + peer_version, + need_reply, + ); // TODO: this may exceed the MTU of the tunnel interface .send_route_packet(encode_to_bytes::<_, 4096>(&msg), 1, &peer_id) @@ -280,19 +352,54 @@ impl BasicRoute { let my_peer_id = self.my_peer_id.clone(); let interface = self.interface.clone(); let notifier = self.need_sync_notifier.clone(); + let sync_peer_from_remote = self.sync_peer_from_remote.clone(); + let myself = self.myself.clone(); + let version = self.version.clone(); + let last_send_time_map = self.last_send_time_map.clone(); self.tasks.lock().await.spawn( async move { loop { + if Self::update_myself(&myself, &global_ctx).await { + version.inc(); + tracing::info!( + my_id = ?my_peer_id, + version = version.get(), + "update route table version when myself changed" + ); + } + let lockd_interface = interface.lock().await; let interface = lockd_interface.as_ref().unwrap(); + let last_send_time_map_new = DashMap::new(); let peers = interface.list_peers().await; for peer in peers.iter() { + let last_send_time = last_send_time_map.get(peer).map(|v| *v).unwrap_or(Instant::now() - Duration::from_secs(3600)); + let my_version_peer_saved = sync_peer_from_remote.get(&peer).and_then(|v| v.packet.peer_version); + let peer_have_latest_version = my_version_peer_saved == Some(version.get()); + if peer_have_latest_version && last_send_time.elapsed().as_secs() < 60 { + last_send_time_map_new.insert(*peer, last_send_time); + continue; + } + + tracing::info!( + my_id = ?my_peer_id, + dst_peer_id = ?peer, + version = version.get(), + ?my_version_peer_saved, + last_send_elapse = ?last_send_time.elapsed().as_secs(), + "need send route info" + ); + last_send_time_map_new.insert(*peer, Instant::now()); + let peer_version_we_saved = sync_peer_from_remote.get(&peer).and_then(|v| Some(v.packet.version)); let ret = Self::send_sync_peer_request( interface, my_peer_id.clone(), global_ctx.clone(), *peer, route_table.clone(), + version.get(), + peer_version_we_saved, + !peer_have_latest_version, ) .await; @@ -312,6 +419,13 @@ impl BasicRoute { } }; } + + last_send_time_map.clear(); + for item in last_send_time_map_new.iter() { + let (k, v) = item.pair(); + last_send_time_map.insert(*k, *v); + } + tokio::select! { _ = notifier.notified() => { log::trace!("sync peer request triggered by notifier"); @@ -333,14 +447,19 @@ impl BasicRoute { let my_peer_id = self.my_peer_id.clone(); let sync_peer_from_remote = self.sync_peer_from_remote.clone(); let notifier = self.need_sync_notifier.clone(); + let interface = self.interface.clone(); + let version = self.version.clone(); self.tasks.lock().await.spawn(async move { loop { let mut need_update_route = false; let now = std::time::Instant::now(); let mut need_remove = Vec::new(); + let connected_peers = interface.lock().await.as_ref().unwrap().list_peers().await; for item in sync_peer_from_remote.iter() { let (k, v) = item.pair(); - if now.duration_since(v.last_update).as_secs() > 5 { + if now.duration_since(v.last_update).as_secs() > 70 + || !connected_peers.contains(k) + { need_update_route = true; need_remove.insert(0, k.clone()); } @@ -357,6 +476,12 @@ impl BasicRoute { sync_peer_from_remote.clone(), route_table.clone(), ); + version.inc(); + tracing::info!( + my_id = ?my_peer_id, + version = version.get(), + "update route table when check expired peer" + ); notifier.notify_one(); } @@ -385,11 +510,13 @@ impl BasicRoute { self.sync_peer_from_remote .entry(packet.myself.peer_id.to_uuid()) .and_modify(|v| { - if v.packet == *packet { + if v.packet.myself == p.myself && v.packet.neighbors == p.neighbors { updated = false; } else { v.packet = p.clone(); } + v.packet.version = p.version; + v.packet.peer_version = p.peer_version; v.last_update = std::time::Instant::now(); }) .or_insert(SyncPeerFromRemote { @@ -403,6 +530,21 @@ impl BasicRoute { self.sync_peer_from_remote.clone(), self.route_table.clone(), ); + self.version.inc(); + tracing::info!( + my_id = ?self.my_peer_id, + ?p, + version = self.version.get(), + "update route table when receive route packet" + ); + } + + if packet.need_reply { + self.last_send_time_map + .remove(&packet.myself.peer_id.to_uuid()); + } + + if updated || packet.need_reply { self.need_sync_notifier.notify_one(); } } @@ -502,3 +644,111 @@ impl PeerPacketFilter for BasicRoute { } } } + +#[cfg(test)] +mod tests { + use std::sync::Arc; + + use crate::{ + connector::udp_hole_punch::tests::create_mock_peer_manager_with_mock_stun, + peers::{ + peer_manager::PeerManager, + peer_rip_route::Version, + tests::{connect_peer_manager, wait_route_appear}, + PeerId, + }, + rpc::NatType, + }; + + #[tokio::test] + async fn test_rip_route() { + let peer_mgr_a = create_mock_peer_manager_with_mock_stun(NatType::Unknown).await; + let peer_mgr_b = create_mock_peer_manager_with_mock_stun(NatType::Unknown).await; + let peer_mgr_c = create_mock_peer_manager_with_mock_stun(NatType::Unknown).await; + connect_peer_manager(peer_mgr_a.clone(), peer_mgr_b.clone()).await; + connect_peer_manager(peer_mgr_b.clone(), peer_mgr_c.clone()).await; + wait_route_appear(peer_mgr_a.clone(), peer_mgr_b.my_node_id()) + .await + .unwrap(); + wait_route_appear(peer_mgr_a.clone(), peer_mgr_c.my_node_id()) + .await + .unwrap(); + + let mgrs = vec![peer_mgr_a.clone(), peer_mgr_b.clone(), peer_mgr_c.clone()]; + + tokio::time::sleep(tokio::time::Duration::from_secs(4)).await; + + let check_version = |version: Version, uuid: PeerId, mgrs: &Vec>| { + for mgr in mgrs.iter() { + tracing::warn!( + "check version: {:?}, {:?}, {:?}, {:?}", + version, + uuid, + mgr, + mgr.get_basic_route().sync_peer_from_remote + ); + assert_eq!( + version, + mgr.get_basic_route() + .sync_peer_from_remote + .get(&uuid) + .unwrap() + .packet + .version, + ); + assert_eq!( + mgr.get_basic_route() + .sync_peer_from_remote + .get(&uuid) + .unwrap() + .packet + .peer_version + .unwrap(), + mgr.get_basic_route().version.get() + ); + } + }; + + let check_sanity = || { + // check peer version in other peer mgr are correct. + check_version( + peer_mgr_b.get_basic_route().version.get(), + peer_mgr_b.my_node_id(), + &vec![peer_mgr_a.clone(), peer_mgr_c.clone()], + ); + + check_version( + peer_mgr_a.get_basic_route().version.get(), + peer_mgr_a.my_node_id(), + &vec![peer_mgr_b.clone()], + ); + + check_version( + peer_mgr_c.get_basic_route().version.get(), + peer_mgr_c.my_node_id(), + &vec![peer_mgr_b.clone()], + ); + }; + + check_sanity(); + + let versions = mgrs + .iter() + .map(|x| x.get_basic_route().version.get()) + .collect::>(); + + tokio::time::sleep(tokio::time::Duration::from_secs(5)).await; + + let versions2 = mgrs + .iter() + .map(|x| x.get_basic_route().version.get()) + .collect::>(); + + assert_eq!(versions, versions2); + check_sanity(); + + assert!(peer_mgr_a.get_basic_route().version.get() <= 3); + assert!(peer_mgr_b.get_basic_route().version.get() <= 6); + assert!(peer_mgr_c.get_basic_route().version.get() <= 3); + } +} diff --git a/easytier-core/src/tunnels/udp_tunnel.rs b/easytier-core/src/tunnels/udp_tunnel.rs index ad1789ee..b2d649a1 100644 --- a/easytier-core/src/tunnels/udp_tunnel.rs +++ b/easytier-core/src/tunnels/udp_tunnel.rs @@ -616,7 +616,7 @@ mod tests { #[tokio::test] async fn udp_multiple_conns() { - let mut listener = UdpTunnelListener::new("udp://0.0.0.0:5556".parse().unwrap()); + let mut listener = UdpTunnelListener::new("udp://0.0.0.0:5557".parse().unwrap()); listener.listen().await.unwrap(); let _lis = tokio::spawn(async move { @@ -630,8 +630,8 @@ mod tests { } }); - let mut connector1 = UdpTunnelConnector::new("udp://127.0.0.1:5556".parse().unwrap()); - let mut connector2 = UdpTunnelConnector::new("udp://127.0.0.1:5556".parse().unwrap()); + let mut connector1 = UdpTunnelConnector::new("udp://127.0.0.1:5557".parse().unwrap()); + let mut connector2 = UdpTunnelConnector::new("udp://127.0.0.1:5557".parse().unwrap()); let t1 = connector1.connect().await.unwrap(); let t2 = connector2.connect().await.unwrap();