From 72f86025bdda4e22c8c7e929daa7d05232c8da19 Mon Sep 17 00:00:00 2001 From: "sijie.sun" Date: Sun, 12 May 2024 11:48:13 +0800 Subject: [PATCH] support custom cost calculate func when generating route table --- easytier/src/peers/peer_ospf_route.rs | 146 ++++++++++++++++++++++++-- easytier/src/peers/route_trait.rs | 44 ++++++++ 2 files changed, 181 insertions(+), 9 deletions(-) diff --git a/easytier/src/peers/peer_ospf_route.rs b/easytier/src/peers/peer_ospf_route.rs index 59bdbf43..0e9acd71 100644 --- a/easytier/src/peers/peer_ospf_route.rs +++ b/easytier/src/peers/peer_ospf_route.rs @@ -19,7 +19,14 @@ use crate::{ rpc::{NatType, StunInfo}, }; -use super::{peer_rpc::PeerRpcManager, PeerPacketFilter}; +use super::{ + peer_rpc::PeerRpcManager, + route_trait::{ + DefaultRouteCostCalculator, NextHopPolicy, RouteCostCalculator, + RouteCostCalculatorInterface, + }, + PeerPacketFilter, +}; static SERVICE_ID: u32 = 7; static UPDATE_PEER_INFO_PERIOD: Duration = Duration::from_secs(3600); @@ -393,7 +400,12 @@ impl RouteTable { .map(|x| NatType::try_from(x.udp_stun_info as i32).unwrap()) } - fn build_from_synced_info(&self, my_peer_id: PeerId, synced_info: &SyncedRouteInfo) { + fn build_from_synced_info( + &self, + my_peer_id: PeerId, + synced_info: &SyncedRouteInfo, + cost_calc: T, + ) { // build peer_infos self.peer_infos.clear(); for item in synced_info.peer_infos.iter() { @@ -415,12 +427,18 @@ impl RouteTable { if peer_id == my_peer_id { continue; } - let Some(path) = pathfinding::prelude::bfs( + let Some((path, _cost)): Option<(Vec, i32)> = pathfinding::prelude::dijkstra( &my_peer_id, - |p| { + |src_peer| { synced_info - .get_connected_peers(*p) + .get_connected_peers(*src_peer) .unwrap_or_else(|| BTreeSet::new()) + .into_iter() + .map(|dst_peer| { + let cost = cost_calc.calculate_cost(*src_peer, dst_peer); + (dst_peer, cost) + }) + .collect::>() }, |x| *x == peer_id, ) else { @@ -563,7 +581,9 @@ struct PeerRouteServiceImpl { interface: Arc>>, + cost_calculator: Arc>>, route_table: RouteTable, + route_table_with_cost: RouteTable, synced_route_info: Arc, cached_local_conn_map: std::sync::Mutex, } @@ -585,9 +605,17 @@ impl PeerRouteServiceImpl { PeerRouteServiceImpl { my_peer_id, global_ctx, - interface: Arc::new(Mutex::new(None)), sessions: DashMap::new(), + + interface: Arc::new(Mutex::new(None)), + + cost_calculator: Arc::new(std::sync::Mutex::new(Some(Box::new( + DefaultRouteCostCalculator, + )))), + route_table: RouteTable::new(), + route_table_with_cost: RouteTable::new(), + synced_route_info: Arc::new(SyncedRouteInfo { peer_infos: DashMap::new(), conn_map: DashMap::new(), @@ -649,8 +677,31 @@ impl PeerRouteServiceImpl { } fn update_route_table(&self) { - self.route_table - .build_from_synced_info(self.my_peer_id, &self.synced_route_info); + self.route_table.build_from_synced_info( + self.my_peer_id, + &self.synced_route_info, + DefaultRouteCostCalculator::default(), + ); + + let calc_locked = self.cost_calculator.lock().unwrap(); + if calc_locked.is_none() { + return; + } + + self.route_table_with_cost.build_from_synced_info( + self.my_peer_id, + &self.synced_route_info, + &calc_locked.as_ref().unwrap(), + ); + } + + fn cost_calculator_need_update(&self) -> bool { + self.cost_calculator + .lock() + .unwrap() + .as_ref() + .map(|x| x.need_update()) + .unwrap_or(false) } fn update_route_table_and_cached_local_conn_bitmap(&self) { @@ -1183,6 +1234,10 @@ impl PeerRoute { session_mgr.sync_now("update_my_infos"); } + if service_impl.cost_calculator_need_update() { + service_impl.update_route_table(); + } + select! { ev = global_event_receiver.recv() => { tracing::info!(?ev, "global event received in update_my_peer_info_routine"); @@ -1234,6 +1289,19 @@ impl Route for PeerRoute { route_table.get_next_hop(dst_peer_id).map(|x| x.0) } + async fn get_next_hop_with_policy( + &self, + dst_peer_id: PeerId, + policy: NextHopPolicy, + ) -> Option { + let route_table = if matches!(policy, NextHopPolicy::LeastCost) { + &self.service_impl.route_table_with_cost + } else { + &self.service_impl.route_table + }; + route_table.get_next_hop(dst_peer_id).map(|x| x.0) + } + async fn list_routes(&self) -> Vec { let route_table = &self.service_impl.route_table; let mut routes = Vec::new(); @@ -1265,6 +1333,11 @@ impl Route for PeerRoute { tracing::info!(?ipv4_addr, "no peer id for ipv4"); None } + + async fn set_route_cost_fn(&self, _cost_fn: RouteCostCalculator) { + *self.service_impl.cost_calculator.lock().unwrap() = Some(_cost_fn); + self.service_impl.update_route_table(); + } } impl PeerPacketFilter for Arc {} @@ -1282,7 +1355,7 @@ mod tests { connector::udp_hole_punch::tests::replace_stun_info_collector, peers::{ peer_manager::{PeerManager, RouteAlgoType}, - route_trait::Route, + route_trait::{NextHopPolicy, Route, RouteCostCalculatorInterface}, tests::{connect_peer_manager, wait_for_condition}, }, rpc::NatType, @@ -1609,4 +1682,59 @@ mod tests { println!("session: {:?}", r_a.session_mgr.dump_sessions()); check_rpc_counter(&r_a, p_b.my_peer_id(), 2, 2); } + + #[tokio::test] + async fn test_cost_calculator() { + let p_a = create_mock_pmgr().await; + let p_b = create_mock_pmgr().await; + let p_c = create_mock_pmgr().await; + connect_peer_manager(p_a.clone(), p_b.clone()).await; + connect_peer_manager(p_c.clone(), p_b.clone()).await; + connect_peer_manager(p_a.clone(), p_c.clone()).await; + + let _r_a = create_mock_route(p_a.clone()).await; + let _r_b = create_mock_route(p_b.clone()).await; + let r_c = create_mock_route(p_c.clone()).await; + + // in normal mode, packet from p_c should directly forward to p_a + wait_for_condition( + || async { r_c.get_next_hop(p_a.my_peer_id()).await == Some(p_a.my_peer_id()) }, + Duration::from_secs(5), + ) + .await; + + struct TestCostCalculator { + p_a_peer_id: PeerId, + p_b_peer_id: PeerId, + p_c_peer_id: PeerId, + } + + impl RouteCostCalculatorInterface for TestCostCalculator { + fn calculate_cost(&self, src: PeerId, dst: PeerId) -> i32 { + if src == self.p_c_peer_id && dst == self.p_a_peer_id { + return 100; + } + + 1 + } + } + + r_c.set_route_cost_fn(Box::new(TestCostCalculator { + p_a_peer_id: p_a.my_peer_id(), + p_b_peer_id: p_b.my_peer_id(), + p_c_peer_id: p_c.my_peer_id(), + })) + .await; + + // after set cost, packet from p_c should forward to p_b first + wait_for_condition( + || async { + r_c.get_next_hop_with_policy(p_a.my_peer_id(), NextHopPolicy::LeastCost) + .await + == Some(p_b.my_peer_id()) + }, + Duration::from_secs(5), + ) + .await; + } } diff --git a/easytier/src/peers/route_trait.rs b/easytier/src/peers/route_trait.rs index 66364355..d232ab47 100644 --- a/easytier/src/peers/route_trait.rs +++ b/easytier/src/peers/route_trait.rs @@ -5,6 +5,18 @@ use tokio_util::bytes::Bytes; use crate::common::{error::Error, PeerId}; +#[derive(Clone, Debug)] +pub enum NextHopPolicy { + LeastHop, + LeastCost, +} + +impl Default for NextHopPolicy { + fn default() -> Self { + NextHopPolicy::LeastHop + } +} + #[async_trait] pub trait RouteInterface { async fn list_peers(&self) -> Vec; @@ -19,6 +31,28 @@ pub trait RouteInterface { pub type RouteInterfaceBox = Box; +#[auto_impl::auto_impl(Box, Arc, &)] +pub trait RouteCostCalculatorInterface: Send + Sync { + fn calculate_cost(&self, _src: PeerId, _dst: PeerId) -> i32 { + 1 + } + + fn need_update(&self) -> bool { + false + } + + fn dump(&self) -> String { + "All routes have cost 1".to_string() + } +} + +#[derive(Clone, Debug, Default)] +pub struct DefaultRouteCostCalculator; + +impl RouteCostCalculatorInterface for DefaultRouteCostCalculator {} + +pub type RouteCostCalculator = Box; + #[async_trait] #[auto_impl::auto_impl(Box, Arc)] pub trait Route { @@ -26,11 +60,21 @@ pub trait Route { async fn close(&self); async fn get_next_hop(&self, peer_id: PeerId) -> Option; + async fn get_next_hop_with_policy( + &self, + peer_id: PeerId, + _policy: NextHopPolicy, + ) -> Option { + self.get_next_hop(peer_id).await + } + async fn list_routes(&self) -> Vec; async fn get_peer_id_by_ipv4(&self, _ipv4: &Ipv4Addr) -> Option { None } + + async fn set_route_cost_fn(&self, _cost_fn: RouteCostCalculator) {} } pub type ArcRoute = Arc>;