diff --git a/Cargo.lock b/Cargo.lock index 3c6d7572..1c4413ae 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1112,18 +1112,6 @@ dependencies = [ "thiserror", ] -[[package]] -name = "deprecate-until" -version = "0.1.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7a3767f826efbbe5a5ae093920b58b43b01734202be697e1354914e862e8e704" -dependencies = [ - "proc-macro2", - "quote", - "semver", - "syn 2.0.48", -] - [[package]] name = "deranged" version = "0.3.11" @@ -1300,8 +1288,8 @@ dependencies = [ "network-interface", "nix 0.27.1", "once_cell", - "pathfinding", "percent-encoding", + "petgraph", "pin-project-lite", "pnet", "postcard", @@ -2342,15 +2330,6 @@ dependencies = [ "cfg-if", ] -[[package]] -name = "integer-sqrt" -version = "0.1.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "276ec31bcb4a9ee45f58bec6f9ec700ae4cf4f4f8f2fa7e06cb406bd5ffdd770" -dependencies = [ - "num-traits", -] - [[package]] name = "ioctl-sys" version = "0.8.0" @@ -3225,21 +3204,6 @@ version = "0.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8835116a5c179084a830efb3adc117ab007512b535bc1a21c991d3b32a6b44dd" -[[package]] -name = "pathfinding" -version = "4.9.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f0a21c30f03223ae4a4c892f077b3189133689b8a659a84372f8422384ce94c9" -dependencies = [ - "deprecate-until", - "fixedbitset", - "indexmap 2.2.6", - "integer-sqrt", - "num-traits", - "rustc-hash", - "thiserror", -] - [[package]] name = "pbkdf2" version = "0.11.0" @@ -3270,9 +3234,9 @@ checksum = "e3148f5046208a5d56bcfc03053e3ca6334e51da8dfb19b6cdc8b306fae3283e" [[package]] name = "petgraph" -version = "0.6.4" +version = "0.6.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e1d3afd2628e69da2be385eb6f2fd57c8ac7977ceeff6dc166ff1657b0e386a9" +checksum = "b4c5cc86750666a3ed20bdaf5ca2a0344f9c67674cae0515bec2da16fbaa47db" dependencies = [ "fixedbitset", "indexmap 2.2.6", diff --git a/easytier/Cargo.toml b/easytier/Cargo.toml index 15875a71..09add658 100644 --- a/easytier/Cargo.toml +++ b/easytier/Cargo.toml @@ -137,7 +137,7 @@ async-recursion = "1.0.5" network-interface = "1.1.1" # for ospf route -pathfinding = "4.9.1" +petgraph = "0.6.5" # for encryption boringtun = { git = "https://github.com/EasyTier/boringtun.git", optional = true, rev = "449204c" } diff --git a/easytier/proto/cli.proto b/easytier/proto/cli.proto index bbc08e02..675093d8 100644 --- a/easytier/proto/cli.proto +++ b/easytier/proto/cli.proto @@ -117,16 +117,8 @@ service ConnectorManageRpc { rpc ManageConnector (ManageConnectorRequest) returns (ManageConnectorResponse); } -enum LatencyLevel { - VeryLow = 0; - Low = 1; - Normal = 2; - High = 3; - VeryHigh = 4; -} - message DirectConnectedPeerInfo { - LatencyLevel latency_level = 2; + int32 latency_ms = 1; } message PeerInfoForGlobalMap { diff --git a/easytier/src/common/config.rs b/easytier/src/common/config.rs index 9604e373..76c1dd73 100644 --- a/easytier/src/common/config.rs +++ b/easytier/src/common/config.rs @@ -150,6 +150,8 @@ pub struct Flags { pub enable_ipv6: bool, #[derivative(Default(value = "1420"))] pub mtu: u16, + #[derivative(Default(value = "true"))] + pub latency_first: bool, } #[derive(Debug, Clone, Deserialize, Serialize, PartialEq)] diff --git a/easytier/src/easytier-cli.rs b/easytier/src/easytier-cli.rs index 44dc2230..97541ba2 100644 --- a/easytier/src/easytier-cli.rs +++ b/easytier/src/easytier-cli.rs @@ -331,13 +331,7 @@ async fn main() -> Result<(), Error> { let direct_peers = v .direct_peers .iter() - .map(|(k, v)| { - format!( - "{}:{:?}", - k, - LatencyLevel::try_from(v.latency_level).unwrap() - ) - }) + .map(|(k, v)| format!("{}: {:?}ms", k, v.latency_ms,)) .collect::>(); table_rows.push(PeerCenterTableItem { node_id: node_id.to_string(), diff --git a/easytier/src/easytier-core.rs b/easytier/src/easytier-core.rs index 721b4768..ba7ec006 100644 --- a/easytier/src/easytier-core.rs +++ b/easytier/src/easytier-core.rs @@ -97,10 +97,6 @@ struct Cli { "wg://0.0.0.0:11011".to_string()])] listeners: Vec, - /// specify the linux network namespace, default is the root namespace - #[arg(long)] - net_ns: Option, - #[arg(long, help = "console log level", value_parser = clap::builder::PossibleValuesParser::new(["trace", "debug", "info", "warn", "error", "off"]))] console_log_level: Option, @@ -122,13 +118,6 @@ struct Cli { )] instance_name: String, - #[arg( - short = 'd', - long, - help = "instance uuid to identify this vpn node in whole vpn network example: 123e4567-e89b-12d3-a456-426614174000" - )] - instance_id: Option, - #[arg( long, help = "url that defines the vpn portal, allow other vpn clients to connect. @@ -163,6 +152,13 @@ and the vpn client is in network of 10.14.14.0/24" help = "mtu of the TUN device, default is 1420 for non-encryption, 1400 for encryption" )] mtu: Option, + + #[arg( + long, + help = "path to the log file, if not set, will print to stdout", + default_value = "false" + )] + latency_first: bool, } impl From for TomlConfigLoader { @@ -188,7 +184,6 @@ impl From for TomlConfigLoader { cli.network_secret.clone(), )); - cfg.set_netns(cli.net_ns.clone()); if let Some(ipv4) = &cli.ipv4 { cfg.set_ipv4( ipv4.parse() @@ -307,6 +302,7 @@ impl From for TomlConfigLoader { } f.enable_encryption = !cli.disable_encryption; f.enable_ipv6 = !cli.disable_ipv6; + f.latency_first = cli.latency_first; if let Some(mtu) = cli.mtu { f.mtu = mtu; } diff --git a/easytier/src/instance/instance.rs b/easytier/src/instance/instance.rs index 07ea9124..7ad1e918 100644 --- a/easytier/src/instance/instance.rs +++ b/easytier/src/instance/instance.rs @@ -1,6 +1,7 @@ use std::borrow::BorrowMut; use std::net::Ipv4Addr; use std::pin::Pin; +use std::sync::atomic::{AtomicBool, Ordering}; use std::sync::{Arc, Weak}; use anyhow::Context; @@ -44,6 +45,8 @@ struct IpProxy { tcp_proxy: Arc, icmp_proxy: Arc, udp_proxy: Arc, + global_ctx: ArcGlobalCtx, + started: Arc, } impl IpProxy { @@ -57,10 +60,17 @@ impl IpProxy { tcp_proxy, icmp_proxy, udp_proxy, + global_ctx, + started: Arc::new(AtomicBool::new(false)), }) } async fn start(&self) -> Result<(), Error> { + if self.global_ctx.get_proxy_cidrs().is_empty() || self.started.load(Ordering::Relaxed) { + return Ok(()); + } + + self.started.store(true, Ordering::Relaxed); self.tcp_proxy.start().await?; self.icmp_proxy.start().await?; self.udp_proxy.start().await?; @@ -297,11 +307,16 @@ impl Instance { self.get_global_ctx(), self.get_peer_manager(), )?); - self.ip_proxy.as_ref().unwrap().start().await?; + self.run_ip_proxy().await?; self.udp_hole_puncher.lock().await.run().await?; self.peer_center.init().await; + let route_calc = self.peer_center.get_cost_calculator(); + self.peer_manager + .get_route() + .set_route_cost_fn(route_calc) + .await; self.add_initial_peers().await?; @@ -312,6 +327,14 @@ impl Instance { Ok(()) } + pub async fn run_ip_proxy(&mut self) -> Result<(), Error> { + if self.ip_proxy.is_none() { + return Err(anyhow::anyhow!("ip proxy not enabled.").into()); + } + self.ip_proxy.as_ref().unwrap().start().await?; + Ok(()) + } + pub async fn run_vpn_portal(&mut self) -> Result<(), Error> { if self.global_ctx.get_vpn_portal_cidr().is_none() { return Err(anyhow::anyhow!("vpn portal cidr not set.").into()); diff --git a/easytier/src/peer_center/instance.rs b/easytier/src/peer_center/instance.rs index 50ca5a6f..85f8208b 100644 --- a/easytier/src/peer_center/instance.rs +++ b/easytier/src/peer_center/instance.rs @@ -1,24 +1,23 @@ use std::{ - collections::hash_map::DefaultHasher, - hash::{Hash, Hasher}, - sync::{ - atomic::{AtomicBool, Ordering}, - Arc, - }, - time::{Duration, SystemTime}, + collections::BTreeSet, + sync::Arc, + time::{Duration, Instant, SystemTime}, }; use crossbeam::atomic::AtomicCell; use futures::Future; -use tokio::{ - sync::{Mutex, RwLock}, - task::JoinSet, -}; +use std::sync::RwLock; +use tokio::sync::Mutex; +use tokio::task::JoinSet; use tracing::Instrument; use crate::{ common::PeerId, - peers::{peer_manager::PeerManager, rpc_service::PeerManagerRpcService}, + peers::{ + peer_manager::PeerManager, + route_trait::{RouteCostCalculator, RouteCostCalculatorInterface}, + rpc_service::PeerManagerRpcService, + }, rpc::{GetGlobalPeerMapRequest, GetGlobalPeerMapResponse}, }; @@ -34,7 +33,8 @@ struct PeerCenterBase { lock: Arc>, } -static SERVICE_ID: u32 = 5; +// static SERVICE_ID: u32 = 5; for compatibility with the original code +static SERVICE_ID: u32 = 50; struct PeridicJobCtx { peer_mgr: Arc, @@ -132,7 +132,7 @@ impl PeerCenterBase { pub struct PeerCenterInstanceService { global_peer_map: Arc>, - global_peer_map_digest: Arc>, + global_peer_map_digest: Arc>, } #[tonic::async_trait] @@ -141,7 +141,7 @@ impl crate::rpc::cli::peer_center_rpc_server::PeerCenterRpc for PeerCenterInstan &self, _request: tonic::Request, ) -> Result, tonic::Status> { - let global_peer_map = self.global_peer_map.read().await.clone(); + let global_peer_map = self.global_peer_map.read().unwrap().clone(); Ok(tonic::Response::new(GetGlobalPeerMapResponse { global_peer_map: global_peer_map .map @@ -157,7 +157,8 @@ pub struct PeerCenterInstance { client: Arc, global_peer_map: Arc>, - global_peer_map_digest: Arc>, + global_peer_map_digest: Arc>, + global_peer_map_update_time: Arc>, } impl PeerCenterInstance { @@ -166,7 +167,8 @@ impl PeerCenterInstance { peer_mgr: peer_mgr.clone(), client: Arc::new(PeerCenterBase::new(peer_mgr.clone())), global_peer_map: Arc::new(RwLock::new(GlobalPeerMap::new())), - global_peer_map_digest: Arc::new(RwLock::new(Digest::default())), + global_peer_map_digest: Arc::new(AtomicCell::new(Digest::default())), + global_peer_map_update_time: Arc::new(AtomicCell::new(Instant::now())), } } @@ -179,12 +181,14 @@ impl PeerCenterInstance { async fn init_get_global_info_job(&self) { struct Ctx { global_peer_map: Arc>, - global_peer_map_digest: Arc>, + global_peer_map_digest: Arc>, + global_peer_map_update_time: Arc>, } let ctx = Arc::new(Ctx { global_peer_map: self.global_peer_map.clone(), global_peer_map_digest: self.global_peer_map_digest.clone(), + global_peer_map_update_time: self.global_peer_map_update_time.clone(), }); self.client @@ -193,10 +197,7 @@ impl PeerCenterInstance { rpc_ctx.deadline = SystemTime::now() + Duration::from_secs(3); let ret = client - .get_global_peer_map( - rpc_ctx, - ctx.job_ctx.global_peer_map_digest.read().await.clone(), - ) + .get_global_peer_map(rpc_ctx, ctx.job_ctx.global_peer_map_digest.load()) .await?; let Ok(resp) = ret else { @@ -217,10 +218,13 @@ impl PeerCenterInstance { resp.digest ); - *ctx.job_ctx.global_peer_map.write().await = resp.global_peer_map; - *ctx.job_ctx.global_peer_map_digest.write().await = resp.digest; + *ctx.job_ctx.global_peer_map.write().unwrap() = resp.global_peer_map; + ctx.job_ctx.global_peer_map_digest.store(resp.digest); + ctx.job_ctx + .global_peer_map_update_time + .store(Instant::now()); - Ok(10000) + Ok(5000) }) .await; } @@ -228,67 +232,53 @@ impl PeerCenterInstance { async fn init_report_peers_job(&self) { struct Ctx { service: PeerManagerRpcService, - need_send_peers: AtomicBool, - last_report_peers: Mutex, + + last_report_peers: Mutex>, + last_center_peer: AtomicCell, + last_report_time: AtomicCell, } let ctx = Arc::new(Ctx { service: PeerManagerRpcService::new(self.peer_mgr.clone()), - need_send_peers: AtomicBool::new(true), - last_report_peers: Mutex::new(PeerInfoForGlobalMap::default()), + last_report_peers: Mutex::new(BTreeSet::new()), last_center_peer: AtomicCell::new(PeerId::default()), + last_report_time: AtomicCell::new(Instant::now()), }); self.client .init_periodic_job(ctx, |client, ctx| async move { let my_node_id = ctx.peer_mgr.my_peer_id(); + let peers: PeerInfoForGlobalMap = ctx.job_ctx.service.list_peers().await.into(); + let peer_list = peers.direct_peers.keys().map(|k| *k).collect(); + let job_ctx = &ctx.job_ctx; - // if peers are not same in next 10 seconds, report peers to center server - let mut peers = PeerInfoForGlobalMap::default(); - for _ in 1..10 { - peers = ctx.job_ctx.service.list_peers().await.into(); - if ctx.center_peer.load() != ctx.job_ctx.last_center_peer.load() { - // if center peer changed, report peers immediately - break; - } - if peers == *ctx.job_ctx.last_report_peers.lock().await { - return Ok(3000); - } - tokio::time::sleep(Duration::from_secs(2)).await; + // only report when: + // 1. center peer changed + // 2. last report time is more than 60 seconds + // 3. peers changed + if ctx.center_peer.load() == ctx.job_ctx.last_center_peer.load() + && job_ctx.last_report_time.load().elapsed().as_secs() < 60 + && *job_ctx.last_report_peers.lock().await == peer_list + { + return Ok(5000); } - *ctx.job_ctx.last_report_peers.lock().await = peers.clone(); - let mut hasher = DefaultHasher::new(); - peers.hash(&mut hasher); - - let peers = if ctx.job_ctx.need_send_peers.load(Ordering::Relaxed) { - Some(peers) - } else { - None - }; let mut rpc_ctx = tarpc::context::current(); rpc_ctx.deadline = SystemTime::now() + Duration::from_secs(3); let ret = client - .report_peers( - rpc_ctx, - my_node_id.clone(), - peers, - hasher.finish() as Digest, - ) + .report_peers(rpc_ctx, my_node_id.clone(), peers) .await?; - if matches!(ret.as_ref().err(), Some(Error::DigestMismatch)) { - ctx.job_ctx.need_send_peers.store(true, Ordering::Relaxed); - return Ok(0); - } else if ret.is_err() { + if ret.is_ok() { + ctx.job_ctx.last_center_peer.store(ctx.center_peer.load()); + *ctx.job_ctx.last_report_peers.lock().await = peer_list; + ctx.job_ctx.last_report_time.store(Instant::now()); + } else { tracing::error!("report peers to center server got error result: {:?}", ret); - return Ok(500); } - ctx.job_ctx.last_center_peer.store(ctx.center_peer.load()); - ctx.job_ctx.need_send_peers.store(false, Ordering::Relaxed); - Ok(3000) + Ok(5000) }) .await; } @@ -299,15 +289,61 @@ impl PeerCenterInstance { global_peer_map_digest: self.global_peer_map_digest.clone(), } } + + pub fn get_cost_calculator(&self) -> RouteCostCalculator { + struct RouteCostCalculatorImpl { + global_peer_map: Arc>, + + global_peer_map_clone: GlobalPeerMap, + + last_update_time: AtomicCell, + global_peer_map_update_time: Arc>, + } + + impl RouteCostCalculatorInterface for RouteCostCalculatorImpl { + fn calculate_cost(&self, src: PeerId, dst: PeerId) -> i32 { + let ret = self + .global_peer_map_clone + .map + .get(&src) + .and_then(|src_peer_info| src_peer_info.direct_peers.get(&dst)) + .and_then(|info| Some(info.latency_ms)); + ret.unwrap_or(80) + } + + fn begin_update(&mut self) { + let global_peer_map = self.global_peer_map.read().unwrap(); + self.global_peer_map_clone = global_peer_map.clone(); + } + + fn end_update(&mut self) { + self.last_update_time + .store(self.global_peer_map_update_time.load()); + } + + fn need_update(&self) -> bool { + self.last_update_time.load() < self.global_peer_map_update_time.load() + } + } + + Box::new(RouteCostCalculatorImpl { + global_peer_map: self.global_peer_map.clone(), + global_peer_map_clone: GlobalPeerMap::new(), + last_update_time: AtomicCell::new( + self.global_peer_map_update_time.load() - Duration::from_secs(1), + ), + global_peer_map_update_time: self.global_peer_map_update_time.clone(), + }) + } } #[cfg(test)] mod tests { - use std::ops::Deref; - use crate::{ peer_center::server::get_global_data, - peers::tests::{connect_peer_manager, create_mock_peer_manager, wait_route_appear}, + peers::tests::{ + connect_peer_manager, create_mock_peer_manager, wait_for_condition, wait_route_appear, + }, }; use super::*; @@ -340,43 +376,64 @@ mod tests { let center_data = get_global_data(center_peer); // wait center_data has 3 records for 10 seconds - let now = std::time::Instant::now(); - loop { - if center_data.read().await.global_peer_map.map.len() == 3 { - println!( - "center data ready, {:#?}", - center_data.read().await.global_peer_map - ); - break; - } - if now.elapsed().as_secs() > 60 { - panic!("center data not ready"); - } - tokio::time::sleep(Duration::from_millis(100)).await; - } + wait_for_condition( + || async { + if center_data.global_peer_map.len() == 4 { + println!("center data {:#?}", center_data.global_peer_map); + true + } else { + false + } + }, + Duration::from_secs(10), + ) + .await; let mut digest = None; for pc in peer_centers.iter() { let rpc_service = pc.get_rpc_service(); - let now = std::time::Instant::now(); - while now.elapsed().as_secs() < 10 { - if rpc_service.global_peer_map.read().await.map.len() == 3 { - break; - } - tokio::time::sleep(Duration::from_millis(100)).await; - } - assert_eq!(rpc_service.global_peer_map.read().await.map.len(), 3); + wait_for_condition( + || async { rpc_service.global_peer_map.read().unwrap().map.len() == 3 }, + Duration::from_secs(10), + ) + .await; + println!("rpc service ready, {:#?}", rpc_service.global_peer_map); if digest.is_none() { - digest = Some(rpc_service.global_peer_map_digest.read().await.clone()); + digest = Some(rpc_service.global_peer_map_digest.load()); } else { - let v = rpc_service.global_peer_map_digest.read().await; - assert_eq!(digest.as_ref().unwrap(), v.deref()); + let v = rpc_service.global_peer_map_digest.load(); + assert_eq!(digest.unwrap(), v); } + + let mut route_cost = pc.get_cost_calculator(); + assert!(route_cost.need_update()); + + route_cost.begin_update(); + assert!( + route_cost.calculate_cost(peer_mgr_a.my_peer_id(), peer_mgr_b.my_peer_id()) < 30 + ); + assert!( + route_cost.calculate_cost(peer_mgr_b.my_peer_id(), peer_mgr_a.my_peer_id()) < 30 + ); + assert!( + route_cost.calculate_cost(peer_mgr_b.my_peer_id(), peer_mgr_c.my_peer_id()) < 30 + ); + assert!( + route_cost.calculate_cost(peer_mgr_c.my_peer_id(), peer_mgr_b.my_peer_id()) < 30 + ); + assert!( + route_cost.calculate_cost(peer_mgr_c.my_peer_id(), peer_mgr_a.my_peer_id()) > 50 + ); + assert!( + route_cost.calculate_cost(peer_mgr_a.my_peer_id(), peer_mgr_c.my_peer_id()) > 50 + ); + route_cost.end_update(); + assert!(!route_cost.need_update()); } - let global_digest = get_global_data(center_peer).read().await.digest.clone(); + let global_digest = get_global_data(center_peer).digest.load(); assert_eq!(digest.as_ref().unwrap(), &global_digest); } } diff --git a/easytier/src/peer_center/server.rs b/easytier/src/peer_center/server.rs index 89a929f8..efa6029e 100644 --- a/easytier/src/peer_center/server.rs +++ b/easytier/src/peer_center/server.rs @@ -1,45 +1,48 @@ use std::{ + collections::BinaryHeap, hash::{Hash, Hasher}, sync::Arc, }; +use crossbeam::atomic::AtomicCell; use dashmap::DashMap; use once_cell::sync::Lazy; -use tokio::{sync::RwLock, task::JoinSet}; +use tokio::{task::JoinSet}; -use crate::common::PeerId; +use crate::{common::PeerId, rpc::DirectConnectedPeerInfo}; use super::{ service::{GetGlobalPeerMapResponse, GlobalPeerMap, PeerCenterService, PeerInfoForGlobalMap}, Digest, Error, }; -pub(crate) struct PeerCenterServerGlobalData { - pub global_peer_map: GlobalPeerMap, - pub digest: Digest, - pub update_time: std::time::Instant, - pub peer_update_time: DashMap, +#[derive(Debug, Clone, PartialEq, PartialOrd, Ord, Eq, Hash)] +pub(crate) struct SrcDstPeerPair { + src: PeerId, + dst: PeerId, } -impl PeerCenterServerGlobalData { - fn new() -> Self { - PeerCenterServerGlobalData { - global_peer_map: GlobalPeerMap::new(), - digest: Digest::default(), - update_time: std::time::Instant::now(), - peer_update_time: DashMap::new(), - } - } +#[derive(Debug, Clone)] +pub(crate) struct PeerCenterInfoEntry { + info: DirectConnectedPeerInfo, + update_time: std::time::Instant, +} + +#[derive(Default)] +pub(crate) struct PeerCenterServerGlobalData { + pub(crate) global_peer_map: DashMap, + pub(crate) peer_report_time: DashMap, + pub(crate) digest: AtomicCell, } // a global unique instance for PeerCenterServer -pub(crate) static GLOBAL_DATA: Lazy>>> = +pub(crate) static GLOBAL_DATA: Lazy>> = Lazy::new(DashMap::new); -pub(crate) fn get_global_data(node_id: PeerId) -> Arc> { +pub(crate) fn get_global_data(node_id: PeerId) -> Arc { GLOBAL_DATA .entry(node_id) - .or_insert_with(|| Arc::new(RwLock::new(PeerCenterServerGlobalData::new()))) + .or_insert_with(|| Arc::new(PeerCenterServerGlobalData::default())) .value() .clone() } @@ -48,8 +51,6 @@ pub(crate) fn get_global_data(node_id: PeerId) -> Arc, - tasks: Arc>, } @@ -65,26 +66,32 @@ impl PeerCenterServer { PeerCenterServer { my_node_id, - digest_map: DashMap::new(), - tasks: Arc::new(tasks), } } async fn clean_outdated_peer(my_node_id: PeerId) { let data = get_global_data(my_node_id); - let mut locked_data = data.write().await; - let now = std::time::Instant::now(); - let mut to_remove = Vec::new(); - for kv in locked_data.peer_update_time.iter() { - if now.duration_since(*kv.value()).as_secs() > 20 { - to_remove.push(*kv.key()); - } - } - for peer_id in to_remove { - locked_data.global_peer_map.map.remove(&peer_id); - locked_data.peer_update_time.remove(&peer_id); - } + data.peer_report_time.retain(|_, v| { + std::time::Instant::now().duration_since(*v) < std::time::Duration::from_secs(180) + }); + data.global_peer_map.retain(|_, v| { + std::time::Instant::now().duration_since(v.update_time) + < std::time::Duration::from_secs(180) + }); + } + + fn calc_global_digest(my_node_id: PeerId) -> Digest { + let data = get_global_data(my_node_id); + let mut hasher = std::collections::hash_map::DefaultHasher::new(); + data.global_peer_map + .iter() + .map(|v| v.key().clone()) + .collect::>() + .into_sorted_vec() + .into_iter() + .for_each(|v| v.hash(&mut hasher)); + hasher.finish() } } @@ -95,39 +102,28 @@ impl PeerCenterService for PeerCenterServer { self, _: tarpc::context::Context, my_peer_id: PeerId, - peers: Option, - digest: Digest, + peers: PeerInfoForGlobalMap, ) -> Result<(), Error> { - tracing::trace!("receive report_peers"); + tracing::debug!("receive report_peers"); let data = get_global_data(self.my_node_id); - let mut locked_data = data.write().await; - locked_data - .peer_update_time + data.peer_report_time .insert(my_peer_id, std::time::Instant::now()); - let old_digest = self.digest_map.get(&my_peer_id); - // if digest match, no need to update - if let Some(old_digest) = old_digest { - if *old_digest == digest { - return Ok(()); - } + for (peer_id, peer_info) in peers.direct_peers { + let pair = SrcDstPeerPair { + src: my_peer_id, + dst: peer_id, + }; + let entry = PeerCenterInfoEntry { + info: peer_info, + update_time: std::time::Instant::now(), + }; + data.global_peer_map.insert(pair, entry); } - if peers.is_none() { - return Err(Error::DigestMismatch); - } - - self.digest_map.insert(my_peer_id, digest); - locked_data - .global_peer_map - .map - .insert(my_peer_id, peers.unwrap()); - - let mut hasher = std::collections::hash_map::DefaultHasher::new(); - locked_data.global_peer_map.map.hash(&mut hasher); - locked_data.digest = hasher.finish() as Digest; - locked_data.update_time = std::time::Instant::now(); + data.digest + .store(PeerCenterServer::calc_global_digest(self.my_node_id)); Ok(()) } @@ -138,15 +134,26 @@ impl PeerCenterService for PeerCenterServer { digest: Digest, ) -> Result, Error> { let data = get_global_data(self.my_node_id); - if digest == data.read().await.digest { + if digest == data.digest.load() && digest != 0 { return Ok(None); } - let data = get_global_data(self.my_node_id); - let locked_data = data.read().await; + let mut global_peer_map = GlobalPeerMap::new(); + for item in data.global_peer_map.iter() { + let (pair, entry) = item.pair(); + global_peer_map + .map + .entry(pair.src) + .or_insert_with(|| PeerInfoForGlobalMap { + direct_peers: Default::default(), + }) + .direct_peers + .insert(pair.dst, entry.info.clone()); + } + Ok(Some(GetGlobalPeerMapResponse { - global_peer_map: locked_data.global_peer_map.clone(), - digest: locked_data.digest, + global_peer_map, + digest: data.digest.load(), })) } } diff --git a/easytier/src/peer_center/service.rs b/easytier/src/peer_center/service.rs index 647d3aa7..e6d4d04e 100644 --- a/easytier/src/peer_center/service.rs +++ b/easytier/src/peer_center/service.rs @@ -5,39 +5,23 @@ use crate::{common::PeerId, rpc::DirectConnectedPeerInfo}; use super::{Digest, Error}; use crate::rpc::PeerInfo; -pub type LatencyLevel = crate::rpc::cli::LatencyLevel; - -impl LatencyLevel { - pub const fn from_latency_ms(lat_ms: u32) -> Self { - if lat_ms < 10 { - LatencyLevel::VeryLow - } else if lat_ms < 50 { - LatencyLevel::Low - } else if lat_ms < 100 { - LatencyLevel::Normal - } else if lat_ms < 200 { - LatencyLevel::High - } else { - LatencyLevel::VeryHigh - } - } -} - pub type PeerInfoForGlobalMap = crate::rpc::cli::PeerInfoForGlobalMap; impl From> for PeerInfoForGlobalMap { fn from(peers: Vec) -> Self { let mut peer_map = BTreeMap::new(); for peer in peers { - let min_lat = peer + let Some(min_lat) = peer .conns .iter() .map(|conn| conn.stats.as_ref().unwrap().latency_us) .min() - .unwrap_or(0); + else { + continue; + }; let dp_info = DirectConnectedPeerInfo { - latency_level: LatencyLevel::from_latency_ms(min_lat as u32 / 1000) as i32, + latency_ms: std::cmp::max(1, (min_lat as u32 / 1000) as i32), }; // sort conn info so hash result is stable @@ -73,11 +57,7 @@ pub struct GetGlobalPeerMapResponse { pub trait PeerCenterService { // report center server which peer is directly connected to me // digest is a hash of current peer map, if digest not match, we need to transfer the whole map - async fn report_peers( - my_peer_id: PeerId, - peers: Option, - digest: Digest, - ) -> Result<(), Error>; + async fn report_peers(my_peer_id: PeerId, peers: PeerInfoForGlobalMap) -> Result<(), Error>; async fn get_global_peer_map(digest: Digest) -> Result, Error>; diff --git a/easytier/src/peers/foreign_network_manager.rs b/easytier/src/peers/foreign_network_manager.rs index f0b179b6..35ef2f84 100644 --- a/easytier/src/peers/foreign_network_manager.rs +++ b/easytier/src/peers/foreign_network_manager.rs @@ -29,6 +29,7 @@ use super::{ peer_conn::PeerConn, peer_map::PeerMap, peer_rpc::{PeerRpcManager, PeerRpcManagerTransport}, + route_trait::NextHopPolicy, PacketRecvChan, PacketRecvChanReceiver, }; @@ -66,7 +67,10 @@ impl ForeignNetworkManagerData { .get(&network_name) .ok_or_else(|| Error::RouteError(Some("no peer in network".to_string())))? .clone(); - entry.peer_map.send_msg(msg, dst_peer_id).await + entry + .peer_map + .send_msg(msg, dst_peer_id, NextHopPolicy::LeastHop) + .await } fn get_peer_network(&self, peer_id: PeerId) -> Option { @@ -275,7 +279,10 @@ impl ForeignNetworkManager { } if let Some(entry) = data.get_network_entry(&from_network) { - let ret = entry.peer_map.send_msg(packet_bytes, to_peer_id).await; + let ret = entry + .peer_map + .send_msg(packet_bytes, to_peer_id, NextHopPolicy::LeastHop) + .await; if ret.is_err() { tracing::error!("forward packet to peer failed: {:?}", ret.err()); } diff --git a/easytier/src/peers/peer_manager.rs b/easytier/src/peers/peer_manager.rs index b6fec0f4..e040d250 100644 --- a/easytier/src/peers/peer_manager.rs +++ b/easytier/src/peers/peer_manager.rs @@ -22,7 +22,9 @@ use tokio_util::bytes::Bytes; use crate::{ common::{error::Error, global_ctx::ArcGlobalCtx, PeerId}, peers::{ - peer_conn::PeerConn, peer_rpc::PeerRpcManagerTransport, route_trait::RouteInterface, + peer_conn::PeerConn, + peer_rpc::PeerRpcManagerTransport, + route_trait::{NextHopPolicy, RouteInterface}, PeerPacketFilter, }, tunnel::{ @@ -73,7 +75,10 @@ impl PeerRpcManagerTransport for RpcTransport { .ok_or(Error::Unknown)?; let peers = self.peers.upgrade().ok_or(Error::Unknown)?; - if let Some(gateway_id) = peers.get_gateway_peer_id(dst_peer_id).await { + if let Some(gateway_id) = peers + .get_gateway_peer_id(dst_peer_id, NextHopPolicy::LeastHop) + .await + { tracing::trace!( ?dst_peer_id, ?gateway_id, @@ -320,20 +325,33 @@ impl PeerManager { let my_peer_id = self.my_peer_id; let peers = self.peers.clone(); let pipe_line = self.peer_packet_process_pipeline.clone(); + let foreign_client = self.foreign_network_client.clone(); let encryptor = self.encryptor.clone(); self.tasks.lock().await.spawn(async move { log::trace!("start_peer_recv"); while let Some(mut ret) = recv.next().await { - let Some(hdr) = ret.peer_manager_header() else { + let Some(hdr) = ret.mut_peer_manager_header() else { tracing::warn!(?ret, "invalid packet, skip"); continue; }; - tracing::trace!(?hdr, ?ret, "peer recv a packet..."); + tracing::trace!(?hdr, "peer recv a packet..."); let from_peer_id = hdr.from_peer_id.get(); let to_peer_id = hdr.to_peer_id.get(); if to_peer_id != my_peer_id { + if hdr.forward_counter > 7 { + tracing::warn!(?hdr, "forward counter exceed, drop packet"); + continue; + } + + if hdr.forward_counter > 2 && hdr.is_latency_first() { + tracing::trace!(?hdr, "set_latency_first false because too many hop"); + hdr.set_latency_first(false); + } + + hdr.forward_counter += 1; tracing::trace!(?to_peer_id, ?my_peer_id, "need forward"); - let ret = peers.send_msg(ret, to_peer_id).await; + let ret = + Self::send_msg_internal(&peers, &foreign_client, ret, to_peer_id).await; if ret.is_err() { tracing::error!(?ret, ?to_peer_id, ?from_peer_id, "forward packet error"); } @@ -518,11 +536,31 @@ impl PeerManager { } } + fn get_next_hop_policy(is_first_latency: bool) -> NextHopPolicy { + if is_first_latency { + NextHopPolicy::LeastCost + } else { + NextHopPolicy::LeastHop + } + } + pub async fn send_msg(&self, msg: ZCPacket, dst_peer_id: PeerId) -> Result<(), Error> { - if let Some(gateway) = self.peers.get_gateway_peer_id(dst_peer_id).await { - self.peers.send_msg_directly(msg, gateway).await - } else if self.foreign_network_client.has_next_hop(dst_peer_id) { - self.foreign_network_client.send_msg(msg, dst_peer_id).await + Self::send_msg_internal(&self.peers, &self.foreign_network_client, msg, dst_peer_id).await + } + + async fn send_msg_internal( + peers: &Arc, + foreign_network_client: &Arc, + msg: ZCPacket, + dst_peer_id: PeerId, + ) -> Result<(), Error> { + let policy = + Self::get_next_hop_policy(msg.peer_manager_header().unwrap().is_latency_first()); + + if let Some(gateway) = peers.get_gateway_peer_id(dst_peer_id, policy).await { + peers.send_msg_directly(msg, gateway).await + } else if foreign_network_client.has_next_hop(dst_peer_id) { + foreign_network_client.send_msg(msg, dst_peer_id).await } else { Err(Error::RouteError(None)) } @@ -564,6 +602,12 @@ impl PeerManager { .encrypt(&mut msg) .with_context(|| "encrypt failed")?; + let is_latency_first = self.global_ctx.get_flags().latency_first; + msg.mut_peer_manager_header() + .unwrap() + .set_latency_first(is_latency_first); + let next_hop_policy = Self::get_next_hop_policy(is_latency_first); + let mut errs: Vec = vec![]; let mut msg = Some(msg); @@ -581,7 +625,11 @@ impl PeerManager { .to_peer_id .set(*peer_id); - if let Some(gateway) = self.peers.get_gateway_peer_id(*peer_id).await { + if let Some(gateway) = self + .peers + .get_gateway_peer_id(*peer_id, next_hop_policy.clone()) + .await + { if let Err(e) = self.peers.send_msg_directly(msg, gateway).await { errs.push(e); } diff --git a/easytier/src/peers/peer_map.rs b/easytier/src/peers/peer_map.rs index 821c9035..286b9d10 100644 --- a/easytier/src/peers/peer_map.rs +++ b/easytier/src/peers/peer_map.rs @@ -18,7 +18,7 @@ use crate::{ use super::{ peer::Peer, peer_conn::{PeerConn, PeerConnId}, - route_trait::ArcRoute, + route_trait::{ArcRoute, NextHopPolicy}, PacketRecvChan, }; @@ -94,18 +94,25 @@ impl PeerMap { Ok(()) } - pub async fn get_gateway_peer_id(&self, dst_peer_id: PeerId) -> Option { + pub async fn get_gateway_peer_id( + &self, + dst_peer_id: PeerId, + policy: NextHopPolicy, + ) -> Option { if dst_peer_id == self.my_peer_id { return Some(dst_peer_id); } - if self.has_peer(dst_peer_id) { + if self.has_peer(dst_peer_id) && matches!(policy, NextHopPolicy::LeastHop) { return Some(dst_peer_id); } // get route info for route in self.routes.read().await.iter() { - if let Some(gateway_peer_id) = route.get_next_hop(dst_peer_id).await { + if let Some(gateway_peer_id) = route + .get_next_hop_with_policy(dst_peer_id, policy.clone()) + .await + { // for foreign network, gateway_peer_id may not connect to me if self.has_peer(gateway_peer_id) { return Some(gateway_peer_id); @@ -116,8 +123,13 @@ impl PeerMap { None } - pub async fn send_msg(&self, msg: ZCPacket, dst_peer_id: PeerId) -> Result<(), Error> { - let Some(gateway_peer_id) = self.get_gateway_peer_id(dst_peer_id).await else { + pub async fn send_msg( + &self, + msg: ZCPacket, + dst_peer_id: PeerId, + policy: NextHopPolicy, + ) -> Result<(), Error> { + let Some(gateway_peer_id) = self.get_gateway_peer_id(dst_peer_id, policy).await else { return Err(Error::RouteError(Some(format!( "peer map sengmsg no gateway for dst_peer_id: {}", dst_peer_id diff --git a/easytier/src/peers/peer_ospf_route.rs b/easytier/src/peers/peer_ospf_route.rs index 59bdbf43..758bcb78 100644 --- a/easytier/src/peers/peer_ospf_route.rs +++ b/easytier/src/peers/peer_ospf_route.rs @@ -10,6 +10,11 @@ use std::{ }; use dashmap::DashMap; +use petgraph::{ + algo::{all_simple_paths, astar, dijkstra}, + graph::NodeIndex, + Directed, Graph, +}; use serde::{Deserialize, Serialize}; use tokio::{select, sync::Mutex, task::JoinSet}; @@ -19,7 +24,14 @@ use crate::{ rpc::{NatType, StunInfo}, }; -use super::{peer_rpc::PeerRpcManager, PeerPacketFilter}; +use super::{ + peer_rpc::PeerRpcManager, + route_trait::{ + DefaultRouteCostCalculator, NextHopPolicy, RouteCostCalculator, + RouteCostCalculatorInterface, + }, + PeerPacketFilter, +}; static SERVICE_ID: u32 = 7; static UPDATE_PEER_INFO_PERIOD: Duration = Duration::from_secs(3600); @@ -360,11 +372,15 @@ impl SyncedRouteInfo { } } +type PeerGraph = Graph; +type PeerIdToNodexIdxMap = DashMap; +type NextHopMap = DashMap; + // computed with SyncedRouteInfo. used to get next hop. #[derive(Debug)] struct RouteTable { peer_infos: DashMap, - next_hop_map: DashMap, + next_hop_map: NextHopMap, ipv4_peer_id_map: DashMap, cidr_peer_id_map: DashMap, } @@ -393,7 +409,121 @@ impl RouteTable { .map(|x| NatType::try_from(x.udp_stun_info as i32).unwrap()) } - fn build_from_synced_info(&self, my_peer_id: PeerId, synced_info: &SyncedRouteInfo) { + fn build_peer_graph_from_synced_info( + peers: Vec, + synced_info: &SyncedRouteInfo, + cost_calc: &mut T, + ) -> (PeerGraph, PeerIdToNodexIdxMap) { + let mut graph: PeerGraph = Graph::new(); + let peer_id_to_node_index = PeerIdToNodexIdxMap::new(); + for peer_id in peers.iter() { + peer_id_to_node_index.insert(*peer_id, graph.add_node(*peer_id)); + } + + for peer_id in peers.iter() { + let connected_peers = synced_info + .get_connected_peers(*peer_id) + .unwrap_or(BTreeSet::new()); + for dst_peer_id in connected_peers.iter() { + let Some(dst_idx) = peer_id_to_node_index.get(dst_peer_id) else { + continue; + }; + + graph.add_edge( + *peer_id_to_node_index.get(&peer_id).unwrap(), + *dst_idx, + cost_calc.calculate_cost(*peer_id, *dst_peer_id), + ); + } + } + + (graph, peer_id_to_node_index) + } + + fn gen_next_hop_map_with_least_hop( + my_peer_id: PeerId, + graph: &PeerGraph, + idx_map: &PeerIdToNodexIdxMap, + cost_calc: &mut T, + ) -> NextHopMap { + let res = dijkstra(&graph, *idx_map.get(&my_peer_id).unwrap(), None, |_| 1); + let next_hop_map = NextHopMap::new(); + for (node_idx, cost) in res.iter() { + if *cost == 0 { + continue; + } + let all_paths = all_simple_paths::, _>( + graph, + *idx_map.get(&my_peer_id).unwrap(), + *node_idx, + *cost - 1, + Some(*cost - 1), + ) + .collect::>(); + + assert!(!all_paths.is_empty()); + + // find a path with least cost. + let mut min_cost = i32::MAX; + let mut min_path = Vec::new(); + for path in all_paths.iter() { + let mut cost = 0; + for i in 0..path.len() - 1 { + let src_peer_id = *graph.node_weight(path[i]).unwrap(); + let dst_peer_id = *graph.node_weight(path[i + 1]).unwrap(); + cost += cost_calc.calculate_cost(src_peer_id, dst_peer_id); + } + + if cost <= min_cost { + min_cost = cost; + min_path = path.clone(); + } + } + next_hop_map.insert( + *graph.node_weight(*node_idx).unwrap(), + (*graph.node_weight(min_path[1]).unwrap(), *cost as i32), + ); + } + + next_hop_map + } + + fn gen_next_hop_map_with_least_cost( + my_peer_id: PeerId, + graph: &PeerGraph, + idx_map: &PeerIdToNodexIdxMap, + ) -> NextHopMap { + let next_hop_map = NextHopMap::new(); + for item in idx_map.iter() { + if *item.key() == my_peer_id { + continue; + } + + let dst_peer_node_idx = *item.value(); + + let Some((cost, path)) = astar::astar( + graph, + *idx_map.get(&my_peer_id).unwrap(), + |node_idx| node_idx == dst_peer_node_idx, + |e| *e.weight(), + |_| 0, + ) else { + continue; + }; + + next_hop_map.insert(*item.key(), (*graph.node_weight(path[1]).unwrap(), cost)); + } + + next_hop_map + } + + fn build_from_synced_info( + &self, + my_peer_id: PeerId, + synced_info: &SyncedRouteInfo, + policy: NextHopPolicy, + mut cost_calc: T, + ) { // build peer_infos self.peer_infos.clear(); for item in synced_info.peer_infos.iter() { @@ -410,28 +540,20 @@ impl RouteTable { // build next hop map self.next_hop_map.clear(); self.next_hop_map.insert(my_peer_id, (my_peer_id, 0)); - for item in self.peer_infos.iter() { - let peer_id = *item.key(); - if peer_id == my_peer_id { - continue; - } - let Some(path) = pathfinding::prelude::bfs( - &my_peer_id, - |p| { - synced_info - .get_connected_peers(*p) - .unwrap_or_else(|| BTreeSet::new()) - }, - |x| *x == peer_id, - ) else { - continue; - }; - if !path.is_empty() { - assert!(path.len() >= 2); - self.next_hop_map - .insert(peer_id, (path[1], (path.len() - 1) as i32)); - } + let (graph, idx_map) = Self::build_peer_graph_from_synced_info( + self.peer_infos.iter().map(|x| *x.key()).collect(), + &synced_info, + &mut cost_calc, + ); + let next_hop_map = if matches!(policy, NextHopPolicy::LeastHop) { + Self::gen_next_hop_map_with_least_hop(my_peer_id, &graph, &idx_map, &mut cost_calc) + } else { + Self::gen_next_hop_map_with_least_cost(my_peer_id, &graph, &idx_map) + }; + for item in next_hop_map.iter() { + self.next_hop_map.insert(*item.key(), *item.value()); } + // build graph // build ipv4_peer_id_map, cidr_peer_id_map self.ipv4_peer_id_map.clear(); @@ -563,7 +685,9 @@ struct PeerRouteServiceImpl { interface: Arc>>, + cost_calculator: Arc>>, route_table: RouteTable, + route_table_with_cost: RouteTable, synced_route_info: Arc, cached_local_conn_map: std::sync::Mutex, } @@ -585,9 +709,17 @@ impl PeerRouteServiceImpl { PeerRouteServiceImpl { my_peer_id, global_ctx, - interface: Arc::new(Mutex::new(None)), sessions: DashMap::new(), + + interface: Arc::new(Mutex::new(None)), + + cost_calculator: Arc::new(std::sync::Mutex::new(Some(Box::new( + DefaultRouteCostCalculator, + )))), + route_table: RouteTable::new(), + route_table_with_cost: RouteTable::new(), + synced_route_info: Arc::new(SyncedRouteInfo { peer_infos: DashMap::new(), conn_map: DashMap::new(), @@ -649,8 +781,32 @@ impl PeerRouteServiceImpl { } fn update_route_table(&self) { - self.route_table - .build_from_synced_info(self.my_peer_id, &self.synced_route_info); + let mut calc_locked = self.cost_calculator.lock().unwrap(); + + calc_locked.as_mut().unwrap().begin_update(); + self.route_table.build_from_synced_info( + self.my_peer_id, + &self.synced_route_info, + NextHopPolicy::LeastHop, + calc_locked.as_mut().unwrap(), + ); + + self.route_table_with_cost.build_from_synced_info( + self.my_peer_id, + &self.synced_route_info, + NextHopPolicy::LeastCost, + calc_locked.as_mut().unwrap(), + ); + calc_locked.as_mut().unwrap().end_update(); + } + + fn cost_calculator_need_update(&self) -> bool { + self.cost_calculator + .lock() + .unwrap() + .as_ref() + .map(|x| x.need_update()) + .unwrap_or(false) } fn update_route_table_and_cached_local_conn_bitmap(&self) { @@ -1173,6 +1329,7 @@ impl PeerRoute { session_mgr.maintain_sessions(service_impl).await; } + #[tracing::instrument(skip(session_mgr))] async fn update_my_peer_info_routine( service_impl: Arc, session_mgr: RouteSessionManager, @@ -1183,6 +1340,11 @@ impl PeerRoute { session_mgr.sync_now("update_my_infos"); } + if service_impl.cost_calculator_need_update() { + tracing::debug!("cost_calculator_need_update"); + service_impl.update_route_table(); + } + select! { ev = global_event_receiver.recv() => { tracing::info!(?ev, "global event received in update_my_peer_info_routine"); @@ -1234,6 +1396,19 @@ impl Route for PeerRoute { route_table.get_next_hop(dst_peer_id).map(|x| x.0) } + async fn get_next_hop_with_policy( + &self, + dst_peer_id: PeerId, + policy: NextHopPolicy, + ) -> Option { + let route_table = if matches!(policy, NextHopPolicy::LeastCost) { + &self.service_impl.route_table_with_cost + } else { + &self.service_impl.route_table + }; + route_table.get_next_hop(dst_peer_id).map(|x| x.0) + } + async fn list_routes(&self) -> Vec { let route_table = &self.service_impl.route_table; let mut routes = Vec::new(); @@ -1265,6 +1440,11 @@ impl Route for PeerRoute { tracing::info!(?ipv4_addr, "no peer id for ipv4"); None } + + async fn set_route_cost_fn(&self, _cost_fn: RouteCostCalculator) { + *self.service_impl.cost_calculator.lock().unwrap() = Some(_cost_fn); + self.service_impl.update_route_table(); + } } impl PeerPacketFilter for Arc {} @@ -1282,7 +1462,7 @@ mod tests { connector::udp_hole_punch::tests::replace_stun_info_collector, peers::{ peer_manager::{PeerManager, RouteAlgoType}, - route_trait::Route, + route_trait::{NextHopPolicy, Route, RouteCostCalculatorInterface}, tests::{connect_peer_manager, wait_for_condition}, }, rpc::NatType, @@ -1609,4 +1789,91 @@ mod tests { println!("session: {:?}", r_a.session_mgr.dump_sessions()); check_rpc_counter(&r_a, p_b.my_peer_id(), 2, 2); } + + #[tokio::test] + async fn test_cost_calculator() { + let p_a = create_mock_pmgr().await; + let p_b = create_mock_pmgr().await; + let p_c = create_mock_pmgr().await; + let p_d = create_mock_pmgr().await; + connect_peer_manager(p_a.clone(), p_b.clone()).await; + connect_peer_manager(p_a.clone(), p_c.clone()).await; + connect_peer_manager(p_d.clone(), p_b.clone()).await; + connect_peer_manager(p_d.clone(), p_c.clone()).await; + connect_peer_manager(p_b.clone(), p_c.clone()).await; + + let _r_a = create_mock_route(p_a.clone()).await; + let _r_b = create_mock_route(p_b.clone()).await; + let _r_c = create_mock_route(p_c.clone()).await; + let r_d = create_mock_route(p_d.clone()).await; + + // in normal mode, packet from p_c should directly forward to p_a + wait_for_condition( + || async { r_d.get_next_hop(p_a.my_peer_id()).await != None }, + Duration::from_secs(5), + ) + .await; + + struct TestCostCalculator { + p_a_peer_id: PeerId, + p_b_peer_id: PeerId, + p_c_peer_id: PeerId, + p_d_peer_id: PeerId, + } + + impl RouteCostCalculatorInterface for TestCostCalculator { + fn calculate_cost(&self, src: PeerId, dst: PeerId) -> i32 { + if src == self.p_d_peer_id && dst == self.p_b_peer_id { + return 100; + } + + if src == self.p_d_peer_id && dst == self.p_c_peer_id { + return 1; + } + + if src == self.p_c_peer_id && dst == self.p_a_peer_id { + return 101; + } + + if src == self.p_b_peer_id && dst == self.p_a_peer_id { + return 1; + } + + if src == self.p_c_peer_id && dst == self.p_b_peer_id { + return 2; + } + + 1 + } + } + + r_d.set_route_cost_fn(Box::new(TestCostCalculator { + p_a_peer_id: p_a.my_peer_id(), + p_b_peer_id: p_b.my_peer_id(), + p_c_peer_id: p_c.my_peer_id(), + p_d_peer_id: p_d.my_peer_id(), + })) + .await; + + // after set cost, packet from p_c should forward to p_b first + wait_for_condition( + || async { + r_d.get_next_hop_with_policy(p_a.my_peer_id(), NextHopPolicy::LeastCost) + .await + == Some(p_c.my_peer_id()) + }, + Duration::from_secs(5), + ) + .await; + + wait_for_condition( + || async { + r_d.get_next_hop_with_policy(p_a.my_peer_id(), NextHopPolicy::LeastHop) + .await + == Some(p_b.my_peer_id()) + }, + Duration::from_secs(5), + ) + .await; + } } diff --git a/easytier/src/peers/route_trait.rs b/easytier/src/peers/route_trait.rs index 66364355..ad64df32 100644 --- a/easytier/src/peers/route_trait.rs +++ b/easytier/src/peers/route_trait.rs @@ -5,6 +5,18 @@ use tokio_util::bytes::Bytes; use crate::common::{error::Error, PeerId}; +#[derive(Clone, Debug)] +pub enum NextHopPolicy { + LeastHop, + LeastCost, +} + +impl Default for NextHopPolicy { + fn default() -> Self { + NextHopPolicy::LeastHop + } +} + #[async_trait] pub trait RouteInterface { async fn list_peers(&self) -> Vec; @@ -19,6 +31,31 @@ pub trait RouteInterface { pub type RouteInterfaceBox = Box; +#[auto_impl::auto_impl(Box , &mut)] +pub trait RouteCostCalculatorInterface: Send + Sync { + fn begin_update(&mut self) {} + fn end_update(&mut self) {} + + fn calculate_cost(&self, _src: PeerId, _dst: PeerId) -> i32 { + 1 + } + + fn need_update(&self) -> bool { + false + } + + fn dump(&self) -> String { + "All routes have cost 1".to_string() + } +} + +#[derive(Clone, Debug, Default)] +pub struct DefaultRouteCostCalculator; + +impl RouteCostCalculatorInterface for DefaultRouteCostCalculator {} + +pub type RouteCostCalculator = Box; + #[async_trait] #[auto_impl::auto_impl(Box, Arc)] pub trait Route { @@ -26,11 +63,21 @@ pub trait Route { async fn close(&self); async fn get_next_hop(&self, peer_id: PeerId) -> Option; + async fn get_next_hop_with_policy( + &self, + peer_id: PeerId, + _policy: NextHopPolicy, + ) -> Option { + self.get_next_hop(peer_id).await + } + async fn list_routes(&self) -> Vec; async fn get_peer_id_by_ipv4(&self, _ipv4: &Ipv4Addr) -> Option { None } + + async fn set_route_cost_fn(&self, _cost_fn: RouteCostCalculator) {} } pub type ArcRoute = Arc>; diff --git a/easytier/src/tests/three_node.rs b/easytier/src/tests/three_node.rs index 1ecdee33..279b166a 100644 --- a/easytier/src/tests/three_node.rs +++ b/easytier/src/tests/three_node.rs @@ -187,12 +187,13 @@ pub async fn basic_three_node_test(#[values("tcp", "udp", "wg", "ws", "wss")] pr pub async fn tcp_proxy_three_node_test(#[values("tcp", "udp", "wg")] proto: &str) { use crate::tunnel::{common::tests::_tunnel_pingpong_netns, tcp::TcpTunnelListener}; - let insts = init_three_node(proto).await; + let mut insts = init_three_node(proto).await; insts[2] .get_global_ctx() .add_proxy_cidr("10.1.2.0/24".parse().unwrap()) .unwrap(); + insts[2].run_ip_proxy().await.unwrap(); assert_eq!(insts[2].get_global_ctx().get_proxy_cidrs().len(), 1); wait_proxy_route_appear( @@ -222,12 +223,13 @@ pub async fn tcp_proxy_three_node_test(#[values("tcp", "udp", "wg")] proto: &str #[tokio::test] #[serial_test::serial] pub async fn icmp_proxy_three_node_test(#[values("tcp", "udp", "wg")] proto: &str) { - let insts = init_three_node(proto).await; + let mut insts = init_three_node(proto).await; insts[2] .get_global_ctx() .add_proxy_cidr("10.1.2.0/24".parse().unwrap()) .unwrap(); + insts[2].run_ip_proxy().await.unwrap(); assert_eq!(insts[2].get_global_ctx().get_proxy_cidrs().len(), 1); wait_proxy_route_appear( @@ -318,12 +320,13 @@ pub async fn proxy_three_node_disconnect_test(#[values("tcp", "wg")] proto: &str pub async fn udp_proxy_three_node_test(#[values("tcp", "udp", "wg")] proto: &str) { use crate::tunnel::{common::tests::_tunnel_pingpong_netns, udp::UdpTunnelListener}; - let insts = init_three_node(proto).await; + let mut insts = init_three_node(proto).await; insts[2] .get_global_ctx() .add_proxy_cidr("10.1.2.0/24".parse().unwrap()) .unwrap(); + insts[2].run_ip_proxy().await.unwrap(); assert_eq!(insts[2].get_global_ctx().get_proxy_cidrs().len(), 1); wait_proxy_route_appear( diff --git a/easytier/src/tunnel/packet_def.rs b/easytier/src/tunnel/packet_def.rs index 036b40d8..567bfa65 100644 --- a/easytier/src/tunnel/packet_def.rs +++ b/easytier/src/tunnel/packet_def.rs @@ -59,6 +59,7 @@ pub enum PacketType { bitflags::bitflags! { struct PeerManagerHeaderFlags: u8 { const ENCRYPTED = 0b0000_0001; + const LATENCY_FIRST = 0b0000_0010; } } @@ -69,7 +70,8 @@ pub struct PeerManagerHeader { pub to_peer_id: U32, pub packet_type: u8, pub flags: u8, - reserved: U16, + pub forward_counter: u8, + reserved: u8, pub len: U32, } pub const PEER_MANAGER_HEADER_SIZE: usize = std::mem::size_of::(); @@ -90,6 +92,22 @@ impl PeerManagerHeader { } self.flags = flags.bits(); } + + pub fn is_latency_first(&self) -> bool { + PeerManagerHeaderFlags::from_bits(self.flags) + .unwrap() + .contains(PeerManagerHeaderFlags::LATENCY_FIRST) + } + + pub fn set_latency_first(&mut self, latency_first: bool) { + let mut flags = PeerManagerHeaderFlags::from_bits(self.flags).unwrap(); + if latency_first { + flags.insert(PeerManagerHeaderFlags::LATENCY_FIRST); + } else { + flags.remove(PeerManagerHeaderFlags::LATENCY_FIRST); + } + self.flags = flags.bits(); + } } // reserve the space for aes tag and nonce @@ -362,6 +380,7 @@ impl ZCPacket { hdr.to_peer_id.set(to_peer_id); hdr.packet_type = packet_type; hdr.flags = 0; + hdr.forward_counter = 1; hdr.len.set(payload_len as u32); }