use workspace, prepare for config server and gui (#48)

This commit is contained in:
Sijie.Sun
2024-04-04 10:33:53 +08:00
committed by GitHub
parent bb4ae71869
commit 4eb7efe5fc
77 changed files with 162 additions and 195 deletions
+2
View File
@@ -0,0 +1,2 @@
#[cfg(target_os = "windows")]
pub mod windows;
+145
View File
@@ -0,0 +1,145 @@
use std::{
ffi::c_void,
io::{self, ErrorKind},
mem,
net::SocketAddr,
os::windows::io::AsRawSocket,
ptr,
};
use network_interface::NetworkInterfaceConfig;
use windows_sys::{
core::PCSTR,
Win32::{
Foundation::{BOOL, FALSE},
Networking::WinSock::{
htonl, setsockopt, WSAGetLastError, WSAIoctl, IPPROTO_IP, IPPROTO_IPV6,
IPV6_UNICAST_IF, IP_UNICAST_IF, SIO_UDP_CONNRESET, SOCKET, SOCKET_ERROR,
},
},
};
pub fn disable_connection_reset<S: AsRawSocket>(socket: &S) -> io::Result<()> {
let handle = socket.as_raw_socket() as SOCKET;
unsafe {
// Ignoring UdpSocket's WSAECONNRESET error
// https://github.com/shadowsocks/shadowsocks-rust/issues/179
// https://stackoverflow.com/questions/30749423/is-winsock-error-10054-wsaeconnreset-normal-with-udp-to-from-localhost
//
// This is because `UdpSocket::recv_from` may return WSAECONNRESET
// if you called `UdpSocket::send_to` a destination that is not existed (may be closed).
//
// It is not an error. Could be ignored completely.
// We have to ignore it here because it will crash the server.
let mut bytes_returned: u32 = 0;
let enable: BOOL = FALSE;
let ret = WSAIoctl(
handle,
SIO_UDP_CONNRESET,
&enable as *const _ as *const c_void,
mem::size_of_val(&enable) as u32,
ptr::null_mut(),
0,
&mut bytes_returned as *mut _,
ptr::null_mut(),
None,
);
if ret == SOCKET_ERROR {
use std::io::Error;
// Error occurs
let err_code = WSAGetLastError();
return Err(Error::from_raw_os_error(err_code));
}
}
Ok(())
}
pub fn find_interface_index(iface_name: &str) -> io::Result<u32> {
let ifaces = network_interface::NetworkInterface::show().map_err(|e| {
io::Error::new(
ErrorKind::NotFound,
format!("Failed to get interfaces. {}, error: {}", iface_name, e),
)
})?;
if let Some(iface) = ifaces.iter().find(|iface| iface.name == iface_name) {
return Ok(iface.index);
}
tracing::error!("Failed to find interface index for {}", iface_name);
Err(io::Error::new(
ErrorKind::NotFound,
format!("{}", iface_name),
))
}
pub fn set_ip_unicast_if<S: AsRawSocket>(
socket: &S,
addr: &SocketAddr,
iface: &str,
) -> io::Result<()> {
let handle = socket.as_raw_socket() as SOCKET;
let if_index = find_interface_index(iface)?;
unsafe {
// https://docs.microsoft.com/en-us/windows/win32/winsock/ipproto-ip-socket-options
let ret = match addr {
SocketAddr::V4(..) => {
// Interface index is in network byte order for IPPROTO_IP.
let if_index = htonl(if_index);
setsockopt(
handle,
IPPROTO_IP as i32,
IP_UNICAST_IF as i32,
&if_index as *const _ as PCSTR,
mem::size_of_val(&if_index) as i32,
)
}
SocketAddr::V6(..) => {
// Interface index is in host byte order for IPPROTO_IPV6.
setsockopt(
handle,
IPPROTO_IPV6 as i32,
IPV6_UNICAST_IF as i32,
&if_index as *const _ as PCSTR,
mem::size_of_val(&if_index) as i32,
)
}
};
if ret == SOCKET_ERROR {
let err = io::Error::from_raw_os_error(WSAGetLastError());
tracing::error!(
"set IP_UNICAST_IF / IPV6_UNICAST_IF interface: {}, index: {}, error: {}",
iface,
if_index,
err
);
return Err(err);
}
}
Ok(())
}
pub fn setup_socket_for_win<S: AsRawSocket>(
socket: &S,
bind_addr: &SocketAddr,
bind_dev: Option<String>,
is_udp: bool,
) -> io::Result<()> {
if is_udp {
disable_connection_reset(socket)?;
}
if let Some(iface) = bind_dev {
set_ip_unicast_if(socket, bind_addr, iface.as_str())?;
}
Ok(())
}
+425
View File
@@ -0,0 +1,425 @@
use std::{
net::SocketAddr,
sync::{Arc, Mutex},
};
use anyhow::Context;
use serde::{Deserialize, Serialize};
#[auto_impl::auto_impl(Box, &)]
pub trait ConfigLoader: Send + Sync {
fn get_id(&self) -> uuid::Uuid;
fn get_inst_name(&self) -> String;
fn set_inst_name(&self, name: String);
fn get_netns(&self) -> Option<String>;
fn set_netns(&self, ns: Option<String>);
fn get_ipv4(&self) -> Option<std::net::Ipv4Addr>;
fn set_ipv4(&self, addr: std::net::Ipv4Addr);
fn add_proxy_cidr(&self, cidr: cidr::IpCidr);
fn remove_proxy_cidr(&self, cidr: cidr::IpCidr);
fn get_proxy_cidrs(&self) -> Vec<cidr::IpCidr>;
fn get_network_identity(&self) -> NetworkIdentity;
fn set_network_identity(&self, identity: NetworkIdentity);
fn get_listener_uris(&self) -> Vec<url::Url>;
fn get_file_logger_config(&self) -> FileLoggerConfig;
fn set_file_logger_config(&self, config: FileLoggerConfig);
fn get_console_logger_config(&self) -> ConsoleLoggerConfig;
fn set_console_logger_config(&self, config: ConsoleLoggerConfig);
fn get_peers(&self) -> Vec<PeerConfig>;
fn set_peers(&self, peers: Vec<PeerConfig>);
fn get_listeners(&self) -> Vec<url::Url>;
fn set_listeners(&self, listeners: Vec<url::Url>);
fn get_rpc_portal(&self) -> Option<SocketAddr>;
fn set_rpc_portal(&self, addr: SocketAddr);
fn get_vpn_portal_config(&self) -> Option<VpnPortalConfig>;
fn set_vpn_portal_config(&self, config: VpnPortalConfig);
fn get_flags(&self) -> Flags;
fn set_flags(&self, flags: Flags);
fn dump(&self) -> String;
}
#[derive(Debug, Clone, Deserialize, Serialize, PartialEq)]
pub struct NetworkIdentity {
pub network_name: String,
pub network_secret: String,
}
impl NetworkIdentity {
pub fn new(network_name: String, network_secret: String) -> Self {
NetworkIdentity {
network_name,
network_secret,
}
}
pub fn default() -> Self {
Self::new("default".to_string(), "".to_string())
}
}
#[derive(Debug, Clone, Deserialize, Serialize, PartialEq)]
pub struct PeerConfig {
pub uri: url::Url,
}
#[derive(Debug, Clone, Deserialize, Serialize, PartialEq)]
pub struct NetworkConfig {
pub cidr: String,
pub allow: Option<Vec<String>>,
}
#[derive(Debug, Clone, Deserialize, Serialize, PartialEq, Default)]
pub struct FileLoggerConfig {
pub level: Option<String>,
pub file: Option<String>,
pub dir: Option<String>,
}
#[derive(Debug, Clone, Deserialize, Serialize, PartialEq, Default)]
pub struct ConsoleLoggerConfig {
pub level: Option<String>,
}
#[derive(Debug, Clone, Deserialize, Serialize, PartialEq)]
pub struct VpnPortalConfig {
pub client_cidr: cidr::Ipv4Cidr,
pub wireguard_listen: SocketAddr,
}
// Flags is used to control the behavior of the program
#[derive(derivative::Derivative, Deserialize, Serialize)]
#[derivative(Debug, Clone, PartialEq, Default)]
pub struct Flags {
#[derivative(Default(value = "\"tcp\".to_string()"))]
pub default_protocol: String,
}
#[derive(Debug, Clone, Deserialize, Serialize, PartialEq)]
struct Config {
netns: Option<String>,
instance_name: Option<String>,
instance_id: Option<String>,
ipv4: Option<String>,
network_identity: Option<NetworkIdentity>,
listeners: Option<Vec<url::Url>>,
peer: Option<Vec<PeerConfig>>,
proxy_network: Option<Vec<NetworkConfig>>,
file_logger: Option<FileLoggerConfig>,
console_logger: Option<ConsoleLoggerConfig>,
rpc_portal: Option<SocketAddr>,
vpn_portal_config: Option<VpnPortalConfig>,
flags: Option<Flags>,
}
#[derive(Debug, Clone)]
pub struct TomlConfigLoader {
config: Arc<Mutex<Config>>,
}
impl Default for TomlConfigLoader {
fn default() -> Self {
TomlConfigLoader::new_from_str("").unwrap()
}
}
impl TomlConfigLoader {
pub fn new_from_str(config_str: &str) -> Result<Self, anyhow::Error> {
let config = toml::de::from_str::<Config>(config_str).with_context(|| {
format!(
"failed to parse config file: {}\n{}",
config_str, config_str
)
})?;
Ok(TomlConfigLoader {
config: Arc::new(Mutex::new(config)),
})
}
pub fn new(config_path: &str) -> Result<Self, anyhow::Error> {
let config_str = std::fs::read_to_string(config_path)
.with_context(|| format!("failed to read config file: {}", config_path))?;
Self::new_from_str(&config_str)
}
}
impl ConfigLoader for TomlConfigLoader {
fn get_inst_name(&self) -> String {
self.config
.lock()
.unwrap()
.instance_name
.clone()
.unwrap_or("default".to_string())
}
fn set_inst_name(&self, name: String) {
self.config.lock().unwrap().instance_name = Some(name);
}
fn get_netns(&self) -> Option<String> {
self.config.lock().unwrap().netns.clone()
}
fn set_netns(&self, ns: Option<String>) {
self.config.lock().unwrap().netns = ns;
}
fn get_ipv4(&self) -> Option<std::net::Ipv4Addr> {
let locked_config = self.config.lock().unwrap();
locked_config
.ipv4
.as_ref()
.map(|s| s.parse().ok())
.flatten()
}
fn set_ipv4(&self, addr: std::net::Ipv4Addr) {
self.config.lock().unwrap().ipv4 = Some(addr.to_string());
}
fn add_proxy_cidr(&self, cidr: cidr::IpCidr) {
let mut locked_config = self.config.lock().unwrap();
if locked_config.proxy_network.is_none() {
locked_config.proxy_network = Some(vec![]);
}
let cidr_str = cidr.to_string();
// insert if no duplicate
if !locked_config
.proxy_network
.as_ref()
.unwrap()
.iter()
.any(|c| c.cidr == cidr_str)
{
locked_config
.proxy_network
.as_mut()
.unwrap()
.push(NetworkConfig {
cidr: cidr_str,
allow: None,
});
}
}
fn remove_proxy_cidr(&self, cidr: cidr::IpCidr) {
let mut locked_config = self.config.lock().unwrap();
if let Some(proxy_cidrs) = &mut locked_config.proxy_network {
let cidr_str = cidr.to_string();
proxy_cidrs.retain(|c| c.cidr != cidr_str);
}
}
fn get_proxy_cidrs(&self) -> Vec<cidr::IpCidr> {
self.config
.lock()
.unwrap()
.proxy_network
.as_ref()
.map(|v| {
v.iter()
.map(|c| c.cidr.parse().unwrap())
.collect::<Vec<cidr::IpCidr>>()
})
.unwrap_or_default()
}
fn get_id(&self) -> uuid::Uuid {
let mut locked_config = self.config.lock().unwrap();
if locked_config.instance_id.is_none() {
let id = uuid::Uuid::new_v4();
locked_config.instance_id = Some(id.to_string());
id
} else {
uuid::Uuid::parse_str(locked_config.instance_id.as_ref().unwrap())
.with_context(|| {
format!(
"failed to parse instance id as uuid: {}, you can use this id: {}",
locked_config.instance_id.as_ref().unwrap(),
uuid::Uuid::new_v4()
)
})
.unwrap()
}
}
fn get_network_identity(&self) -> NetworkIdentity {
self.config
.lock()
.unwrap()
.network_identity
.clone()
.unwrap_or_else(NetworkIdentity::default)
}
fn set_network_identity(&self, identity: NetworkIdentity) {
self.config.lock().unwrap().network_identity = Some(identity);
}
fn get_listener_uris(&self) -> Vec<url::Url> {
self.config
.lock()
.unwrap()
.listeners
.clone()
.unwrap_or_default()
}
fn get_file_logger_config(&self) -> FileLoggerConfig {
self.config
.lock()
.unwrap()
.file_logger
.clone()
.unwrap_or_default()
}
fn set_file_logger_config(&self, config: FileLoggerConfig) {
self.config.lock().unwrap().file_logger = Some(config);
}
fn get_console_logger_config(&self) -> ConsoleLoggerConfig {
self.config
.lock()
.unwrap()
.console_logger
.clone()
.unwrap_or_default()
}
fn set_console_logger_config(&self, config: ConsoleLoggerConfig) {
self.config.lock().unwrap().console_logger = Some(config);
}
fn get_peers(&self) -> Vec<PeerConfig> {
self.config.lock().unwrap().peer.clone().unwrap_or_default()
}
fn set_peers(&self, peers: Vec<PeerConfig>) {
self.config.lock().unwrap().peer = Some(peers);
}
fn get_listeners(&self) -> Vec<url::Url> {
self.config
.lock()
.unwrap()
.listeners
.clone()
.unwrap_or_default()
}
fn set_listeners(&self, listeners: Vec<url::Url>) {
self.config.lock().unwrap().listeners = Some(listeners);
}
fn get_rpc_portal(&self) -> Option<SocketAddr> {
self.config.lock().unwrap().rpc_portal
}
fn set_rpc_portal(&self, addr: SocketAddr) {
self.config.lock().unwrap().rpc_portal = Some(addr);
}
fn get_vpn_portal_config(&self) -> Option<VpnPortalConfig> {
self.config.lock().unwrap().vpn_portal_config.clone()
}
fn set_vpn_portal_config(&self, config: VpnPortalConfig) {
self.config.lock().unwrap().vpn_portal_config = Some(config);
}
fn get_flags(&self) -> Flags {
self.config
.lock()
.unwrap()
.flags
.clone()
.unwrap_or_default()
}
fn set_flags(&self, flags: Flags) {
self.config.lock().unwrap().flags = Some(flags);
}
fn dump(&self) -> String {
toml::to_string_pretty(&*self.config.lock().unwrap()).unwrap()
}
}
#[cfg(test)]
pub mod tests {
use super::*;
#[tokio::test]
async fn full_example_test() {
let config_str = r#"
instance_name = "default"
instance_id = "87ede5a2-9c3d-492d-9bbe-989b9d07e742"
ipv4 = "10.144.144.10"
listeners = [ "tcp://0.0.0.0:11010", "udp://0.0.0.0:11010" ]
[network_identity]
network_name = "default"
network_secret = ""
[[peer]]
uri = "tcp://public.kkrainbow.top:11010"
[[peer]]
uri = "udp://192.168.94.33:11010"
[[proxy_network]]
cidr = "10.147.223.0/24"
allow = ["tcp", "udp", "icmp"]
[[proxy_network]]
cidr = "10.1.1.0/24"
allow = ["tcp", "icmp"]
[file_logger]
level = "info"
file = "easytier"
dir = "/tmp/easytier"
[console_logger]
level = "warn"
"#;
let ret = TomlConfigLoader::new_from_str(config_str);
if let Err(e) = &ret {
println!("{}", e);
} else {
println!("{:?}", ret.as_ref().unwrap());
}
assert!(ret.is_ok());
let ret = ret.unwrap();
assert_eq!("10.144.144.10", ret.get_ipv4().unwrap().to_string());
assert_eq!(
vec!["tcp://0.0.0.0:11010", "udp://0.0.0.0:11010"],
ret.get_listener_uris()
.iter()
.map(|u| u.to_string())
.collect::<Vec<String>>()
);
println!("{}", ret.dump());
}
}
+24
View File
@@ -0,0 +1,24 @@
macro_rules! define_global_var {
($name:ident, $type:ty, $init:expr) => {
pub static $name: once_cell::sync::Lazy<tokio::sync::Mutex<$type>> =
once_cell::sync::Lazy::new(|| tokio::sync::Mutex::new($init));
};
}
#[macro_export]
macro_rules! use_global_var {
($name:ident) => {
crate::common::constants::$name.lock().await.to_owned()
};
}
#[macro_export]
macro_rules! set_global_var {
($name:ident, $val:expr) => {
*crate::common::constants::$name.lock().await = $val
};
}
define_global_var!(MANUAL_CONNECTOR_RECONNECT_INTERVAL_MS, u64, 1000);
pub const UDP_HOLE_PUNCH_CONNECTOR_SERVICE_ID: u32 = 2;
+45
View File
@@ -0,0 +1,45 @@
use std::{io, result};
use thiserror::Error;
use crate::tunnels;
use super::PeerId;
#[derive(Error, Debug)]
pub enum Error {
#[error("io error")]
IOError(#[from] io::Error),
#[error("rust tun error {0}")]
TunError(#[from] tun::Error),
#[error("tunnel error {0}")]
TunnelError(#[from] tunnels::TunnelError),
#[error("Peer has no conn, PeerId: {0}")]
PeerNoConnectionError(PeerId),
#[error("RouteError: {0:?}")]
RouteError(Option<String>),
#[error("Not found")]
NotFound,
#[error("Invalid Url: {0}")]
InvalidUrl(String),
#[error("Shell Command error: {0}")]
ShellCommandError(String),
// #[error("Rpc listen error: {0}")]
// RpcListenError(String),
#[error("Rpc connect error: {0}")]
RpcConnectError(String),
#[error("Rpc error: {0}")]
RpcClientError(#[from] tarpc::client::RpcError),
#[error("Timeout error: {0}")]
Timeout(#[from] tokio::time::error::Elapsed),
#[error("url in blacklist")]
UrlInBlacklist,
#[error("unknown data store error")]
Unknown,
#[error("anyhow error: {0}")]
AnyhowError(#[from] anyhow::Error),
}
pub type Result<T> = result::Result<T, Error>;
// impl From for std::
+256
View File
@@ -0,0 +1,256 @@
use std::sync::{Arc, Mutex};
use crate::rpc::PeerConnInfo;
use crossbeam::atomic::AtomicCell;
use super::{
config::{ConfigLoader, Flags},
netns::NetNS,
network::IPCollector,
stun::{StunInfoCollector, StunInfoCollectorTrait},
PeerId,
};
pub type NetworkIdentity = crate::common::config::NetworkIdentity;
#[derive(Debug, Clone, PartialEq)]
pub enum GlobalCtxEvent {
TunDeviceReady(String),
PeerAdded(PeerId),
PeerRemoved(PeerId),
PeerConnAdded(PeerConnInfo),
PeerConnRemoved(PeerConnInfo),
ListenerAdded(url::Url),
ConnectionAccepted(String, String), // (local url, remote url)
ConnectionError(String, String, String), // (local url, remote url, error message)
Connecting(url::Url),
ConnectError(String, String), // (dst, error message)
VpnPortalClientConnected(String, String), // (portal, client ip)
VpnPortalClientDisconnected(String, String), // (portal, client ip)
}
type EventBus = tokio::sync::broadcast::Sender<GlobalCtxEvent>;
type EventBusSubscriber = tokio::sync::broadcast::Receiver<GlobalCtxEvent>;
pub struct GlobalCtx {
pub inst_name: String,
pub id: uuid::Uuid,
pub config: Box<dyn ConfigLoader>,
pub net_ns: NetNS,
pub network: NetworkIdentity,
event_bus: EventBus,
cached_ipv4: AtomicCell<Option<std::net::Ipv4Addr>>,
cached_proxy_cidrs: AtomicCell<Option<Vec<cidr::IpCidr>>>,
ip_collector: Arc<IPCollector>,
hotname: AtomicCell<Option<String>>,
stun_info_collection: Box<dyn StunInfoCollectorTrait>,
running_listeners: Mutex<Vec<url::Url>>,
}
impl std::fmt::Debug for GlobalCtx {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("GlobalCtx")
.field("inst_name", &self.inst_name)
.field("id", &self.id)
.field("net_ns", &self.net_ns.name())
.field("event_bus", &"EventBus")
.field("ipv4", &self.cached_ipv4)
.finish()
}
}
pub type ArcGlobalCtx = std::sync::Arc<GlobalCtx>;
impl GlobalCtx {
pub fn new(config_fs: impl ConfigLoader + 'static + Send + Sync) -> Self {
let id = config_fs.get_id();
let network = config_fs.get_network_identity();
let net_ns = NetNS::new(config_fs.get_netns());
let (event_bus, _) = tokio::sync::broadcast::channel(100);
GlobalCtx {
inst_name: config_fs.get_inst_name(),
id,
config: Box::new(config_fs),
net_ns: net_ns.clone(),
network,
event_bus,
cached_ipv4: AtomicCell::new(None),
cached_proxy_cidrs: AtomicCell::new(None),
ip_collector: Arc::new(IPCollector::new(net_ns)),
hotname: AtomicCell::new(None),
stun_info_collection: Box::new(StunInfoCollector::new_with_default_servers()),
running_listeners: Mutex::new(Vec::new()),
}
}
pub fn subscribe(&self) -> EventBusSubscriber {
self.event_bus.subscribe()
}
pub fn issue_event(&self, event: GlobalCtxEvent) {
if self.event_bus.receiver_count() != 0 {
self.event_bus.send(event).unwrap();
} else {
log::warn!("No subscriber for event: {:?}", event);
}
}
pub fn get_ipv4(&self) -> Option<std::net::Ipv4Addr> {
if let Some(ret) = self.cached_ipv4.load() {
return Some(ret);
}
let addr = self.config.get_ipv4();
self.cached_ipv4.store(addr.clone());
return addr;
}
pub fn set_ipv4(&mut self, addr: std::net::Ipv4Addr) {
self.config.set_ipv4(addr);
self.cached_ipv4.store(None);
}
pub fn add_proxy_cidr(&self, cidr: cidr::IpCidr) -> Result<(), std::io::Error> {
self.config.add_proxy_cidr(cidr);
self.cached_proxy_cidrs.store(None);
Ok(())
}
pub fn remove_proxy_cidr(&self, cidr: cidr::IpCidr) -> Result<(), std::io::Error> {
self.config.remove_proxy_cidr(cidr);
self.cached_proxy_cidrs.store(None);
Ok(())
}
pub fn get_proxy_cidrs(&self) -> Vec<cidr::IpCidr> {
if let Some(proxy_cidrs) = self.cached_proxy_cidrs.take() {
self.cached_proxy_cidrs.store(Some(proxy_cidrs.clone()));
return proxy_cidrs;
}
let ret = self.config.get_proxy_cidrs();
self.cached_proxy_cidrs.store(Some(ret.clone()));
ret
}
pub fn get_id(&self) -> uuid::Uuid {
self.config.get_id()
}
pub fn get_network_identity(&self) -> NetworkIdentity {
self.config.get_network_identity()
}
pub fn get_ip_collector(&self) -> Arc<IPCollector> {
self.ip_collector.clone()
}
pub fn get_hostname(&self) -> Option<String> {
if let Some(hostname) = self.hotname.take() {
self.hotname.store(Some(hostname.clone()));
return Some(hostname);
}
let hostname = gethostname::gethostname().to_string_lossy().to_string();
self.hotname.store(Some(hostname.clone()));
return Some(hostname);
}
pub fn get_stun_info_collector(&self) -> impl StunInfoCollectorTrait + '_ {
self.stun_info_collection.as_ref()
}
#[cfg(test)]
pub fn replace_stun_info_collector(&self, collector: Box<dyn StunInfoCollectorTrait>) {
// force replace the stun_info_collection without mut and drop the old one
let ptr = &self.stun_info_collection as *const Box<dyn StunInfoCollectorTrait>;
let ptr = ptr as *mut Box<dyn StunInfoCollectorTrait>;
unsafe {
std::ptr::drop_in_place(ptr);
#[allow(invalid_reference_casting)]
std::ptr::write(ptr, collector);
}
}
pub fn get_running_listeners(&self) -> Vec<url::Url> {
self.running_listeners.lock().unwrap().clone()
}
pub fn add_running_listener(&self, url: url::Url) {
self.running_listeners.lock().unwrap().push(url);
}
pub fn get_vpn_portal_cidr(&self) -> Option<cidr::Ipv4Cidr> {
self.config.get_vpn_portal_config().map(|x| x.client_cidr)
}
pub fn get_flags(&self) -> Flags {
self.config.get_flags()
}
}
#[cfg(test)]
pub mod tests {
use crate::common::{config::TomlConfigLoader, new_peer_id};
use super::*;
#[tokio::test]
async fn test_global_ctx() {
let config = TomlConfigLoader::default();
let global_ctx = GlobalCtx::new(config);
let mut subscriber = global_ctx.subscribe();
let peer_id = new_peer_id();
global_ctx.issue_event(GlobalCtxEvent::PeerAdded(peer_id.clone()));
global_ctx.issue_event(GlobalCtxEvent::PeerRemoved(peer_id.clone()));
global_ctx.issue_event(GlobalCtxEvent::PeerConnAdded(PeerConnInfo::default()));
global_ctx.issue_event(GlobalCtxEvent::PeerConnRemoved(PeerConnInfo::default()));
assert_eq!(
subscriber.recv().await.unwrap(),
GlobalCtxEvent::PeerAdded(peer_id.clone())
);
assert_eq!(
subscriber.recv().await.unwrap(),
GlobalCtxEvent::PeerRemoved(peer_id.clone())
);
assert_eq!(
subscriber.recv().await.unwrap(),
GlobalCtxEvent::PeerConnAdded(PeerConnInfo::default())
);
assert_eq!(
subscriber.recv().await.unwrap(),
GlobalCtxEvent::PeerConnRemoved(PeerConnInfo::default())
);
}
pub fn get_mock_global_ctx_with_network(
network_identy: Option<NetworkIdentity>,
) -> ArcGlobalCtx {
let config_fs = TomlConfigLoader::default();
config_fs.set_inst_name(format!("test_{}", config_fs.get_id()));
config_fs.set_network_identity(network_identy.unwrap_or(NetworkIdentity::default()));
std::sync::Arc::new(GlobalCtx::new(config_fs))
}
pub fn get_mock_global_ctx() -> ArcGlobalCtx {
get_mock_global_ctx_with_network(None)
}
}
+358
View File
@@ -0,0 +1,358 @@
use std::net::Ipv4Addr;
use async_trait::async_trait;
use tokio::process::Command;
use super::error::Error;
#[async_trait]
pub trait IfConfiguerTrait {
async fn add_ipv4_route(
&self,
name: &str,
address: Ipv4Addr,
cidr_prefix: u8,
) -> Result<(), Error>;
async fn remove_ipv4_route(
&self,
name: &str,
address: Ipv4Addr,
cidr_prefix: u8,
) -> Result<(), Error>;
async fn add_ipv4_ip(
&self,
name: &str,
address: Ipv4Addr,
cidr_prefix: u8,
) -> Result<(), Error>;
async fn set_link_status(&self, name: &str, up: bool) -> Result<(), Error>;
async fn remove_ip(&self, name: &str, ip: Option<Ipv4Addr>) -> Result<(), Error>;
async fn wait_interface_show(&self, _name: &str) -> Result<(), Error> {
return Ok(());
}
}
fn cidr_to_subnet_mask(prefix_length: u8) -> Ipv4Addr {
if prefix_length > 32 {
panic!("Invalid CIDR prefix length");
}
let subnet_mask: u32 = (!0u32)
.checked_shl(32 - u32::from(prefix_length))
.unwrap_or(0);
Ipv4Addr::new(
((subnet_mask >> 24) & 0xFF) as u8,
((subnet_mask >> 16) & 0xFF) as u8,
((subnet_mask >> 8) & 0xFF) as u8,
(subnet_mask & 0xFF) as u8,
)
}
async fn run_shell_cmd(cmd: &str) -> Result<(), Error> {
let cmd_out = if cfg!(target_os = "windows") {
Command::new("cmd").arg("/C").arg(cmd).output().await?
} else {
Command::new("sh").arg("-c").arg(cmd).output().await?
};
let stdout = String::from_utf8_lossy(cmd_out.stdout.as_slice());
let stderr = String::from_utf8_lossy(cmd_out.stderr.as_slice());
let ec = cmd_out.status.code();
let succ = cmd_out.status.success();
tracing::info!(?cmd, ?ec, ?succ, ?stdout, ?stderr, "run shell cmd");
if !cmd_out.status.success() {
return Err(Error::ShellCommandError(
stdout.to_string() + &stderr.to_string(),
));
}
Ok(())
}
pub struct MacIfConfiger {}
#[async_trait]
impl IfConfiguerTrait for MacIfConfiger {
async fn add_ipv4_route(
&self,
name: &str,
address: Ipv4Addr,
cidr_prefix: u8,
) -> Result<(), Error> {
run_shell_cmd(
format!(
"route -n add {} -netmask {} -interface {} -hopcount 7",
address,
cidr_to_subnet_mask(cidr_prefix),
name
)
.as_str(),
)
.await
}
async fn remove_ipv4_route(
&self,
name: &str,
address: Ipv4Addr,
cidr_prefix: u8,
) -> Result<(), Error> {
run_shell_cmd(
format!(
"route -n delete {} -netmask {} -interface {}",
address,
cidr_to_subnet_mask(cidr_prefix),
name
)
.as_str(),
)
.await
}
async fn add_ipv4_ip(
&self,
name: &str,
address: Ipv4Addr,
cidr_prefix: u8,
) -> Result<(), Error> {
run_shell_cmd(
format!(
"ifconfig {} {:?}/{:?} 10.8.8.8 up",
name, address, cidr_prefix,
)
.as_str(),
)
.await
}
async fn set_link_status(&self, name: &str, up: bool) -> Result<(), Error> {
run_shell_cmd(format!("ifconfig {} {}", name, if up { "up" } else { "down" }).as_str())
.await
}
async fn remove_ip(&self, name: &str, ip: Option<Ipv4Addr>) -> Result<(), Error> {
if ip.is_none() {
run_shell_cmd(format!("ifconfig {} inet delete", name).as_str()).await
} else {
run_shell_cmd(
format!("ifconfig {} inet {} delete", name, ip.unwrap().to_string()).as_str(),
)
.await
}
}
}
pub struct LinuxIfConfiger {}
#[async_trait]
impl IfConfiguerTrait for LinuxIfConfiger {
async fn add_ipv4_route(
&self,
name: &str,
address: Ipv4Addr,
cidr_prefix: u8,
) -> Result<(), Error> {
run_shell_cmd(
format!(
"ip route add {}/{} dev {} metric 65535",
address, cidr_prefix, name
)
.as_str(),
)
.await
}
async fn remove_ipv4_route(
&self,
name: &str,
address: Ipv4Addr,
cidr_prefix: u8,
) -> Result<(), Error> {
run_shell_cmd(format!("ip route del {}/{} dev {}", address, cidr_prefix, name).as_str())
.await
}
async fn add_ipv4_ip(
&self,
name: &str,
address: Ipv4Addr,
cidr_prefix: u8,
) -> Result<(), Error> {
run_shell_cmd(format!("ip addr add {:?}/{:?} dev {}", address, cidr_prefix, name).as_str())
.await
}
async fn set_link_status(&self, name: &str, up: bool) -> Result<(), Error> {
run_shell_cmd(format!("ip link set {} {}", name, if up { "up" } else { "down" }).as_str())
.await
}
async fn remove_ip(&self, name: &str, ip: Option<Ipv4Addr>) -> Result<(), Error> {
if ip.is_none() {
run_shell_cmd(format!("ip addr flush dev {}", name).as_str()).await
} else {
run_shell_cmd(
format!("ip addr del {:?} dev {}", ip.unwrap().to_string(), name).as_str(),
)
.await
}
}
}
#[cfg(target_os = "windows")]
pub struct WindowsIfConfiger {}
#[cfg(target_os = "windows")]
impl WindowsIfConfiger {
pub fn get_interface_index(name: &str) -> Option<u32> {
crate::arch::windows::find_interface_index(name).ok()
}
async fn list_ipv4(name: &str) -> Result<Vec<Ipv4Addr>, Error> {
use anyhow::Context;
use network_interface::NetworkInterfaceConfig;
use std::net::IpAddr;
let ret = network_interface::NetworkInterface::show().with_context(|| "show interface")?;
let addrs = ret
.iter()
.filter_map(|x| {
if x.name != name {
return None;
}
Some(x.addr.clone())
})
.flat_map(|x| x)
.map(|x| x.ip())
.filter_map(|x| {
if let IpAddr::V4(ipv4) = x {
Some(ipv4)
} else {
None
}
})
.collect::<Vec<_>>();
Ok(addrs)
}
async fn remove_one_ipv4(name: &str, ip: Ipv4Addr) -> Result<(), Error> {
run_shell_cmd(
format!(
"netsh interface ipv4 delete address {} address={}",
name,
ip.to_string()
)
.as_str(),
)
.await
}
}
#[cfg(target_os = "windows")]
#[async_trait]
impl IfConfiguerTrait for WindowsIfConfiger {
async fn add_ipv4_route(
&self,
name: &str,
address: Ipv4Addr,
cidr_prefix: u8,
) -> Result<(), Error> {
let Some(idx) = Self::get_interface_index(name) else {
return Err(Error::NotFound);
};
run_shell_cmd(
format!(
"route ADD {} MASK {} 10.1.1.1 IF {} METRIC 255",
address,
cidr_to_subnet_mask(cidr_prefix),
idx
)
.as_str(),
)
.await
}
async fn remove_ipv4_route(
&self,
name: &str,
address: Ipv4Addr,
cidr_prefix: u8,
) -> Result<(), Error> {
let Some(idx) = Self::get_interface_index(name) else {
return Err(Error::NotFound);
};
run_shell_cmd(
format!(
"route DELETE {} MASK {} IF {}",
address,
cidr_to_subnet_mask(cidr_prefix),
idx
)
.as_str(),
)
.await
}
async fn add_ipv4_ip(
&self,
name: &str,
address: Ipv4Addr,
cidr_prefix: u8,
) -> Result<(), Error> {
run_shell_cmd(
format!(
"netsh interface ipv4 add address {} address={} mask={}",
name,
address,
cidr_to_subnet_mask(cidr_prefix)
)
.as_str(),
)
.await
}
async fn set_link_status(&self, name: &str, up: bool) -> Result<(), Error> {
run_shell_cmd(
format!(
"netsh interface set interface {} {}",
name,
if up { "enable" } else { "disable" }
)
.as_str(),
)
.await
}
async fn remove_ip(&self, name: &str, ip: Option<Ipv4Addr>) -> Result<(), Error> {
if ip.is_none() {
for ip in Self::list_ipv4(name).await?.iter() {
Self::remove_one_ipv4(name, *ip).await?;
}
Ok(())
} else {
Self::remove_one_ipv4(name, ip.unwrap()).await
}
}
async fn wait_interface_show(&self, name: &str) -> Result<(), Error> {
Ok(
tokio::time::timeout(std::time::Duration::from_secs(10), async move {
loop {
if let Some(idx) = Self::get_interface_index(name) {
tracing::info!(?name, ?idx, "Interface found");
break;
}
tokio::time::sleep(std::time::Duration::from_millis(100)).await;
}
Ok::<(), Error>(())
})
.await??,
)
}
}
#[cfg(target_os = "macos")]
pub type IfConfiger = MacIfConfiger;
#[cfg(target_os = "linux")]
pub type IfConfiger = LinuxIfConfiger;
#[cfg(target_os = "windows")]
pub type IfConfiger = WindowsIfConfiger;
+120
View File
@@ -0,0 +1,120 @@
use std::{
fmt::Debug,
future,
sync::{Arc, Mutex},
};
use tokio::task::JoinSet;
use tracing::Instrument;
pub mod config;
pub mod constants;
pub mod error;
pub mod global_ctx;
pub mod ifcfg;
pub mod netns;
pub mod network;
pub mod rkyv_util;
pub mod stun;
pub mod stun_codec_ext;
pub fn get_logger_timer<F: time::formatting::Formattable>(
format: F,
) -> tracing_subscriber::fmt::time::OffsetTime<F> {
unsafe {
time::util::local_offset::set_soundness(time::util::local_offset::Soundness::Unsound)
};
let local_offset = time::UtcOffset::current_local_offset()
.unwrap_or(time::UtcOffset::from_whole_seconds(0).unwrap());
tracing_subscriber::fmt::time::OffsetTime::new(local_offset, format)
}
pub fn get_logger_timer_rfc3339(
) -> tracing_subscriber::fmt::time::OffsetTime<time::format_description::well_known::Rfc3339> {
get_logger_timer(time::format_description::well_known::Rfc3339)
}
pub type PeerId = u32;
pub fn new_peer_id() -> PeerId {
rand::random()
}
pub fn join_joinset_background<T: Debug + Send + Sync + 'static>(
js: Arc<Mutex<JoinSet<T>>>,
origin: String,
) {
let js = Arc::downgrade(&js);
tokio::spawn(
async move {
loop {
tokio::time::sleep(std::time::Duration::from_secs(1)).await;
if js.weak_count() == 0 {
tracing::info!("joinset task exit");
break;
}
future::poll_fn(|cx| {
tracing::debug!("try join joinset tasks");
let Some(js) = js.upgrade() else {
return std::task::Poll::Ready(());
};
let mut js = js.lock().unwrap();
while !js.is_empty() {
let ret = js.poll_join_next(cx);
if ret.is_pending() {
return std::task::Poll::Pending;
}
}
std::task::Poll::Ready(())
})
.await;
}
}
.instrument(tracing::info_span!(
"join_joinset_background",
origin = origin
)),
);
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_join_joinset_backgroud() {
let js = Arc::new(Mutex::new(JoinSet::<()>::new()));
join_joinset_background(js.clone(), "TEST".to_owned());
js.try_lock().unwrap().spawn(async {
tokio::time::sleep(std::time::Duration::from_secs(1)).await;
});
tokio::time::sleep(std::time::Duration::from_secs(2)).await;
assert!(js.try_lock().unwrap().is_empty());
for _ in 0..5 {
js.try_lock().unwrap().spawn(async {
tokio::time::sleep(std::time::Duration::from_secs(3)).await;
});
tokio::task::yield_now().await;
}
tokio::time::sleep(std::time::Duration::from_secs(2)).await;
for _ in 0..5 {
js.try_lock().unwrap().spawn(async {
tokio::time::sleep(std::time::Duration::from_secs(1)).await;
});
tokio::task::yield_now().await;
}
tokio::time::sleep(std::time::Duration::from_secs(2)).await;
assert!(js.try_lock().unwrap().is_empty());
let weak_js = Arc::downgrade(&js);
drop(js);
tokio::time::sleep(std::time::Duration::from_secs(2)).await;
assert_eq!(weak_js.weak_count(), 0);
}
}
+114
View File
@@ -0,0 +1,114 @@
use futures::Future;
#[cfg(target_os = "linux")]
use nix::sched::{setns, CloneFlags};
#[cfg(target_os = "linux")]
use std::os::fd::AsFd;
pub struct NetNSGuard {
#[cfg(target_os = "linux")]
old_ns: Option<std::fs::File>,
}
pub static ROOT_NETNS_NAME: &str = "_root_ns";
#[cfg(target_os = "linux")]
impl NetNSGuard {
pub fn new(ns: Option<String>) -> Box<Self> {
let old_ns = if ns.is_some() {
let old_ns = if cfg!(target_os = "linux") {
Some(std::fs::File::open("/proc/self/ns/net").unwrap())
} else {
None
};
Self::switch_ns(ns);
old_ns
} else {
None
};
Box::new(NetNSGuard { old_ns })
}
fn switch_ns(name: Option<String>) {
if name.is_none() {
return;
}
let ns_path: String;
let name = name.unwrap();
if name == ROOT_NETNS_NAME {
ns_path = "/proc/1/ns/net".to_string();
} else {
ns_path = format!("/var/run/netns/{}", name);
}
let ns = std::fs::File::open(ns_path).unwrap();
log::info!(
"[INIT NS] switching to new ns_name: {:?}, ns_file: {:?}",
name,
ns
);
setns(ns.as_fd(), CloneFlags::CLONE_NEWNET).unwrap();
}
}
#[cfg(target_os = "linux")]
impl Drop for NetNSGuard {
fn drop(&mut self) {
if self.old_ns.is_none() {
return;
}
log::info!("[INIT NS] switching back to old ns, ns: {:?}", self.old_ns);
setns(
self.old_ns.as_ref().unwrap().as_fd(),
CloneFlags::CLONE_NEWNET,
)
.unwrap();
}
}
#[cfg(not(target_os = "linux"))]
impl NetNSGuard {
pub fn new(_ns: Option<String>) -> Box<Self> {
Box::new(NetNSGuard {})
}
}
#[derive(Clone)]
pub struct NetNS {
name: Option<String>,
}
impl NetNS {
pub fn new(name: Option<String>) -> Self {
NetNS { name }
}
pub async fn run_async<F, Fut, Ret>(&self, f: F) -> Ret
where
F: FnOnce() -> Fut,
Fut: Future<Output = Ret>,
{
// TODO: do we really need this lock
// let _lock = LOCK.lock().await;
let _guard = NetNSGuard::new(self.name.clone());
f().await
}
pub fn run<F, Ret>(&self, f: F) -> Ret
where
F: FnOnce() -> Ret,
{
let _guard = NetNSGuard::new(self.name.clone());
f()
}
pub fn guard(&self) -> Box<NetNSGuard> {
NetNSGuard::new(self.name.clone())
}
pub fn name(&self) -> Option<String> {
self.name.clone()
}
}
+243
View File
@@ -0,0 +1,243 @@
use std::{ops::Deref, sync::Arc};
use crate::rpc::peer::GetIpListResponse;
use pnet::datalink::NetworkInterface;
use tokio::{
sync::{Mutex, RwLock},
task::JoinSet,
};
use super::netns::NetNS;
pub const CACHED_IP_LIST_TIMEOUT_SEC: u64 = 60;
struct InterfaceFilter {
iface: NetworkInterface,
}
#[cfg(target_os = "linux")]
impl InterfaceFilter {
async fn is_tun_tap_device(&self) -> bool {
let path = format!("/sys/class/net/{}/tun_flags", self.iface.name);
tokio::fs::metadata(&path).await.is_ok()
}
async fn has_valid_ip(&self) -> bool {
self.iface
.ips
.iter()
.map(|ip| ip.ip())
.any(|ip| !ip.is_loopback() && !ip.is_unspecified() && !ip.is_multicast())
}
async fn filter_iface(&self) -> bool {
tracing::trace!(
"filter linux iface: {:?}, is_point_to_point: {}, is_loopback: {}, is_up: {}, is_lower_up: {}, is_tun: {}, has_valid_ip: {}",
self.iface,
self.iface.is_point_to_point(),
self.iface.is_loopback(),
self.iface.is_up(),
self.iface.is_lower_up(),
self.is_tun_tap_device().await,
self.has_valid_ip().await
);
!self.iface.is_point_to_point()
&& !self.iface.is_loopback()
&& self.iface.is_up()
&& self.iface.is_lower_up()
&& !self.is_tun_tap_device().await
&& self.has_valid_ip().await
}
}
#[cfg(target_os = "macos")]
impl InterfaceFilter {
async fn is_interface_physical(interface_name: &str) -> bool {
let output = tokio::process::Command::new("networksetup")
.args(&["-listallhardwareports"])
.output()
.await
.expect("Failed to execute command");
let stdout = std::str::from_utf8(&output.stdout).expect("Invalid UTF-8");
let lines: Vec<&str> = stdout.lines().collect();
for i in 0..lines.len() {
let line = lines[i];
if line.contains("Device:") && line.contains(interface_name) {
let next_line = lines[i + 1];
if next_line.contains("Virtual Interface") {
return false;
} else {
return true;
}
}
}
false
}
async fn filter_iface(&self) -> bool {
!self.iface.is_point_to_point()
&& !self.iface.is_loopback()
&& self.iface.is_up()
&& Self::is_interface_physical(&self.iface.name).await
}
}
#[cfg(target_os = "windows")]
impl InterfaceFilter {
async fn filter_iface(&self) -> bool {
tracing::debug!(
"iface_name: {:?}, p2p: {:?}, is_up: {:?}, iface: {:?}",
self.iface.name,
self.iface.is_point_to_point(),
self.iface.is_up(),
self.iface
);
!self.iface.is_point_to_point()
&& !self.iface.is_loopback()
&& self
.iface
.ips
.iter()
.map(|ip| ip.ip())
.any(|ip| !ip.is_loopback() && !ip.is_unspecified() && !ip.is_multicast())
&& self.iface.mac.map(|mac| !mac.is_zero()).unwrap_or(false)
}
}
pub async fn local_ipv4() -> std::io::Result<std::net::Ipv4Addr> {
let socket = tokio::net::UdpSocket::bind("0.0.0.0:0").await?;
socket.connect("8.8.8.8:80").await?;
let addr = socket.local_addr()?;
match addr.ip() {
std::net::IpAddr::V4(ip) => Ok(ip),
std::net::IpAddr::V6(_) => Err(std::io::Error::new(
std::io::ErrorKind::AddrNotAvailable,
"no ipv4 address",
)),
}
}
pub async fn local_ipv6() -> std::io::Result<std::net::Ipv6Addr> {
let socket = tokio::net::UdpSocket::bind("[::]:0").await?;
socket
.connect("[2001:4860:4860:0000:0000:0000:0000:8888]:80")
.await?;
let addr = socket.local_addr()?;
match addr.ip() {
std::net::IpAddr::V6(ip) => Ok(ip),
std::net::IpAddr::V4(_) => Err(std::io::Error::new(
std::io::ErrorKind::AddrNotAvailable,
"no ipv4 address",
)),
}
}
pub struct IPCollector {
cached_ip_list: Arc<RwLock<GetIpListResponse>>,
collect_ip_task: Mutex<JoinSet<()>>,
net_ns: NetNS,
}
impl IPCollector {
pub fn new(net_ns: NetNS) -> Self {
Self {
cached_ip_list: Arc::new(RwLock::new(GetIpListResponse::new())),
collect_ip_task: Mutex::new(JoinSet::new()),
net_ns,
}
}
pub async fn collect_ip_addrs(&self) -> GetIpListResponse {
let mut task = self.collect_ip_task.lock().await;
if task.is_empty() {
let cached_ip_list = self.cached_ip_list.clone();
*cached_ip_list.write().await =
Self::do_collect_ip_addrs(false, self.net_ns.clone()).await;
let net_ns = self.net_ns.clone();
task.spawn(async move {
loop {
let ip_addrs = Self::do_collect_ip_addrs(true, net_ns.clone()).await;
*cached_ip_list.write().await = ip_addrs;
tokio::time::sleep(std::time::Duration::from_secs(CACHED_IP_LIST_TIMEOUT_SEC))
.await;
}
});
}
return self.cached_ip_list.read().await.deref().clone();
}
pub async fn collect_interfaces(net_ns: NetNS) -> Vec<NetworkInterface> {
let _g = net_ns.guard();
let ifaces = pnet::datalink::interfaces();
let mut ret = vec![];
for iface in ifaces {
let f = InterfaceFilter {
iface: iface.clone(),
};
if !f.filter_iface().await {
continue;
}
ret.push(iface);
}
ret
}
#[tracing::instrument(skip(net_ns))]
async fn do_collect_ip_addrs(with_public: bool, net_ns: NetNS) -> GetIpListResponse {
let mut ret = crate::rpc::peer::GetIpListResponse::new();
if with_public {
if let Some(v4_addr) =
public_ip::addr_with(public_ip::http::ALL, public_ip::Version::V4).await
{
ret.public_ipv4 = v4_addr.to_string();
}
if let Some(v6_addr) = public_ip::addr_v6().await {
ret.public_ipv6 = v6_addr.to_string();
}
}
let ifaces = Self::collect_interfaces(net_ns.clone()).await;
let _g = net_ns.guard();
for iface in ifaces {
for ip in iface.ips {
let ip: std::net::IpAddr = ip.ip();
if ip.is_loopback() || ip.is_multicast() {
continue;
}
if ip.is_ipv4() {
ret.interface_ipv4s.push(ip.to_string());
} else if ip.is_ipv6() {
ret.interface_ipv6s.push(ip.to_string());
}
}
}
if let Ok(v4_addr) = local_ipv4().await {
tracing::trace!("got local ipv4: {}", v4_addr);
if !ret.interface_ipv4s.contains(&v4_addr.to_string()) {
ret.interface_ipv4s.push(v4_addr.to_string());
}
}
if let Ok(v6_addr) = local_ipv6().await {
tracing::trace!("got local ipv6: {}", v6_addr);
if !ret.interface_ipv6s.contains(&v6_addr.to_string()) {
ret.interface_ipv6s.push(v6_addr.to_string());
}
}
ret
}
}
+72
View File
@@ -0,0 +1,72 @@
use rkyv::{
string::ArchivedString,
validation::{validators::DefaultValidator, CheckTypeError},
vec::ArchivedVec,
Archive, CheckBytes, Serialize,
};
use tokio_util::bytes::{Bytes, BytesMut};
pub fn decode_from_bytes_checked<'a, T: Archive>(
bytes: &'a [u8],
) -> Result<&'a T::Archived, CheckTypeError<T::Archived, DefaultValidator<'a>>>
where
T::Archived: CheckBytes<DefaultValidator<'a>>,
{
rkyv::check_archived_root::<T>(bytes)
}
pub fn decode_from_bytes<'a, T: Archive>(
bytes: &'a [u8],
) -> Result<&'a T::Archived, CheckTypeError<T::Archived, DefaultValidator<'a>>>
where
T::Archived: CheckBytes<DefaultValidator<'a>>,
{
// rkyv::check_archived_root::<T>(bytes)
unsafe { Ok(rkyv::archived_root::<T>(bytes)) }
}
// allow deseraial T to Bytes
pub fn encode_to_bytes<T, const N: usize>(val: &T) -> Bytes
where
T: Serialize<rkyv::ser::serializers::AllocSerializer<N>>,
{
let ret = rkyv::to_bytes::<_, N>(val).unwrap();
// let mut r = BytesMut::new();
// r.extend_from_slice(&ret);
// r.freeze()
ret.into_boxed_slice().into()
}
pub fn extract_bytes_from_archived_vec(raw_data: &Bytes, archived_data: &ArchivedVec<u8>) -> Bytes {
let ptr_range = archived_data.as_ptr_range();
let offset = ptr_range.start as usize - raw_data.as_ptr() as usize;
let len = ptr_range.end as usize - ptr_range.start as usize;
return raw_data.slice(offset..offset + len);
}
pub fn extract_bytes_from_archived_string(
raw_data: &Bytes,
archived_data: &ArchivedString,
) -> Bytes {
let offset = archived_data.as_ptr() as usize - raw_data.as_ptr() as usize;
let len = archived_data.len();
if offset + len > raw_data.len() {
return Bytes::new();
}
return raw_data.slice(offset..offset + archived_data.len());
}
pub fn extract_bytes_mut_from_archived_vec(
raw_data: &mut BytesMut,
archived_data: &ArchivedVec<u8>,
) -> BytesMut {
let ptr_range = archived_data.as_ptr_range();
let offset = ptr_range.start as usize - raw_data.as_ptr() as usize;
let len = ptr_range.end as usize - ptr_range.start as usize;
raw_data.split_off(offset).split_to(len)
}
pub fn vec_to_string(vec: Vec<u8>) -> String {
unsafe { String::from_utf8_unchecked(vec) }
}
+550
View File
@@ -0,0 +1,550 @@
use std::net::SocketAddr;
use std::sync::Arc;
use std::time::Duration;
use crate::rpc::{NatType, StunInfo};
use anyhow::Context;
use crossbeam::atomic::AtomicCell;
use tokio::net::{lookup_host, UdpSocket};
use tokio::sync::RwLock;
use tokio::task::JoinSet;
use tracing::Level;
use bytecodec::{DecodeExt, EncodeExt};
use stun_codec::rfc5389::methods::BINDING;
use stun_codec::rfc5780::attributes::ChangeRequest;
use stun_codec::{Message, MessageClass, MessageDecoder, MessageEncoder};
use crate::common::error::Error;
use super::stun_codec_ext::*;
struct HostResolverIter {
hostnames: Vec<String>,
ips: Vec<SocketAddr>,
}
impl HostResolverIter {
fn new(hostnames: Vec<String>) -> Self {
Self {
hostnames,
ips: vec![],
}
}
#[async_recursion::async_recursion]
async fn next(&mut self) -> Option<SocketAddr> {
if self.ips.is_empty() {
if self.hostnames.is_empty() {
return None;
}
let host = self.hostnames.remove(0);
match lookup_host(&host).await {
Ok(ips) => {
self.ips = ips.collect();
}
Err(e) => {
tracing::warn!(?host, ?e, "lookup host for stun failed");
return self.next().await;
}
};
}
Some(self.ips.remove(0))
}
}
#[derive(Debug, Clone, Copy)]
struct BindRequestResponse {
source_addr: SocketAddr,
send_to_addr: SocketAddr,
recv_from_addr: SocketAddr,
mapped_socket_addr: Option<SocketAddr>,
changed_socket_addr: Option<SocketAddr>,
ip_changed: bool,
port_changed: bool,
real_ip_changed: bool,
real_port_changed: bool,
}
impl BindRequestResponse {
pub fn get_mapped_addr_no_check(&self) -> &SocketAddr {
self.mapped_socket_addr.as_ref().unwrap()
}
}
#[derive(Debug, Clone)]
struct Stun {
stun_server: SocketAddr,
req_repeat: u8,
resp_timeout: Duration,
}
impl Stun {
pub fn new(stun_server: SocketAddr) -> Self {
Self {
stun_server,
req_repeat: 5,
resp_timeout: Duration::from_millis(3000),
}
}
#[tracing::instrument(skip(self, buf))]
async fn wait_stun_response<'a, const N: usize>(
&self,
buf: &'a mut [u8; N],
udp: &UdpSocket,
tids: &Vec<u128>,
expected_ip_changed: bool,
expected_port_changed: bool,
stun_host: &SocketAddr,
) -> Result<(Message<Attribute>, SocketAddr), Error> {
let mut now = tokio::time::Instant::now();
let deadline = now + self.resp_timeout;
while now < deadline {
let mut udp_buf = [0u8; 1500];
let (len, remote_addr) =
tokio::time::timeout(deadline - now, udp.recv_from(udp_buf.as_mut_slice()))
.await??;
now = tokio::time::Instant::now();
if len < 20 {
continue;
}
// TODO:: we cannot borrow `buf` directly in udp recv_from, so we copy it here
unsafe { std::ptr::copy(udp_buf.as_ptr(), buf.as_ptr() as *mut u8, len) };
let mut decoder = MessageDecoder::<Attribute>::new();
let Ok(msg) = decoder
.decode_from_bytes(&buf[..len])
.with_context(|| format!("decode stun msg {:?}", buf))?
else {
continue;
};
tracing::debug!(b = ?&udp_buf[..len], ?tids, ?remote_addr, ?stun_host, "recv stun response, msg: {:#?}", msg);
if msg.class() != MessageClass::SuccessResponse
|| msg.method() != BINDING
|| !tids.contains(&tid_to_u128(&msg.transaction_id()))
{
continue;
}
// some stun server use changed socket even we don't ask for.
if expected_ip_changed && stun_host.ip() == remote_addr.ip() {
continue;
}
if expected_port_changed
&& stun_host.ip() == remote_addr.ip()
&& stun_host.port() == remote_addr.port()
{
continue;
}
return Ok((msg, remote_addr));
}
Err(Error::Unknown)
}
fn extrace_mapped_addr(msg: &Message<Attribute>) -> Option<SocketAddr> {
let mut mapped_addr = None;
for x in msg.attributes() {
match x {
Attribute::MappedAddress(addr) => {
if mapped_addr.is_none() {
let _ = mapped_addr.insert(addr.address());
}
}
Attribute::XorMappedAddress(addr) => {
if mapped_addr.is_none() {
let _ = mapped_addr.insert(addr.address());
}
}
_ => {}
}
}
mapped_addr
}
fn extract_changed_addr(msg: &Message<Attribute>) -> Option<SocketAddr> {
let mut changed_addr = None;
for x in msg.attributes() {
match x {
Attribute::OtherAddress(m) => {
if changed_addr.is_none() {
let _ = changed_addr.insert(m.address());
}
}
Attribute::ChangedAddress(m) => {
if changed_addr.is_none() {
let _ = changed_addr.insert(m.address());
}
}
_ => {}
}
}
changed_addr
}
#[tracing::instrument(ret, err, level = Level::DEBUG)]
pub async fn bind_request(
&self,
source_port: u16,
change_ip: bool,
change_port: bool,
) -> Result<BindRequestResponse, Error> {
let stun_host = self.stun_server;
let udp = UdpSocket::bind(format!("0.0.0.0:{}", source_port)).await?;
// repeat req in case of packet loss
let mut tids = vec![];
for _ in 0..self.req_repeat {
let tid = rand::random::<u32>();
let mut buf = [0u8; 28];
// memset buf
unsafe { std::ptr::write_bytes(buf.as_mut_ptr(), 0, buf.len()) };
let mut message =
Message::<Attribute>::new(MessageClass::Request, BINDING, u128_to_tid(tid as u128));
message.add_attribute(ChangeRequest::new(change_ip, change_port));
// Encodes the message
let mut encoder = MessageEncoder::new();
let msg = encoder
.encode_into_bytes(message.clone())
.with_context(|| "encode stun message")?;
tids.push(tid as u128);
tracing::trace!(?message, ?msg, tid, "send stun request");
udp.send_to(msg.as_slice().into(), &stun_host).await?;
}
tracing::trace!("waiting stun response");
let mut buf = [0; 1620];
let (msg, recv_addr) = self
.wait_stun_response(&mut buf, &udp, &tids, change_ip, change_port, &stun_host)
.await?;
let changed_socket_addr = Self::extract_changed_addr(&msg);
let real_ip_changed = stun_host.ip() != recv_addr.ip();
let real_port_changed = stun_host.port() != recv_addr.port();
let resp = BindRequestResponse {
source_addr: udp.local_addr()?,
send_to_addr: stun_host,
recv_from_addr: recv_addr,
mapped_socket_addr: Self::extrace_mapped_addr(&msg),
changed_socket_addr,
ip_changed: change_ip,
port_changed: change_port,
real_ip_changed,
real_port_changed,
};
tracing::debug!(
?stun_host,
?recv_addr,
?changed_socket_addr,
"finish stun bind request"
);
Ok(resp)
}
}
pub struct UdpNatTypeDetector {
stun_servers: Vec<String>,
}
impl UdpNatTypeDetector {
pub fn new(stun_servers: Vec<String>) -> Self {
Self { stun_servers }
}
pub async fn get_udp_nat_type(&self, mut source_port: u16) -> NatType {
// Like classic STUN (rfc3489). Detect NAT behavior for UDP.
// Modified from rfc3489. Requires at least two STUN servers.
let mut ret_test1_1 = None;
let mut ret_test1_2 = None;
let mut ret_test2 = None;
let mut ret_test3 = None;
if source_port == 0 {
let udp = UdpSocket::bind("0.0.0.0:0").await.unwrap();
source_port = udp.local_addr().unwrap().port();
}
let mut succ = false;
let mut ips = HostResolverIter::new(self.stun_servers.clone());
while let Some(server_ip) = ips.next().await {
let stun = Stun::new(server_ip.clone());
let ret = stun.bind_request(source_port, false, false).await;
if ret.is_err() {
// Try another STUN server
continue;
}
if ret_test1_1.is_none() {
ret_test1_1 = ret.ok();
continue;
}
ret_test1_2 = ret.ok();
let ret = stun.bind_request(source_port, true, true).await;
if let Ok(resp) = ret {
if !resp.real_ip_changed || !resp.real_port_changed {
tracing::debug!(
?server_ip,
?ret,
"stun bind request return with unchanged ip and port"
);
// Try another STUN server
continue;
}
}
ret_test2 = ret.ok();
ret_test3 = stun.bind_request(source_port, false, true).await.ok();
tracing::debug!(?ret_test3, "stun bind request with changed port");
succ = true;
break;
}
if !succ {
return NatType::Unknown;
}
tracing::debug!(
?ret_test1_1,
?ret_test1_2,
?ret_test2,
?ret_test3,
"finish stun test, try to detect nat type"
);
let ret_test1_1 = ret_test1_1.unwrap();
let ret_test1_2 = ret_test1_2.unwrap();
if ret_test1_1.mapped_socket_addr != ret_test1_2.mapped_socket_addr {
return NatType::Symmetric;
}
if ret_test1_1.mapped_socket_addr.is_some()
&& ret_test1_1.source_addr == ret_test1_1.mapped_socket_addr.unwrap()
{
if !ret_test2.is_none() {
return NatType::OpenInternet;
} else {
return NatType::SymUdpFirewall;
}
} else {
if let Some(ret_test2) = ret_test2 {
if source_port == ret_test2.get_mapped_addr_no_check().port()
&& source_port == ret_test1_1.get_mapped_addr_no_check().port()
{
return NatType::NoPat;
} else {
return NatType::FullCone;
}
} else {
if !ret_test3.is_none() {
return NatType::Restricted;
} else {
return NatType::PortRestricted;
}
}
}
}
}
#[async_trait::async_trait]
#[auto_impl::auto_impl(&, Arc, Box)]
pub trait StunInfoCollectorTrait: Send + Sync {
fn get_stun_info(&self) -> StunInfo;
async fn get_udp_port_mapping(&self, local_port: u16) -> Result<SocketAddr, Error>;
}
pub struct StunInfoCollector {
stun_servers: Arc<RwLock<Vec<String>>>,
udp_nat_type: Arc<AtomicCell<(NatType, std::time::Instant)>>,
redetect_notify: Arc<tokio::sync::Notify>,
tasks: JoinSet<()>,
}
#[async_trait::async_trait]
impl StunInfoCollectorTrait for StunInfoCollector {
fn get_stun_info(&self) -> StunInfo {
let (typ, time) = self.udp_nat_type.load();
StunInfo {
udp_nat_type: typ as i32,
tcp_nat_type: 0,
last_update_time: time.elapsed().as_secs() as i64,
}
}
async fn get_udp_port_mapping(&self, local_port: u16) -> Result<SocketAddr, Error> {
let stun_servers = self.stun_servers.read().await.clone();
let mut ips = HostResolverIter::new(stun_servers.clone());
while let Some(server) = ips.next().await {
let stun = Stun::new(server.clone());
let Ok(ret) = stun.bind_request(local_port, false, false).await else {
tracing::warn!(?server, "stun bind request failed");
continue;
};
if let Some(mapped_addr) = ret.mapped_socket_addr {
return Ok(mapped_addr);
}
}
Err(Error::NotFound)
}
}
impl StunInfoCollector {
pub fn new(stun_servers: Vec<String>) -> Self {
let mut ret = Self {
stun_servers: Arc::new(RwLock::new(stun_servers)),
udp_nat_type: Arc::new(AtomicCell::new((
NatType::Unknown,
std::time::Instant::now(),
))),
redetect_notify: Arc::new(tokio::sync::Notify::new()),
tasks: JoinSet::new(),
};
ret.start_stun_routine();
ret
}
pub fn new_with_default_servers() -> Self {
Self::new(Self::get_default_servers())
}
pub fn get_default_servers() -> Vec<String> {
// NOTICE: we may need to choose stun stun server based on geo location
// stun server cross nation may return a external ip address with high latency and loss rate
vec![
"stun.miwifi.com:3478".to_string(),
"stun.qq.com:3478".to_string(),
// "stun.chat.bilibili.com:3478".to_string(), // bilibili's stun server doesn't repond to change_ip and change_port
"fwa.lifesizecloud.com:3478".to_string(),
"stun.isp.net.au:3478".to_string(),
"stun.nextcloud.com:3478".to_string(),
"stun.freeswitch.org:3478".to_string(),
"stun.voip.blackberry.com:3478".to_string(),
"stunserver.stunprotocol.org:3478".to_string(),
"stun.sipnet.com:3478".to_string(),
"stun.radiojar.com:3478".to_string(),
"stun.sonetel.com:3478".to_string(),
"stun.voipgate.com:3478".to_string(),
"stun.counterpath.com:3478".to_string(),
"180.235.108.91:3478".to_string(),
"193.22.2.248:3478".to_string(),
]
}
fn start_stun_routine(&mut self) {
let stun_servers = self.stun_servers.clone();
let udp_nat_type = self.udp_nat_type.clone();
let redetect_notify = self.redetect_notify.clone();
self.tasks.spawn(async move {
loop {
let detector = UdpNatTypeDetector::new(stun_servers.read().await.clone());
let old_nat_type = udp_nat_type.load().0;
let mut ret = NatType::Unknown;
for _ in 1..5 {
// if nat type degrade, sleep and retry. so result can be relatively stable.
ret = detector.get_udp_nat_type(0).await;
if ret == NatType::Unknown || ret <= old_nat_type {
break;
}
tokio::time::sleep(Duration::from_secs(5)).await;
}
udp_nat_type.store((ret, std::time::Instant::now()));
let sleep_sec = match ret {
NatType::Unknown => 15,
_ => 60,
};
tracing::info!(?ret, ?sleep_sec, "finish udp nat type detect");
tokio::select! {
_ = redetect_notify.notified() => {}
_ = tokio::time::sleep(Duration::from_secs(sleep_sec)) => {}
}
}
});
}
pub fn update_stun_info(&self) {
self.redetect_notify.notify_one();
}
pub async fn set_stun_servers(&self, stun_servers: Vec<String>) {
*self.stun_servers.write().await = stun_servers;
self.update_stun_info();
}
}
#[cfg(test)]
mod tests {
use super::*;
pub fn enable_log() {
let filter = tracing_subscriber::EnvFilter::builder()
.with_default_directive(tracing::level_filters::LevelFilter::TRACE.into())
.from_env()
.unwrap()
.add_directive("tarpc=error".parse().unwrap());
tracing_subscriber::fmt::fmt()
.pretty()
.with_env_filter(filter)
.init();
}
#[tokio::test]
async fn test_stun_bind_request() {
// miwifi / qq seems not correctly responde to change_ip and change_port, they always try to change the src ip and port.
let mut ips = HostResolverIter::new(vec!["stun1.l.google.com:19302".to_string()]);
let stun = Stun::new(ips.next().await.unwrap());
// let stun = Stun::new("180.235.108.91:3478".to_string());
// let stun = Stun::new("193.22.2.248:3478".to_string());
// let stun = Stun::new("stun.chat.bilibili.com:3478".to_string());
// let stun = Stun::new("stun.miwifi.com:3478".to_string());
// github actions are port restricted nat, so we only test last one.
// let rs = stun.bind_request(12345, true, true).await.unwrap();
// assert!(rs.ip_changed);
// assert!(rs.port_changed);
// let rs = stun.bind_request(12345, true, false).await.unwrap();
// assert!(rs.ip_changed);
// assert!(!rs.port_changed);
// let rs = stun.bind_request(12345, false, true).await.unwrap();
// assert!(!rs.ip_changed);
// assert!(rs.port_changed);
let rs = stun.bind_request(12345, false, false).await.unwrap();
assert!(!rs.ip_changed);
assert!(!rs.port_changed);
}
#[tokio::test]
async fn test_udp_nat_type_detect() {
let detector = UdpNatTypeDetector::new(vec![
"stun.counterpath.com:3478".to_string(),
"180.235.108.91:3478".to_string(),
]);
let ret = detector.get_udp_nat_type(0).await;
assert_ne!(ret, NatType::Unknown);
}
}
+229
View File
@@ -0,0 +1,229 @@
use std::net::SocketAddr;
use stun_codec::net::{socket_addr_xor, SocketAddrDecoder, SocketAddrEncoder};
use stun_codec::rfc5389::attributes::{
MappedAddress, Software, XorMappedAddress, XorMappedAddress2,
};
use stun_codec::rfc5780::attributes::{ChangeRequest, OtherAddress, ResponseOrigin};
use stun_codec::{define_attribute_enums, AttributeType, Message, TransactionId};
use bytecodec::{ByteCount, Decode, Encode, Eos, Result, SizedEncode, TryTaggedDecode};
use stun_codec::macros::track;
macro_rules! impl_decode {
($decoder:ty, $item:ident, $and_then:expr) => {
impl Decode for $decoder {
type Item = $item;
fn decode(&mut self, buf: &[u8], eos: Eos) -> Result<usize> {
track!(self.0.decode(buf, eos))
}
fn finish_decoding(&mut self) -> Result<Self::Item> {
track!(self.0.finish_decoding()).and_then($and_then)
}
fn requiring_bytes(&self) -> ByteCount {
self.0.requiring_bytes()
}
fn is_idle(&self) -> bool {
self.0.is_idle()
}
}
impl TryTaggedDecode for $decoder {
type Tag = AttributeType;
fn try_start_decoding(&mut self, attr_type: Self::Tag) -> Result<bool> {
Ok(attr_type.as_u16() == $item::CODEPOINT)
}
}
};
}
macro_rules! impl_encode {
($encoder:ty, $item:ty, $map_from:expr) => {
impl Encode for $encoder {
type Item = $item;
fn encode(&mut self, buf: &mut [u8], eos: Eos) -> Result<usize> {
track!(self.0.encode(buf, eos))
}
#[allow(clippy::redundant_closure_call)]
fn start_encoding(&mut self, item: Self::Item) -> Result<()> {
track!(self.0.start_encoding($map_from(item)))
}
fn requiring_bytes(&self) -> ByteCount {
self.0.requiring_bytes()
}
fn is_idle(&self) -> bool {
self.0.is_idle()
}
}
impl SizedEncode for $encoder {
fn exact_requiring_bytes(&self) -> u64 {
self.0.exact_requiring_bytes()
}
}
};
}
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct ChangedAddress(SocketAddr);
impl ChangedAddress {
/// The codepoint of the type of the attribute.
pub const CODEPOINT: u16 = 0x0005;
pub fn new(addr: SocketAddr) -> Self {
ChangedAddress(addr)
}
/// Returns the address of this instance.
pub fn address(&self) -> SocketAddr {
self.0
}
}
impl stun_codec::Attribute for ChangedAddress {
type Decoder = ChangedAddressDecoder;
type Encoder = ChangedAddressEncoder;
fn get_type(&self) -> AttributeType {
AttributeType::new(Self::CODEPOINT)
}
fn before_encode<A: stun_codec::Attribute>(
&mut self,
message: &Message<A>,
) -> bytecodec::Result<()> {
self.0 = socket_addr_xor(self.0, message.transaction_id());
Ok(())
}
fn after_decode<A: stun_codec::Attribute>(
&mut self,
message: &Message<A>,
) -> bytecodec::Result<()> {
self.0 = socket_addr_xor(self.0, message.transaction_id());
Ok(())
}
}
#[derive(Debug, Default)]
pub struct ChangedAddressDecoder(SocketAddrDecoder);
impl ChangedAddressDecoder {
pub fn new() -> Self {
Self::default()
}
}
impl_decode!(ChangedAddressDecoder, ChangedAddress, |item| Ok(
ChangedAddress(item)
));
#[derive(Debug, Default)]
pub struct ChangedAddressEncoder(SocketAddrEncoder);
impl ChangedAddressEncoder {
pub fn new() -> Self {
Self::default()
}
}
impl_encode!(ChangedAddressEncoder, ChangedAddress, |item: Self::Item| {
item.0
});
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct SourceAddress(SocketAddr);
impl SourceAddress {
/// The codepoint of the type of the attribute.
pub const CODEPOINT: u16 = 0x0004;
pub fn new(addr: SocketAddr) -> Self {
SourceAddress(addr)
}
/// Returns the address of this instance.
pub fn address(&self) -> SocketAddr {
self.0
}
}
impl stun_codec::Attribute for SourceAddress {
type Decoder = SourceAddressDecoder;
type Encoder = SourceAddressEncoder;
fn get_type(&self) -> AttributeType {
AttributeType::new(Self::CODEPOINT)
}
fn before_encode<A: stun_codec::Attribute>(
&mut self,
message: &Message<A>,
) -> bytecodec::Result<()> {
self.0 = socket_addr_xor(self.0, message.transaction_id());
Ok(())
}
fn after_decode<A: stun_codec::Attribute>(
&mut self,
message: &Message<A>,
) -> bytecodec::Result<()> {
self.0 = socket_addr_xor(self.0, message.transaction_id());
Ok(())
}
}
#[derive(Debug, Default)]
pub struct SourceAddressDecoder(SocketAddrDecoder);
impl SourceAddressDecoder {
pub fn new() -> Self {
Self::default()
}
}
impl_decode!(SourceAddressDecoder, SourceAddress, |item| Ok(
SourceAddress(item)
));
#[derive(Debug, Default)]
pub struct SourceAddressEncoder(SocketAddrEncoder);
impl SourceAddressEncoder {
pub fn new() -> Self {
Self::default()
}
}
impl_encode!(SourceAddressEncoder, SourceAddress, |item: Self::Item| {
item.0
});
pub fn tid_to_u128(tid: &TransactionId) -> u128 {
let mut tid_buf = [0u8; 16];
// copy bytes from msg_tid to tid_buf
tid_buf[..tid.as_bytes().len()].copy_from_slice(tid.as_bytes());
u128::from_le_bytes(tid_buf)
}
pub fn u128_to_tid(tid: u128) -> TransactionId {
let tid_buf = tid.to_le_bytes();
let mut tid_arr = [0u8; 12];
tid_arr.copy_from_slice(&tid_buf[..12]);
TransactionId::new(tid_arr)
}
define_attribute_enums!(
Attribute,
AttributeDecoder,
AttributeEncoder,
[
Software,
MappedAddress,
XorMappedAddress,
XorMappedAddress2,
OtherAddress,
ChangeRequest,
ChangedAddress,
SourceAddress,
ResponseOrigin
]
);
+411
View File
@@ -0,0 +1,411 @@
// try connect peers directly, with either its public ip or lan ip
use std::sync::Arc;
use crate::{
common::{error::Error, global_ctx::ArcGlobalCtx, PeerId},
peers::{peer_manager::PeerManager, peer_rpc::PeerRpcManager},
};
use crate::rpc::{peer::GetIpListResponse, PeerConnInfo};
use tokio::{task::JoinSet, time::timeout};
use tracing::Instrument;
use super::create_connector_by_url;
pub const DIRECT_CONNECTOR_SERVICE_ID: u32 = 1;
pub const DIRECT_CONNECTOR_BLACKLIST_TIMEOUT_SEC: u64 = 300;
#[tarpc::service]
pub trait DirectConnectorRpc {
async fn get_ip_list() -> GetIpListResponse;
}
#[async_trait::async_trait]
pub trait PeerManagerForDirectConnector {
async fn list_peers(&self) -> Vec<PeerId>;
async fn list_peer_conns(&self, peer_id: PeerId) -> Option<Vec<PeerConnInfo>>;
fn get_peer_rpc_mgr(&self) -> Arc<PeerRpcManager>;
}
#[async_trait::async_trait]
impl PeerManagerForDirectConnector for PeerManager {
async fn list_peers(&self) -> Vec<PeerId> {
let mut ret = vec![];
let routes = self.list_routes().await;
for r in routes.iter() {
ret.push(r.peer_id);
}
ret
}
async fn list_peer_conns(&self, peer_id: PeerId) -> Option<Vec<PeerConnInfo>> {
self.get_peer_map().list_peer_conns(peer_id).await
}
fn get_peer_rpc_mgr(&self) -> Arc<PeerRpcManager> {
self.get_peer_rpc_mgr()
}
}
#[derive(Clone)]
struct DirectConnectorManagerRpcServer {
// TODO: this only cache for one src peer, should make it global
global_ctx: ArcGlobalCtx,
}
#[tarpc::server]
impl DirectConnectorRpc for DirectConnectorManagerRpcServer {
async fn get_ip_list(self, _: tarpc::context::Context) -> GetIpListResponse {
let mut ret = self.global_ctx.get_ip_collector().collect_ip_addrs().await;
ret.listeners = self.global_ctx.get_running_listeners();
ret
}
}
impl DirectConnectorManagerRpcServer {
pub fn new(global_ctx: ArcGlobalCtx) -> Self {
Self { global_ctx }
}
}
#[derive(Hash, Eq, PartialEq, Clone)]
struct DstBlackListItem(PeerId, String);
#[derive(Hash, Eq, PartialEq, Clone)]
struct DstSchemeBlackListItem(PeerId, String);
struct DirectConnectorManagerData {
global_ctx: ArcGlobalCtx,
peer_manager: Arc<PeerManager>,
dst_blacklist: timedmap::TimedMap<DstBlackListItem, ()>,
dst_sceme_blacklist: timedmap::TimedMap<DstSchemeBlackListItem, ()>,
}
impl DirectConnectorManagerData {
pub fn new(global_ctx: ArcGlobalCtx, peer_manager: Arc<PeerManager>) -> Self {
Self {
global_ctx,
peer_manager,
dst_blacklist: timedmap::TimedMap::new(),
dst_sceme_blacklist: timedmap::TimedMap::new(),
}
}
}
impl std::fmt::Debug for DirectConnectorManagerData {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("DirectConnectorManagerData")
.field("peer_manager", &self.peer_manager)
.finish()
}
}
pub struct DirectConnectorManager {
global_ctx: ArcGlobalCtx,
data: Arc<DirectConnectorManagerData>,
tasks: JoinSet<()>,
}
impl DirectConnectorManager {
pub fn new(global_ctx: ArcGlobalCtx, peer_manager: Arc<PeerManager>) -> Self {
Self {
global_ctx: global_ctx.clone(),
data: Arc::new(DirectConnectorManagerData::new(global_ctx, peer_manager)),
tasks: JoinSet::new(),
}
}
pub fn run(&mut self) {
self.run_as_server();
self.run_as_client();
}
pub fn run_as_server(&mut self) {
self.data.peer_manager.get_peer_rpc_mgr().run_service(
DIRECT_CONNECTOR_SERVICE_ID,
DirectConnectorManagerRpcServer::new(self.global_ctx.clone()).serve(),
);
}
pub fn run_as_client(&mut self) {
let data = self.data.clone();
let my_peer_id = self.data.peer_manager.my_peer_id();
self.tasks.spawn(
async move {
loop {
let peers = data.peer_manager.list_peers().await;
let mut tasks = JoinSet::new();
for peer_id in peers {
if peer_id == my_peer_id {
continue;
}
tasks.spawn(Self::do_try_direct_connect(data.clone(), peer_id));
}
while let Some(task_ret) = tasks.join_next().await {
tracing::trace!(?task_ret, "direct connect task ret");
}
tokio::time::sleep(std::time::Duration::from_secs(5)).await;
}
}
.instrument(
tracing::info_span!("direct_connector_client", my_id = ?self.global_ctx.id),
),
);
}
async fn do_try_connect_to_ip(
data: Arc<DirectConnectorManagerData>,
dst_peer_id: PeerId,
addr: String,
) -> Result<(), Error> {
data.dst_blacklist.cleanup();
if data
.dst_blacklist
.contains(&DstBlackListItem(dst_peer_id.clone(), addr.clone()))
{
tracing::trace!("try_connect_to_ip failed, addr in blacklist: {}", addr);
return Err(Error::UrlInBlacklist);
}
let connector = create_connector_by_url(&addr, &data.global_ctx).await?;
let (peer_id, conn_id) = timeout(
std::time::Duration::from_secs(5),
data.peer_manager.try_connect(connector),
)
.await??;
// let (peer_id, conn_id) = data.peer_manager.try_connect(connector).await?;
if peer_id != dst_peer_id {
tracing::info!(
"connect to ip succ: {}, but peer id mismatch, expect: {}, actual: {}",
addr,
dst_peer_id,
peer_id
);
data.peer_manager
.get_peer_map()
.close_peer_conn(peer_id, &conn_id)
.await?;
return Err(Error::InvalidUrl(addr));
}
Ok(())
}
#[tracing::instrument]
async fn try_connect_to_ip(
data: Arc<DirectConnectorManagerData>,
dst_peer_id: PeerId,
addr: String,
) -> Result<(), Error> {
let ret = Self::do_try_connect_to_ip(data.clone(), dst_peer_id, addr.clone()).await;
if let Err(e) = ret {
if !matches!(e, Error::UrlInBlacklist) {
tracing::info!(
"try_connect_to_ip failed: {:?}, peer_id: {}",
e,
dst_peer_id
);
data.dst_blacklist.insert(
DstBlackListItem(dst_peer_id.clone(), addr.clone()),
(),
std::time::Duration::from_secs(DIRECT_CONNECTOR_BLACKLIST_TIMEOUT_SEC),
);
}
return Err(e);
} else {
log::info!("try_connect_to_ip success, peer_id: {}", dst_peer_id);
return Ok(());
}
}
#[tracing::instrument]
async fn do_try_direct_connect_internal(
data: Arc<DirectConnectorManagerData>,
dst_peer_id: PeerId,
ip_list: GetIpListResponse,
) -> Result<(), Error> {
let available_listeners = ip_list
.listeners
.iter()
.filter_map(|l| if l.scheme() != "ring" { Some(l) } else { None })
.filter(|l| l.port().is_some())
.filter(|l| {
!data.dst_sceme_blacklist.contains(&DstSchemeBlackListItem(
dst_peer_id.clone(),
l.scheme().to_string(),
))
})
.collect::<Vec<_>>();
let mut listener = available_listeners.get(0).ok_or(anyhow::anyhow!(
"peer {} have no valid listener",
dst_peer_id
))?;
// if have default listener, use it first
listener = available_listeners
.iter()
.find(|l| l.scheme() == data.global_ctx.get_flags().default_protocol)
.unwrap_or(listener);
let mut tasks = JoinSet::new();
ip_list.interface_ipv4s.iter().for_each(|ip| {
let addr = format!(
"{}://{}:{}",
listener.scheme(),
ip,
listener.port().unwrap_or(11010)
);
tasks.spawn(Self::try_connect_to_ip(
data.clone(),
dst_peer_id.clone(),
addr,
));
});
let addr = format!(
"{}://{}:{}",
listener.scheme(),
ip_list.public_ipv4.clone(),
listener.port().unwrap_or(11010)
);
tasks.spawn(Self::try_connect_to_ip(
data.clone(),
dst_peer_id.clone(),
addr,
));
let mut has_succ = false;
while let Some(ret) = tasks.join_next().await {
if let Err(e) = ret {
log::error!("join direct connect task failed: {:?}", e);
} else if let Ok(Ok(_)) = ret {
has_succ = true;
}
}
if !has_succ {
data.dst_sceme_blacklist.insert(
DstSchemeBlackListItem(dst_peer_id.clone(), listener.scheme().to_string()),
(),
std::time::Duration::from_secs(DIRECT_CONNECTOR_BLACKLIST_TIMEOUT_SEC),
);
}
Ok(())
}
#[tracing::instrument]
async fn do_try_direct_connect(
data: Arc<DirectConnectorManagerData>,
dst_peer_id: PeerId,
) -> Result<(), Error> {
let peer_manager = data.peer_manager.clone();
// check if we have direct connection with dst_peer_id
if let Some(c) = peer_manager.list_peer_conns(dst_peer_id).await {
// currently if we have any type of direct connection (udp or tcp), we will not try to connect
if !c.is_empty() {
return Ok(());
}
}
log::trace!("try direct connect to peer: {}", dst_peer_id);
let ip_list = peer_manager
.get_peer_rpc_mgr()
.do_client_rpc_scoped(1, dst_peer_id, |c| async {
let client =
DirectConnectorRpcClient::new(tarpc::client::Config::default(), c).spawn();
let ip_list = client.get_ip_list(tarpc::context::current()).await;
tracing::info!(ip_list = ?ip_list, dst_peer_id = ?dst_peer_id, "got ip list");
ip_list
})
.await?;
Self::do_try_direct_connect_internal(data, dst_peer_id, ip_list).await
}
}
#[cfg(test)]
mod tests {
use std::sync::Arc;
use crate::{
connector::direct::{
DirectConnectorManager, DirectConnectorManagerData, DstBlackListItem,
DstSchemeBlackListItem,
},
instance::listeners::ListenerManager,
peers::tests::{
connect_peer_manager, create_mock_peer_manager, wait_route_appear,
wait_route_appear_with_cost,
},
rpc::peer::GetIpListResponse,
};
#[rstest::rstest]
#[tokio::test]
async fn direct_connector_basic_test(#[values("tcp", "udp", "wg")] proto: &str) {
let p_a = create_mock_peer_manager().await;
let p_b = create_mock_peer_manager().await;
let p_c = create_mock_peer_manager().await;
connect_peer_manager(p_a.clone(), p_b.clone()).await;
connect_peer_manager(p_b.clone(), p_c.clone()).await;
wait_route_appear(p_a.clone(), p_c.clone()).await.unwrap();
let mut dm_a = DirectConnectorManager::new(p_a.get_global_ctx(), p_a.clone());
let mut dm_c = DirectConnectorManager::new(p_c.get_global_ctx(), p_c.clone());
dm_a.run_as_client();
dm_c.run_as_server();
let port = if proto == "wg" { 11040 } else { 11041 };
p_c.get_global_ctx()
.config
.set_listeners(vec![format!("{}://0.0.0.0:{}", proto, port)
.parse()
.unwrap()]);
let mut lis_c = ListenerManager::new(p_c.get_global_ctx(), p_c.clone());
lis_c.prepare_listeners().await.unwrap();
lis_c.run().await.unwrap();
wait_route_appear_with_cost(p_a.clone(), p_c.my_peer_id(), Some(1))
.await
.unwrap();
}
#[tokio::test]
async fn direct_connector_scheme_blacklist() {
let p_a = create_mock_peer_manager().await;
let data = Arc::new(DirectConnectorManagerData::new(
p_a.get_global_ctx(),
p_a.clone(),
));
let mut ip_list = GetIpListResponse::new();
ip_list
.listeners
.push("tcp://127.0.0.1:10222".parse().unwrap());
ip_list.interface_ipv4s.push("127.0.0.1".to_string());
DirectConnectorManager::do_try_direct_connect_internal(data.clone(), 1, ip_list.clone())
.await
.unwrap();
assert!(data
.dst_sceme_blacklist
.contains(&DstSchemeBlackListItem(1, "tcp".into())));
assert!(data
.dst_blacklist
.contains(&DstBlackListItem(1, ip_list.listeners[0].to_string())));
}
}
+390
View File
@@ -0,0 +1,390 @@
use std::{collections::BTreeSet, sync::Arc};
use dashmap::{DashMap, DashSet};
use tokio::{
sync::{broadcast::Receiver, mpsc, Mutex},
task::JoinSet,
time::timeout,
};
use crate::{common::PeerId, peers::peer_conn::PeerConnId, rpc as easytier_rpc};
use crate::{
common::{
error::Error,
global_ctx::{ArcGlobalCtx, GlobalCtxEvent},
netns::NetNS,
},
connector::set_bind_addr_for_peer_connector,
peers::peer_manager::PeerManager,
rpc::{
connector_manage_rpc_server::ConnectorManageRpc, Connector, ConnectorStatus,
ListConnectorRequest, ManageConnectorRequest,
},
tunnels::{Tunnel, TunnelConnector},
use_global_var,
};
use super::create_connector_by_url;
type ConnectorMap = Arc<DashMap<String, Box<dyn TunnelConnector + Send + Sync>>>;
#[derive(Debug, Clone)]
struct ReconnResult {
dead_url: String,
peer_id: PeerId,
conn_id: PeerConnId,
}
struct ConnectorManagerData {
connectors: ConnectorMap,
reconnecting: DashSet<String>,
peer_manager: Arc<PeerManager>,
alive_conn_urls: Arc<Mutex<BTreeSet<String>>>,
// user removed connector urls
removed_conn_urls: Arc<DashSet<String>>,
net_ns: NetNS,
global_ctx: ArcGlobalCtx,
}
pub struct ManualConnectorManager {
global_ctx: ArcGlobalCtx,
data: Arc<ConnectorManagerData>,
tasks: JoinSet<()>,
}
impl ManualConnectorManager {
pub fn new(global_ctx: ArcGlobalCtx, peer_manager: Arc<PeerManager>) -> Self {
let connectors = Arc::new(DashMap::new());
let tasks = JoinSet::new();
let event_subscriber = global_ctx.subscribe();
let mut ret = Self {
global_ctx: global_ctx.clone(),
data: Arc::new(ConnectorManagerData {
connectors,
reconnecting: DashSet::new(),
peer_manager,
alive_conn_urls: Arc::new(Mutex::new(BTreeSet::new())),
removed_conn_urls: Arc::new(DashSet::new()),
net_ns: global_ctx.net_ns.clone(),
global_ctx,
}),
tasks,
};
ret.tasks
.spawn(Self::conn_mgr_routine(ret.data.clone(), event_subscriber));
ret
}
pub fn add_connector<T>(&self, connector: T)
where
T: TunnelConnector + Send + Sync + 'static,
{
log::info!("add_connector: {}", connector.remote_url());
self.data
.connectors
.insert(connector.remote_url().into(), Box::new(connector));
}
pub async fn add_connector_by_url(&self, url: &str) -> Result<(), Error> {
self.add_connector(create_connector_by_url(url, &self.global_ctx).await?);
Ok(())
}
pub async fn remove_connector(&self, url: &str) -> Result<(), Error> {
log::info!("remove_connector: {}", url);
if !self.list_connectors().await.iter().any(|x| x.url == url) {
return Err(Error::NotFound);
}
self.data.removed_conn_urls.insert(url.into());
Ok(())
}
pub async fn list_connectors(&self) -> Vec<Connector> {
let conn_urls: BTreeSet<String> = self
.data
.connectors
.iter()
.map(|x| x.key().clone().into())
.collect();
let dead_urls: BTreeSet<String> = Self::collect_dead_conns(self.data.clone())
.await
.into_iter()
.collect();
let mut ret = Vec::new();
for conn_url in conn_urls {
let mut status = ConnectorStatus::Connected;
if dead_urls.contains(&conn_url) {
status = ConnectorStatus::Disconnected;
}
ret.insert(
0,
Connector {
url: conn_url,
status: status.into(),
},
);
}
let reconnecting_urls: BTreeSet<String> = self
.data
.reconnecting
.iter()
.map(|x| x.clone().into())
.collect();
for conn_url in reconnecting_urls {
ret.insert(
0,
Connector {
url: conn_url,
status: ConnectorStatus::Connecting.into(),
},
);
}
ret
}
async fn conn_mgr_routine(
data: Arc<ConnectorManagerData>,
mut event_recv: Receiver<GlobalCtxEvent>,
) {
log::warn!("conn_mgr_routine started");
let mut reconn_interval = tokio::time::interval(std::time::Duration::from_millis(
use_global_var!(MANUAL_CONNECTOR_RECONNECT_INTERVAL_MS),
));
let mut reconn_tasks = JoinSet::new();
let (reconn_result_send, mut reconn_result_recv) = mpsc::channel(100);
loop {
tokio::select! {
event = event_recv.recv() => {
if let Ok(event) = event {
Self::handle_event(&event, data.clone()).await;
} else {
log::warn!("event_recv closed");
panic!("event_recv closed");
}
}
_ = reconn_interval.tick() => {
let dead_urls = Self::collect_dead_conns(data.clone()).await;
if dead_urls.is_empty() {
continue;
}
for dead_url in dead_urls {
let data_clone = data.clone();
let sender = reconn_result_send.clone();
let (_, connector) = data.connectors.remove(&dead_url).unwrap();
let insert_succ = data.reconnecting.insert(dead_url.clone());
assert!(insert_succ);
reconn_tasks.spawn(async move {
sender.send(Self::conn_reconnect(data_clone.clone(), dead_url, connector).await).await.unwrap();
});
}
log::info!("reconn_interval tick, done");
}
ret = reconn_result_recv.recv() => {
log::warn!("reconn_tasks done, out: {:?}", ret);
let _ = reconn_tasks.join_next().await.unwrap();
}
}
}
}
async fn handle_event(event: &GlobalCtxEvent, data: Arc<ConnectorManagerData>) {
match event {
GlobalCtxEvent::PeerConnAdded(conn_info) => {
let addr = conn_info.tunnel.as_ref().unwrap().remote_addr.clone();
data.alive_conn_urls.lock().await.insert(addr);
log::warn!("peer conn added: {:?}", conn_info);
}
GlobalCtxEvent::PeerConnRemoved(conn_info) => {
let addr = conn_info.tunnel.as_ref().unwrap().remote_addr.clone();
data.alive_conn_urls.lock().await.remove(&addr);
log::warn!("peer conn removed: {:?}", conn_info);
}
_ => {}
}
}
fn handle_remove_connector(data: Arc<ConnectorManagerData>) {
let remove_later = DashSet::new();
for it in data.removed_conn_urls.iter() {
let url = it.key();
if let Some(_) = data.connectors.remove(url) {
log::warn!("connector: {}, removed", url);
continue;
} else if data.reconnecting.contains(url) {
log::warn!("connector: {}, reconnecting, remove later.", url);
remove_later.insert(url.clone());
continue;
} else {
log::warn!("connector: {}, not found", url);
}
}
data.removed_conn_urls.clear();
for it in remove_later.iter() {
data.removed_conn_urls.insert(it.key().clone());
}
}
async fn collect_dead_conns(data: Arc<ConnectorManagerData>) -> BTreeSet<String> {
Self::handle_remove_connector(data.clone());
let curr_alive = data.alive_conn_urls.lock().await.clone();
let all_urls: BTreeSet<String> = data
.connectors
.iter()
.map(|x| x.key().clone().into())
.collect();
&all_urls - &curr_alive
}
async fn conn_reconnect(
data: Arc<ConnectorManagerData>,
dead_url: String,
connector: Box<dyn TunnelConnector + Send + Sync>,
) -> Result<ReconnResult, Error> {
let connector = Arc::new(Mutex::new(Some(connector)));
let net_ns = data.net_ns.clone();
log::info!("reconnect: {}", dead_url);
let connector_clone = connector.clone();
let data_clone = data.clone();
let url_clone = dead_url.clone();
let ip_collector = data.global_ctx.get_ip_collector();
let reconn_task = async move {
let mut locked = connector_clone.lock().await;
let conn = locked.as_mut().unwrap();
// TODO: should support set v6 here, use url in connector array
set_bind_addr_for_peer_connector(conn, true, &ip_collector).await;
data_clone
.global_ctx
.issue_event(GlobalCtxEvent::Connecting(conn.remote_url().clone()));
let _g = net_ns.guard();
log::info!("reconnect try connect... conn: {:?}", conn);
let tunnel = conn.connect().await?;
log::info!("reconnect get tunnel succ: {:?}", tunnel);
assert_eq!(
url_clone,
tunnel.info().unwrap().remote_addr,
"info: {:?}",
tunnel.info()
);
let (peer_id, conn_id) = data_clone.peer_manager.add_client_tunnel(tunnel).await?;
log::info!("reconnect succ: {} {} {}", peer_id, conn_id, url_clone);
Ok(ReconnResult {
dead_url: url_clone,
peer_id,
conn_id,
})
};
let ret = timeout(std::time::Duration::from_secs(1), reconn_task).await;
log::info!("reconnect: {} done, ret: {:?}", dead_url, ret);
if ret.is_err() || ret.as_ref().unwrap().is_err() {
data.global_ctx.issue_event(GlobalCtxEvent::ConnectError(
dead_url.clone(),
format!("{:?}", ret),
));
}
let conn = connector.lock().await.take().unwrap();
data.reconnecting.remove(&dead_url).unwrap();
data.connectors.insert(dead_url.clone(), conn);
ret?
}
}
pub struct ConnectorManagerRpcService(pub Arc<ManualConnectorManager>);
#[tonic::async_trait]
impl ConnectorManageRpc for ConnectorManagerRpcService {
async fn list_connector(
&self,
_request: tonic::Request<ListConnectorRequest>,
) -> Result<tonic::Response<easytier_rpc::ListConnectorResponse>, tonic::Status> {
let mut ret = easytier_rpc::ListConnectorResponse::default();
let connectors = self.0.list_connectors().await;
ret.connectors = connectors;
Ok(tonic::Response::new(ret))
}
async fn manage_connector(
&self,
request: tonic::Request<ManageConnectorRequest>,
) -> Result<tonic::Response<easytier_rpc::ManageConnectorResponse>, tonic::Status> {
let req = request.into_inner();
let url = url::Url::parse(&req.url)
.map_err(|_| tonic::Status::invalid_argument("invalid url"))?;
if req.action == easytier_rpc::ConnectorManageAction::Remove as i32 {
self.0.remove_connector(url.path()).await.map_err(|e| {
tonic::Status::invalid_argument(format!("remove connector failed: {:?}", e))
})?;
return Ok(tonic::Response::new(
easytier_rpc::ManageConnectorResponse::default(),
));
} else {
self.0
.add_connector_by_url(url.as_str())
.await
.map_err(|e| {
tonic::Status::invalid_argument(format!("add connector failed: {:?}", e))
})?;
}
Ok(tonic::Response::new(
easytier_rpc::ManageConnectorResponse::default(),
))
}
}
#[cfg(test)]
mod tests {
use crate::{
peers::tests::create_mock_peer_manager,
set_global_var,
tunnels::{Tunnel, TunnelError},
};
use super::*;
#[tokio::test]
async fn test_reconnect_with_connecting_addr() {
set_global_var!(MANUAL_CONNECTOR_RECONNECT_INTERVAL_MS, 1);
let peer_mgr = create_mock_peer_manager().await;
let mgr = ManualConnectorManager::new(peer_mgr.get_global_ctx(), peer_mgr);
struct MockConnector {}
#[async_trait::async_trait]
impl TunnelConnector for MockConnector {
fn remote_url(&self) -> url::Url {
url::Url::parse("tcp://aa.com").unwrap()
}
async fn connect(&mut self) -> Result<Box<dyn Tunnel>, TunnelError> {
tokio::time::sleep(std::time::Duration::from_millis(10)).await;
Err(TunnelError::CommonError("fake error".into()))
}
}
mgr.add_connector(MockConnector {});
tokio::time::sleep(std::time::Duration::from_secs(5)).await;
}
}
+99
View File
@@ -0,0 +1,99 @@
use std::{
net::{SocketAddr, SocketAddrV4, SocketAddrV6},
sync::Arc,
};
use crate::{
common::{error::Error, global_ctx::ArcGlobalCtx, network::IPCollector},
tunnels::{
ring_tunnel::RingTunnelConnector,
tcp_tunnel::TcpTunnelConnector,
udp_tunnel::UdpTunnelConnector,
wireguard::{WgConfig, WgTunnelConnector},
TunnelConnector,
},
};
pub mod direct;
pub mod manual;
pub mod udp_hole_punch;
async fn set_bind_addr_for_peer_connector(
connector: &mut impl TunnelConnector,
is_ipv4: bool,
ip_collector: &Arc<IPCollector>,
) {
let ips = ip_collector.collect_ip_addrs().await;
if is_ipv4 {
let mut bind_addrs = vec![];
for ipv4 in ips.interface_ipv4s {
let socket_addr = SocketAddrV4::new(ipv4.parse().unwrap(), 0).into();
bind_addrs.push(socket_addr);
}
connector.set_bind_addrs(bind_addrs);
} else {
let mut bind_addrs = vec![];
for ipv6 in ips.interface_ipv6s {
let socket_addr = SocketAddrV6::new(ipv6.parse().unwrap(), 0, 0, 0).into();
bind_addrs.push(socket_addr);
}
connector.set_bind_addrs(bind_addrs);
}
let _ = connector;
}
pub async fn create_connector_by_url(
url: &str,
global_ctx: &ArcGlobalCtx,
) -> Result<Box<dyn TunnelConnector + Send + Sync + 'static>, Error> {
let url = url::Url::parse(url).map_err(|_| Error::InvalidUrl(url.to_owned()))?;
match url.scheme() {
"tcp" => {
let dst_addr =
crate::tunnels::check_scheme_and_get_socket_addr::<SocketAddr>(&url, "tcp")?;
let mut connector = TcpTunnelConnector::new(url);
set_bind_addr_for_peer_connector(
&mut connector,
dst_addr.is_ipv4(),
&global_ctx.get_ip_collector(),
)
.await;
return Ok(Box::new(connector));
}
"udp" => {
let dst_addr =
crate::tunnels::check_scheme_and_get_socket_addr::<SocketAddr>(&url, "udp")?;
let mut connector = UdpTunnelConnector::new(url);
set_bind_addr_for_peer_connector(
&mut connector,
dst_addr.is_ipv4(),
&global_ctx.get_ip_collector(),
)
.await;
return Ok(Box::new(connector));
}
"ring" => {
crate::tunnels::check_scheme_and_get_socket_addr::<uuid::Uuid>(&url, "ring")?;
let connector = RingTunnelConnector::new(url);
return Ok(Box::new(connector));
}
"wg" => {
let dst_addr =
crate::tunnels::check_scheme_and_get_socket_addr::<SocketAddr>(&url, "wg")?;
let nid = global_ctx.get_network_identity();
let wg_config =
WgConfig::new_from_network_identity(&nid.network_name, &nid.network_secret);
let mut connector = WgTunnelConnector::new(url, wg_config);
set_bind_addr_for_peer_connector(
&mut connector,
dst_addr.is_ipv4(),
&global_ctx.get_ip_collector(),
)
.await;
return Ok(Box::new(connector));
}
_ => {
return Err(Error::InvalidUrl(url.into()));
}
}
}
+545
View File
@@ -0,0 +1,545 @@
use std::{net::SocketAddr, sync::Arc};
use anyhow::Context;
use crossbeam::atomic::AtomicCell;
use rand::{seq::SliceRandom, Rng, SeedableRng};
use tokio::{net::UdpSocket, sync::Mutex, task::JoinSet};
use tracing::Instrument;
use crate::{
common::{
constants, error::Error, global_ctx::ArcGlobalCtx, join_joinset_background,
rkyv_util::encode_to_bytes, stun::StunInfoCollectorTrait, PeerId,
},
peers::peer_manager::PeerManager,
rpc::NatType,
tunnels::{
common::setup_sokcet2,
udp_tunnel::{UdpPacket, UdpTunnelConnector, UdpTunnelListener},
Tunnel, TunnelConnCounter, TunnelListener,
},
};
use super::direct::PeerManagerForDirectConnector;
#[tarpc::service]
pub trait UdpHolePunchService {
async fn try_punch_hole(local_mapped_addr: SocketAddr) -> Option<SocketAddr>;
}
#[derive(Debug)]
struct UdpHolePunchListener {
socket: Arc<UdpSocket>,
tasks: JoinSet<()>,
running: Arc<AtomicCell<bool>>,
mapped_addr: SocketAddr,
conn_counter: Arc<Box<dyn TunnelConnCounter>>,
listen_time: std::time::Instant,
last_select_time: AtomicCell<std::time::Instant>,
last_connected_time: Arc<AtomicCell<std::time::Instant>>,
}
impl UdpHolePunchListener {
async fn get_avail_port() -> Result<u16, Error> {
let socket = UdpSocket::bind("0.0.0.0:0").await?;
Ok(socket.local_addr()?.port())
}
pub async fn new(peer_mgr: Arc<PeerManager>) -> Result<Self, Error> {
let port = Self::get_avail_port().await?;
let listen_url = format!("udp://0.0.0.0:{}", port);
let gctx = peer_mgr.get_global_ctx();
let stun_info_collect = gctx.get_stun_info_collector();
let mapped_addr = stun_info_collect.get_udp_port_mapping(port).await?;
let mut listener = UdpTunnelListener::new(listen_url.parse().unwrap());
{
let _g = peer_mgr.get_global_ctx().net_ns.guard();
listener.listen().await?;
}
let socket = listener.get_socket().unwrap();
let running = Arc::new(AtomicCell::new(true));
let running_clone = running.clone();
let last_connected_time = Arc::new(AtomicCell::new(std::time::Instant::now()));
let last_connected_time_clone = last_connected_time.clone();
let conn_counter = listener.get_conn_counter();
let mut tasks = JoinSet::new();
tasks.spawn(async move {
while let Ok(conn) = listener.accept().await {
last_connected_time_clone.store(std::time::Instant::now());
tracing::warn!(?conn, "udp hole punching listener got peer connection");
let peer_mgr = peer_mgr.clone();
tokio::spawn(async move {
if let Err(e) = peer_mgr.add_tunnel_as_server(conn).await {
tracing::error!(
?e,
"failed to add tunnel as server in hole punch listener"
);
}
});
}
running_clone.store(false);
});
tracing::warn!(?mapped_addr, ?socket, "udp hole punching listener started");
Ok(Self {
tasks,
socket,
running,
mapped_addr,
conn_counter,
listen_time: std::time::Instant::now(),
last_select_time: AtomicCell::new(std::time::Instant::now()),
last_connected_time,
})
}
pub async fn get_socket(&self) -> Arc<UdpSocket> {
self.last_select_time.store(std::time::Instant::now());
self.socket.clone()
}
}
#[derive(Debug)]
struct UdpHolePunchConnectorData {
global_ctx: ArcGlobalCtx,
peer_mgr: Arc<PeerManager>,
listeners: Arc<Mutex<Vec<UdpHolePunchListener>>>,
}
#[derive(Clone)]
struct UdpHolePunchRpcServer {
data: Arc<UdpHolePunchConnectorData>,
tasks: Arc<std::sync::Mutex<JoinSet<()>>>,
}
#[tarpc::server]
impl UdpHolePunchService for UdpHolePunchRpcServer {
async fn try_punch_hole(
self,
_: tarpc::context::Context,
local_mapped_addr: SocketAddr,
) -> Option<SocketAddr> {
let (socket, mapped_addr) = self.select_listener().await?;
tracing::warn!(?local_mapped_addr, ?mapped_addr, "start hole punching");
let my_udp_nat_type = self
.data
.global_ctx
.get_stun_info_collector()
.get_stun_info()
.udp_nat_type;
// if we are restricted, we need to send hole punching resp to client
if my_udp_nat_type == NatType::PortRestricted as i32
|| my_udp_nat_type == NatType::Restricted as i32
{
// send punch msg to local_mapped_addr for 3 seconds, 3.3 packet per second
self.tasks.lock().unwrap().spawn(async move {
for _ in 0..10 {
tracing::info!(?local_mapped_addr, "sending hole punching packet");
// generate a 128 bytes vec with random data
let mut rng = rand::rngs::StdRng::from_entropy();
let mut buf = vec![0u8; 128];
rng.fill(&mut buf[..]);
let udp_packet = UdpPacket::new_hole_punch_packet(buf);
let udp_packet_bytes = encode_to_bytes::<_, 256>(&udp_packet);
let _ = socket
.send_to(udp_packet_bytes.as_ref(), local_mapped_addr)
.await;
tokio::time::sleep(std::time::Duration::from_millis(300)).await;
}
});
}
Some(mapped_addr)
}
}
impl UdpHolePunchRpcServer {
pub fn new(data: Arc<UdpHolePunchConnectorData>) -> Self {
let tasks = Arc::new(std::sync::Mutex::new(JoinSet::new()));
join_joinset_background(tasks.clone(), "UdpHolePunchRpcServer".to_owned());
Self { data, tasks }
}
async fn select_listener(&self) -> Option<(Arc<UdpSocket>, SocketAddr)> {
let all_listener_sockets = &self.data.listeners;
// remove listener that not have connection in for 20 seconds
all_listener_sockets.lock().await.retain(|listener| {
listener.last_connected_time.load().elapsed().as_secs() < 20
&& listener.conn_counter.get() > 0
});
let mut use_last = false;
if all_listener_sockets.lock().await.len() < 4 {
tracing::warn!("creating new udp hole punching listener");
all_listener_sockets.lock().await.push(
UdpHolePunchListener::new(self.data.peer_mgr.clone())
.await
.ok()?,
);
use_last = true;
}
let locked = all_listener_sockets.lock().await;
let listener = if use_last {
locked.last()?
} else {
locked.choose(&mut rand::rngs::StdRng::from_entropy())?
};
Some((listener.get_socket().await, listener.mapped_addr))
}
}
pub struct UdpHolePunchConnector {
data: Arc<UdpHolePunchConnectorData>,
tasks: JoinSet<()>,
}
// Currently support:
// Symmetric -> Full Cone
// Any Type of Full Cone -> Any Type of Full Cone
// if same level of full cone, node with smaller peer_id will be the initiator
// if different level of full cone, node with more strict level will be the initiator
impl UdpHolePunchConnector {
pub fn new(global_ctx: ArcGlobalCtx, peer_mgr: Arc<PeerManager>) -> Self {
Self {
data: Arc::new(UdpHolePunchConnectorData {
global_ctx,
peer_mgr,
listeners: Arc::new(Mutex::new(Vec::new())),
}),
tasks: JoinSet::new(),
}
}
pub async fn run_as_client(&mut self) -> Result<(), Error> {
let data = self.data.clone();
self.tasks.spawn(async move {
Self::main_loop(data).await;
});
Ok(())
}
pub async fn run_as_server(&mut self) -> Result<(), Error> {
self.data.peer_mgr.get_peer_rpc_mgr().run_service(
constants::UDP_HOLE_PUNCH_CONNECTOR_SERVICE_ID,
UdpHolePunchRpcServer::new(self.data.clone()).serve(),
);
Ok(())
}
pub async fn run(&mut self) -> Result<(), Error> {
self.run_as_client().await?;
self.run_as_server().await?;
Ok(())
}
async fn collect_peer_to_connect(data: Arc<UdpHolePunchConnectorData>) -> Vec<PeerId> {
let mut peers_to_connect = Vec::new();
// do not do anything if:
// 1. our stun test has not finished
// 2. our nat type is OpenInternet or NoPat, which means we can wait other peers to connect us
let my_nat_type = data
.global_ctx
.get_stun_info_collector()
.get_stun_info()
.udp_nat_type;
let my_nat_type = NatType::try_from(my_nat_type).unwrap();
if my_nat_type == NatType::Unknown
|| my_nat_type == NatType::OpenInternet
|| my_nat_type == NatType::NoPat
{
return peers_to_connect;
}
// collect peer list from peer manager and do some filter:
// 1. peers without direct conns;
// 2. peers is full cone (any restricted type);
for route in data.peer_mgr.list_routes().await.iter() {
let Some(peer_stun_info) = route.stun_info.as_ref() else {
continue;
};
let Ok(peer_nat_type) = NatType::try_from(peer_stun_info.udp_nat_type) else {
continue;
};
let peer_id: PeerId = route.peer_id;
let conns = data.peer_mgr.list_peer_conns(peer_id).await;
if conns.is_some() && conns.unwrap().len() > 0 {
continue;
}
// if peer is symmetric ignore it because we cannot connect to it
// if peer is open internet or no pat, direct connector will connecto to it
if peer_nat_type == NatType::Unknown
|| peer_nat_type == NatType::OpenInternet
|| peer_nat_type == NatType::NoPat
|| peer_nat_type == NatType::Symmetric
|| peer_nat_type == NatType::SymUdpFirewall
{
continue;
}
// if we are symmetric, we can only connect to full cone
// TODO: can also connect to restricted full cone, with some extra work
if (my_nat_type == NatType::Symmetric || my_nat_type == NatType::SymUdpFirewall)
&& peer_nat_type != NatType::FullCone
{
continue;
}
// if we have smae level of full cone, node with smaller peer_id will be the initiator
if my_nat_type == peer_nat_type {
if data.peer_mgr.my_peer_id() > peer_id {
continue;
}
} else {
// if we have different level of full cone
// we will be the initiator if we have more strict level
if my_nat_type < peer_nat_type {
continue;
}
}
tracing::info!(
?peer_id,
?peer_nat_type,
?my_nat_type,
?data.global_ctx.id,
"found peer to do hole punching"
);
peers_to_connect.push(peer_id);
}
peers_to_connect
}
#[tracing::instrument]
async fn do_hole_punching(
data: Arc<UdpHolePunchConnectorData>,
dst_peer_id: PeerId,
) -> Result<Box<dyn Tunnel>, anyhow::Error> {
tracing::info!(?dst_peer_id, "start hole punching");
// client: choose a local udp port, and get the pubic mapped port from stun server
let socket = {
let _g = data.global_ctx.net_ns.guard();
UdpSocket::bind("0.0.0.0:0").await.with_context(|| "")?
};
let local_socket_addr = socket.local_addr()?;
let local_port = socket.local_addr()?.port();
drop(socket); // drop the socket to release the port
let local_mapped_addr = data
.global_ctx
.get_stun_info_collector()
.get_udp_port_mapping(local_port)
.await
.with_context(|| "failed to get udp port mapping")?;
// client -> server: tell server the mapped port, server will return the mapped address of listening port.
let Some(remote_mapped_addr) = data
.peer_mgr
.get_peer_rpc_mgr()
.do_client_rpc_scoped(
constants::UDP_HOLE_PUNCH_CONNECTOR_SERVICE_ID,
dst_peer_id,
|c| async {
let client =
UdpHolePunchServiceClient::new(tarpc::client::Config::default(), c).spawn();
let remote_mapped_addr = client
.try_punch_hole(tarpc::context::current(), local_mapped_addr)
.await;
tracing::info!(?remote_mapped_addr, ?dst_peer_id, "got remote mapped addr");
remote_mapped_addr
},
)
.await?
else {
return Err(anyhow::anyhow!("failed to get remote mapped addr"));
};
// server: will send some punching resps, total 10 packets.
// client: use the socket to create UdpTunnel with UdpTunnelConnector
// NOTICE: UdpTunnelConnector will ignore the punching resp packet sent by remote.
let connector = UdpTunnelConnector::new(
format!(
"udp://{}:{}",
remote_mapped_addr.ip(),
remote_mapped_addr.port()
)
.to_string()
.parse()
.unwrap(),
);
let _g = data.global_ctx.net_ns.guard();
let socket2_socket = socket2::Socket::new(
socket2::Domain::for_address(local_socket_addr),
socket2::Type::DGRAM,
Some(socket2::Protocol::UDP),
)?;
setup_sokcet2(&socket2_socket, &local_socket_addr)?;
let socket = UdpSocket::from_std(socket2_socket.into())?;
Ok(connector
.try_connect_with_socket(socket)
.await
.with_context(|| "UdpTunnelConnector failed to connect remote")?)
}
async fn main_loop(data: Arc<UdpHolePunchConnectorData>) {
loop {
let peers_to_connect = Self::collect_peer_to_connect(data.clone()).await;
tracing::trace!(?peers_to_connect, "peers to connect");
if peers_to_connect.len() == 0 {
tokio::time::sleep(std::time::Duration::from_secs(5)).await;
continue;
}
let mut tasks: JoinSet<Result<(), anyhow::Error>> = JoinSet::new();
for peer_id in peers_to_connect {
let data = data.clone();
tasks.spawn(
async move {
let tunnel = Self::do_hole_punching(data.clone(), peer_id)
.await
.with_context(|| "failed to do hole punching")?;
let _ =
data.peer_mgr
.add_client_tunnel(tunnel)
.await
.with_context(|| {
"failed to add tunnel as client in hole punch connector"
})?;
Ok(())
}
.instrument(tracing::info_span!("doing hole punching client", ?peer_id)),
);
}
while let Some(res) = tasks.join_next().await {
if let Err(e) = res {
tracing::error!(?e, "failed to join hole punching job");
continue;
}
match res.unwrap() {
Err(e) => {
tracing::error!(?e, "failed to do hole punching job");
}
Ok(_) => {
tracing::info!("hole punching job succeed");
}
}
}
tokio::time::sleep(std::time::Duration::from_secs(10)).await;
}
}
}
#[cfg(test)]
pub mod tests {
use std::sync::Arc;
use crate::rpc::{NatType, StunInfo};
use crate::{
common::{error::Error, stun::StunInfoCollectorTrait},
connector::udp_hole_punch::UdpHolePunchConnector,
peers::{
peer_manager::PeerManager,
tests::{
connect_peer_manager, create_mock_peer_manager, wait_route_appear,
wait_route_appear_with_cost,
},
},
};
struct MockStunInfoCollector {
udp_nat_type: NatType,
}
#[async_trait::async_trait]
impl StunInfoCollectorTrait for MockStunInfoCollector {
fn get_stun_info(&self) -> StunInfo {
StunInfo {
udp_nat_type: self.udp_nat_type as i32,
tcp_nat_type: NatType::Unknown as i32,
last_update_time: std::time::Instant::now().elapsed().as_secs() as i64,
}
}
async fn get_udp_port_mapping(&self, port: u16) -> Result<std::net::SocketAddr, Error> {
Ok(format!("127.0.0.1:{}", port).parse().unwrap())
}
}
pub fn replace_stun_info_collector(peer_mgr: Arc<PeerManager>, udp_nat_type: NatType) {
let collector = Box::new(MockStunInfoCollector { udp_nat_type });
peer_mgr
.get_global_ctx()
.replace_stun_info_collector(collector);
}
pub async fn create_mock_peer_manager_with_mock_stun(
udp_nat_type: NatType,
) -> Arc<PeerManager> {
let p_a = create_mock_peer_manager().await;
replace_stun_info_collector(p_a.clone(), udp_nat_type);
p_a
}
#[tokio::test]
async fn hole_punching() {
let p_a = create_mock_peer_manager_with_mock_stun(NatType::PortRestricted).await;
let p_b = create_mock_peer_manager_with_mock_stun(NatType::Symmetric).await;
let p_c = create_mock_peer_manager_with_mock_stun(NatType::PortRestricted).await;
connect_peer_manager(p_a.clone(), p_b.clone()).await;
connect_peer_manager(p_b.clone(), p_c.clone()).await;
wait_route_appear(p_a.clone(), p_c.clone()).await.unwrap();
println!("{:?}", p_a.list_routes().await);
let mut hole_punching_a = UdpHolePunchConnector::new(p_a.get_global_ctx(), p_a.clone());
let mut hole_punching_c = UdpHolePunchConnector::new(p_c.get_global_ctx(), p_c.clone());
hole_punching_a.run().await.unwrap();
hole_punching_c.run().await.unwrap();
wait_route_appear_with_cost(p_a.clone(), p_c.my_peer_id(), Some(1))
.await
.unwrap();
println!("{:?}", p_a.list_routes().await);
}
}
+490
View File
@@ -0,0 +1,490 @@
#![allow(dead_code)]
use std::{net::SocketAddr, vec};
use clap::{command, Args, Parser, Subcommand};
use rpc::vpn_portal_rpc_client::VpnPortalRpcClient;
mod arch;
mod common;
mod rpc;
mod tunnels;
use crate::{
common::stun::{StunInfoCollector, UdpNatTypeDetector},
rpc::{
connector_manage_rpc_client::ConnectorManageRpcClient,
peer_center_rpc_client::PeerCenterRpcClient, peer_manage_rpc_client::PeerManageRpcClient,
*,
},
};
use humansize::format_size;
use tabled::settings::Style;
#[derive(Parser, Debug)]
#[command(author, version, about, long_about = None)]
struct Cli {
/// the instance name
#[arg(short = 'p', long, default_value = "127.0.0.1:15888")]
rpc_portal: SocketAddr,
#[command(subcommand)]
sub_command: SubCommand,
}
#[derive(Subcommand, Debug)]
enum SubCommand {
Peer(PeerArgs),
Connector(ConnectorArgs),
Stun,
Route,
PeerCenter,
VpnPortal,
}
#[derive(Args, Debug)]
struct PeerArgs {
#[arg(short, long)]
ipv4: Option<String>,
#[arg(short, long)]
peers: Vec<String>,
#[command(subcommand)]
sub_command: Option<PeerSubCommand>,
}
#[derive(Args, Debug)]
struct PeerListArgs {
#[arg(short, long)]
verbose: bool,
}
#[derive(Subcommand, Debug)]
enum PeerSubCommand {
Add,
Remove,
List(PeerListArgs),
}
#[derive(Args, Debug)]
struct ConnectorArgs {
#[arg(short, long)]
ipv4: Option<String>,
#[arg(short, long)]
peers: Vec<String>,
#[command(subcommand)]
sub_command: Option<ConnectorSubCommand>,
}
#[derive(Subcommand, Debug)]
enum ConnectorSubCommand {
Add,
Remove,
List,
}
#[derive(thiserror::Error, Debug)]
enum Error {
#[error("tonic transport error")]
TonicTransportError(#[from] tonic::transport::Error),
#[error("tonic rpc error")]
TonicRpcError(#[from] tonic::Status),
}
#[derive(Debug)]
struct PeerRoutePair {
route: Route,
peer: Option<PeerInfo>,
}
impl PeerRoutePair {
fn get_latency_ms(&self) -> Option<f64> {
let mut ret = u64::MAX;
let p = self.peer.as_ref()?;
for conn in p.conns.iter() {
let Some(stats) = &conn.stats else {
continue;
};
ret = ret.min(stats.latency_us);
}
if ret == u64::MAX {
None
} else {
Some(f64::from(ret as u32) / 1000.0)
}
}
fn get_rx_bytes(&self) -> Option<u64> {
let mut ret = 0;
let p = self.peer.as_ref()?;
for conn in p.conns.iter() {
let Some(stats) = &conn.stats else {
continue;
};
ret += stats.rx_bytes;
}
if ret == 0 {
None
} else {
Some(ret)
}
}
fn get_tx_bytes(&self) -> Option<u64> {
let mut ret = 0;
let p = self.peer.as_ref()?;
for conn in p.conns.iter() {
let Some(stats) = &conn.stats else {
continue;
};
ret += stats.tx_bytes;
}
if ret == 0 {
None
} else {
Some(ret)
}
}
fn get_loss_rate(&self) -> Option<f64> {
let mut ret = 0.0;
let p = self.peer.as_ref()?;
for conn in p.conns.iter() {
ret += conn.loss_rate;
}
if ret == 0.0 {
None
} else {
Some(ret as f64)
}
}
fn get_conn_protos(&self) -> Option<Vec<String>> {
let mut ret = vec![];
let p = self.peer.as_ref()?;
for conn in p.conns.iter() {
let Some(tunnel_info) = &conn.tunnel else {
continue;
};
// insert if not exists
if !ret.contains(&tunnel_info.tunnel_type) {
ret.push(tunnel_info.tunnel_type.clone());
}
}
if ret.is_empty() {
None
} else {
Some(ret)
}
}
fn get_udp_nat_type(self: &Self) -> String {
let mut ret = NatType::Unknown;
if let Some(r) = &self.route.stun_info {
ret = NatType::try_from(r.udp_nat_type).unwrap();
}
format!("{:?}", ret)
}
}
struct CommandHandler {
addr: String,
}
impl CommandHandler {
async fn get_peer_manager_client(
&self,
) -> Result<PeerManageRpcClient<tonic::transport::Channel>, Error> {
Ok(PeerManageRpcClient::connect(self.addr.clone()).await?)
}
async fn get_connector_manager_client(
&self,
) -> Result<ConnectorManageRpcClient<tonic::transport::Channel>, Error> {
Ok(ConnectorManageRpcClient::connect(self.addr.clone()).await?)
}
async fn get_peer_center_client(
&self,
) -> Result<PeerCenterRpcClient<tonic::transport::Channel>, Error> {
Ok(PeerCenterRpcClient::connect(self.addr.clone()).await?)
}
async fn get_vpn_portal_client(
&self,
) -> Result<VpnPortalRpcClient<tonic::transport::Channel>, Error> {
Ok(VpnPortalRpcClient::connect(self.addr.clone()).await?)
}
async fn list_peers(&self) -> Result<ListPeerResponse, Error> {
let mut client = self.get_peer_manager_client().await?;
let request = tonic::Request::new(ListPeerRequest::default());
let response = client.list_peer(request).await?;
Ok(response.into_inner())
}
async fn list_routes(&self) -> Result<ListRouteResponse, Error> {
let mut client = self.get_peer_manager_client().await?;
let request = tonic::Request::new(ListRouteRequest::default());
let response = client.list_route(request).await?;
Ok(response.into_inner())
}
async fn list_peer_route_pair(&self) -> Result<Vec<PeerRoutePair>, Error> {
let mut peers = self.list_peers().await?.peer_infos;
let mut routes = self.list_routes().await?.routes;
let mut pairs: Vec<PeerRoutePair> = vec![];
for route in routes.iter_mut() {
let peer = peers.iter_mut().find(|peer| peer.peer_id == route.peer_id);
pairs.push(PeerRoutePair {
route: route.clone(),
peer: peer.cloned(),
});
}
Ok(pairs)
}
#[allow(dead_code)]
fn handle_peer_add(&self, _args: PeerArgs) {
println!("add peer");
}
#[allow(dead_code)]
fn handle_peer_remove(&self, _args: PeerArgs) {
println!("remove peer");
}
async fn handle_peer_list(&self, _args: &PeerArgs) -> Result<(), Error> {
#[derive(tabled::Tabled)]
struct PeerTableItem {
ipv4: String,
hostname: String,
cost: String,
lat_ms: String,
loss_rate: String,
rx_bytes: String,
tx_bytes: String,
tunnel_proto: String,
nat_type: String,
id: String,
}
fn cost_to_str(cost: i32) -> String {
if cost == 1 {
"p2p".to_string()
} else {
format!("relay({})", cost)
}
}
fn float_to_str(f: f64, precision: usize) -> String {
format!("{:.1$}", f, precision)
}
impl From<PeerRoutePair> for PeerTableItem {
fn from(p: PeerRoutePair) -> Self {
PeerTableItem {
ipv4: p.route.ipv4_addr.clone(),
hostname: p.route.hostname.clone(),
cost: cost_to_str(p.route.cost),
lat_ms: float_to_str(p.get_latency_ms().unwrap_or(0.0), 3),
loss_rate: float_to_str(p.get_loss_rate().unwrap_or(0.0), 3),
rx_bytes: format_size(p.get_rx_bytes().unwrap_or(0), humansize::DECIMAL),
tx_bytes: format_size(p.get_tx_bytes().unwrap_or(0), humansize::DECIMAL),
tunnel_proto: p.get_conn_protos().unwrap_or(vec![]).join(",").to_string(),
nat_type: p.get_udp_nat_type(),
id: p.route.peer_id.to_string(),
}
}
}
let mut items: Vec<PeerTableItem> = vec![];
let peer_routes = self.list_peer_route_pair().await?;
for p in peer_routes {
items.push(p.into());
}
println!(
"{}",
tabled::Table::new(items).with(Style::modern()).to_string()
);
Ok(())
}
async fn handle_route_list(&self) -> Result<(), Error> {
#[derive(tabled::Tabled)]
struct RouteTableItem {
ipv4: String,
hostname: String,
proxy_cidrs: String,
next_hop_ipv4: String,
next_hop_hostname: String,
next_hop_lat: f64,
cost: i32,
}
let mut items: Vec<RouteTableItem> = vec![];
let peer_routes = self.list_peer_route_pair().await?;
for p in peer_routes.iter() {
let Some(next_hop_pair) = peer_routes
.iter()
.find(|pair| pair.route.peer_id == p.route.next_hop_peer_id)
else {
continue;
};
if p.route.cost == 1 {
items.push(RouteTableItem {
ipv4: p.route.ipv4_addr.clone(),
hostname: p.route.hostname.clone(),
proxy_cidrs: p.route.proxy_cidrs.clone().join(",").to_string(),
next_hop_ipv4: "DIRECT".to_string(),
next_hop_hostname: "".to_string(),
next_hop_lat: next_hop_pair.get_latency_ms().unwrap_or(0.0),
cost: p.route.cost,
});
} else {
items.push(RouteTableItem {
ipv4: p.route.ipv4_addr.clone(),
hostname: p.route.hostname.clone(),
proxy_cidrs: p.route.proxy_cidrs.clone().join(",").to_string(),
next_hop_ipv4: next_hop_pair.route.ipv4_addr.clone(),
next_hop_hostname: next_hop_pair.route.hostname.clone(),
next_hop_lat: next_hop_pair.get_latency_ms().unwrap_or(0.0),
cost: p.route.cost,
});
}
}
println!(
"{}",
tabled::Table::new(items).with(Style::modern()).to_string()
);
Ok(())
}
async fn handle_connector_list(&self) -> Result<(), Error> {
let mut client = self.get_connector_manager_client().await?;
let request = tonic::Request::new(ListConnectorRequest::default());
let response = client.list_connector(request).await?;
println!("response: {:#?}", response.into_inner());
Ok(())
}
}
#[tokio::main]
#[tracing::instrument]
async fn main() -> Result<(), Error> {
let cli = Cli::parse();
let handler = CommandHandler {
addr: format!("http://{}:{}", cli.rpc_portal.ip(), cli.rpc_portal.port()),
};
match cli.sub_command {
SubCommand::Peer(peer_args) => match &peer_args.sub_command {
Some(PeerSubCommand::Add) => {
println!("add peer");
}
Some(PeerSubCommand::Remove) => {
println!("remove peer");
}
Some(PeerSubCommand::List(arg)) => {
if arg.verbose {
println!("{:#?}", handler.list_peer_route_pair().await?);
} else {
handler.handle_peer_list(&peer_args).await?;
}
}
None => {
handler.handle_peer_list(&peer_args).await?;
}
},
SubCommand::Connector(conn_args) => match conn_args.sub_command {
Some(ConnectorSubCommand::Add) => {
println!("add connector");
}
Some(ConnectorSubCommand::Remove) => {
println!("remove connector");
}
Some(ConnectorSubCommand::List) => {
handler.handle_connector_list().await?;
}
None => {
handler.handle_connector_list().await?;
}
},
SubCommand::Route => {
handler.handle_route_list().await?;
}
SubCommand::Stun => {
let stun = UdpNatTypeDetector::new(StunInfoCollector::get_default_servers());
println!("udp type: {:?}", stun.get_udp_nat_type(0).await);
}
SubCommand::PeerCenter => {
let mut peer_center_client = handler.get_peer_center_client().await?;
let resp = peer_center_client
.get_global_peer_map(GetGlobalPeerMapRequest::default())
.await?
.into_inner();
#[derive(tabled::Tabled)]
struct PeerCenterTableItem {
node_id: String,
direct_peers: String,
}
let mut table_rows = vec![];
for (k, v) in resp.global_peer_map.iter() {
let node_id = k;
let direct_peers = v
.direct_peers
.iter()
.map(|(k, v)| {
format!(
"{}:{:?}",
k,
LatencyLevel::try_from(v.latency_level).unwrap()
)
})
.collect::<Vec<_>>();
table_rows.push(PeerCenterTableItem {
node_id: node_id.to_string(),
direct_peers: direct_peers.join("\n"),
});
}
println!(
"{}",
tabled::Table::new(table_rows)
.with(Style::modern())
.to_string()
);
}
SubCommand::VpnPortal => {
let mut vpn_portal_client = handler.get_vpn_portal_client().await?;
let resp = vpn_portal_client
.get_vpn_portal_info(GetVpnPortalInfoRequest::default())
.await?
.into_inner()
.vpn_portal_info
.unwrap_or_default();
println!("portal_name: {}\n", resp.vpn_type);
println!("client_config:{}", resp.client_config);
println!("connected_clients:\n{:#?}", resp.connected_clients);
}
}
Ok(())
}
+428
View File
@@ -0,0 +1,428 @@
#![allow(dead_code)]
#[cfg(test)]
mod tests;
use std::{backtrace, io::Write as _, net::SocketAddr};
use anyhow::Context;
use clap::Parser;
mod arch;
mod common;
mod connector;
mod gateway;
mod instance;
mod peer_center;
mod peers;
mod rpc;
mod tunnels;
mod vpn_portal;
use common::{
config::{ConsoleLoggerConfig, FileLoggerConfig, NetworkIdentity, PeerConfig, VpnPortalConfig},
get_logger_timer_rfc3339,
};
use instance::instance::Instance;
use tracing::level_filters::LevelFilter;
use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt, EnvFilter, Layer};
use crate::common::{
config::{ConfigLoader, TomlConfigLoader},
global_ctx::GlobalCtxEvent,
};
#[derive(Parser, Debug)]
#[command(author, version, about, long_about = None)]
struct Cli {
#[arg(
long,
help = "network name to identify this vpn network",
default_value = "default"
)]
network_name: String,
#[arg(
long,
help = "network secret to verify this node belongs to the vpn network",
default_value = ""
)]
network_secret: String,
#[arg(short, long, help = "ipv4 address of this vpn node")]
ipv4: Option<String>,
#[arg(short, long, help = "peers to connect initially")]
peers: Vec<String>,
#[arg(short, long, help = "use a public shared node to discover peers")]
external_node: Option<String>,
#[arg(
short = 'n',
long,
help = "export local networks to other peers in the vpn"
)]
proxy_networks: Vec<String>,
#[arg(
short,
long,
default_value = "127.0.0.1:15888",
help = "rpc portal address to listen for management"
)]
rpc_portal: SocketAddr,
#[arg(short, long, help = "listeners to accept connections, pass '' to avoid listening.",
default_values_t = ["tcp://0.0.0.0:11010".to_string(),
"udp://0.0.0.0:11010".to_string(),
"wg://0.0.0.0:11011".to_string()])]
listeners: Vec<String>,
/// specify the linux network namespace, default is the root namespace
#[arg(long)]
net_ns: Option<String>,
#[arg(long, help = "console log level",
value_parser = clap::builder::PossibleValuesParser::new(["trace", "debug", "info", "warn", "error", "off"]))]
console_log_level: Option<String>,
#[arg(long, help = "file log level",
value_parser = clap::builder::PossibleValuesParser::new(["trace", "debug", "info", "warn", "error", "off"]))]
file_log_level: Option<String>,
#[arg(long, help = "directory to store log files")]
file_log_dir: Option<String>,
#[arg(
short = 'm',
long,
default_value = "default",
help = "instance name to identify this vpn node in same machine"
)]
instance_name: String,
#[arg(
short = 'd',
long,
help = "instance uuid to identify this vpn node in whole vpn network example: 123e4567-e89b-12d3-a456-426614174000"
)]
instance_id: Option<String>,
#[arg(
long,
help = "url that defines the vpn portal, allow other vpn clients to connect.
example: wg://0.0.0.0:11010/10.14.14.0/24, means the vpn portal is a wireguard server listening on vpn.example.com:11010,
and the vpn client is in network of 10.14.14.0/24"
)]
vpn_portal: Option<String>,
#[arg(long, help = "default protocol to use when connecting to peers")]
default_protocol: Option<String>,
}
impl From<Cli> for TomlConfigLoader {
fn from(cli: Cli) -> Self {
let cfg = TomlConfigLoader::default();
cfg.set_inst_name(cli.instance_name.clone());
cfg.set_network_identity(NetworkIdentity {
network_name: cli.network_name.clone(),
network_secret: cli.network_secret.clone(),
});
cfg.set_netns(cli.net_ns.clone());
if let Some(ipv4) = &cli.ipv4 {
cfg.set_ipv4(
ipv4.parse()
.with_context(|| format!("failed to parse ipv4 address: {}", ipv4))
.unwrap(),
)
}
cfg.set_peers(
cli.peers
.iter()
.map(|s| PeerConfig {
uri: s
.parse()
.with_context(|| format!("failed to parse peer uri: {}", s))
.unwrap(),
})
.collect(),
);
cfg.set_listeners(
cli.listeners
.iter()
.filter_map(|s| {
if s.is_empty() {
return None;
}
Some(
s.parse()
.with_context(|| format!("failed to parse listener uri: {}", s))
.unwrap(),
)
})
.collect(),
);
for n in cli.proxy_networks.iter() {
cfg.add_proxy_cidr(
n.parse()
.with_context(|| format!("failed to parse proxy network: {}", n))
.unwrap(),
);
}
cfg.set_rpc_portal(cli.rpc_portal);
if cli.external_node.is_some() {
let mut old_peers = cfg.get_peers();
old_peers.push(PeerConfig {
uri: cli
.external_node
.clone()
.unwrap()
.parse()
.with_context(|| {
format!(
"failed to parse external node uri: {}",
cli.external_node.unwrap()
)
})
.unwrap(),
});
cfg.set_peers(old_peers);
}
if cli.console_log_level.is_some() {
cfg.set_console_logger_config(ConsoleLoggerConfig {
level: cli.console_log_level.clone(),
});
}
if cli.file_log_dir.is_some() || cli.file_log_level.is_some() {
cfg.set_file_logger_config(FileLoggerConfig {
level: cli.file_log_level.clone(),
dir: cli.file_log_dir.clone(),
file: Some(format!("easytier-{}", cli.instance_name)),
});
}
if cli.vpn_portal.is_some() {
let url: url::Url = cli
.vpn_portal
.clone()
.unwrap()
.parse()
.with_context(|| {
format!(
"failed to parse vpn portal url: {}",
cli.vpn_portal.unwrap()
)
})
.unwrap();
cfg.set_vpn_portal_config(VpnPortalConfig {
client_cidr: url.path()[1..]
.parse()
.with_context(|| {
format!("failed to parse vpn portal client cidr: {}", url.path())
})
.unwrap(),
wireguard_listen: format!("{}:{}", url.host_str().unwrap(), url.port().unwrap())
.parse()
.with_context(|| {
format!(
"failed to parse vpn portal wireguard listen address: {}",
url.host_str().unwrap()
)
})
.unwrap(),
});
}
if cli.default_protocol.is_some() {
let mut f = cfg.get_flags();
f.default_protocol = cli.default_protocol.as_ref().unwrap().clone();
cfg.set_flags(f);
}
cfg
}
}
fn init_logger(config: impl ConfigLoader) {
let file_config = config.get_file_logger_config();
let file_level = file_config
.level
.map(|s| s.parse().unwrap())
.unwrap_or(LevelFilter::OFF);
// logger to rolling file
let mut file_layer = None;
if file_level != LevelFilter::OFF {
let mut l = tracing_subscriber::fmt::layer();
l.set_ansi(false);
let file_filter = EnvFilter::builder()
.with_default_directive(file_level.into())
.from_env()
.unwrap();
let file_appender = tracing_appender::rolling::Builder::new()
.rotation(tracing_appender::rolling::Rotation::DAILY)
.max_log_files(5)
.filename_prefix(file_config.file.unwrap_or("easytier".to_string()))
.build(file_config.dir.unwrap_or("./".to_string()))
.expect("failed to initialize rolling file appender");
file_layer = Some(
l.with_writer(file_appender)
.with_timer(get_logger_timer_rfc3339())
.with_filter(file_filter),
);
}
// logger to console
let console_config = config.get_console_logger_config();
let console_level = console_config
.level
.map(|s| s.parse().unwrap())
.unwrap_or(LevelFilter::OFF);
let console_filter = EnvFilter::builder()
.with_default_directive(console_level.into())
.from_env()
.unwrap();
let console_layer = tracing_subscriber::fmt::layer()
.pretty()
.with_timer(get_logger_timer_rfc3339())
.with_writer(std::io::stderr)
.with_filter(console_filter);
tracing_subscriber::Registry::default()
.with(console_layer)
.with(file_layer)
.init();
}
fn print_event(msg: String) {
println!(
"{}: {}",
chrono::Local::now().format("%Y-%m-%d %H:%M:%S"),
msg
);
}
fn peer_conn_info_to_string(p: crate::rpc::PeerConnInfo) -> String {
format!(
"my_peer_id: {}, dst_peer_id: {}, tunnel_info: {:?}",
p.my_peer_id, p.peer_id, p.tunnel
)
}
fn setup_panic_handler() {
std::panic::set_hook(Box::new(|info| {
let backtrace = backtrace::Backtrace::force_capture();
println!("panic occurred: {:?}", info);
let _ = std::fs::File::create("easytier-panic.log")
.and_then(|mut f| f.write_all(format!("{:?}\n{:#?}", info, backtrace).as_bytes()));
std::process::exit(1);
}));
}
#[tokio::main(flavor = "current_thread")]
#[tracing::instrument]
pub async fn main() {
setup_panic_handler();
let cli = Cli::parse();
tracing::info!(cli = ?cli, "cli args parsed");
let cfg: TomlConfigLoader = cli.into();
init_logger(&cfg);
let mut inst = Instance::new(cfg.clone());
let mut events = inst.get_global_ctx().subscribe();
tokio::spawn(async move {
while let Ok(e) = events.recv().await {
match e {
GlobalCtxEvent::PeerAdded(p) => {
print_event(format!("new peer added. peer_id: {}", p));
}
GlobalCtxEvent::PeerRemoved(p) => {
print_event(format!("peer removed. peer_id: {}", p));
}
GlobalCtxEvent::PeerConnAdded(p) => {
print_event(format!(
"new peer connection added. conn_info: {}",
peer_conn_info_to_string(p)
));
}
GlobalCtxEvent::PeerConnRemoved(p) => {
print_event(format!(
"peer connection removed. conn_info: {}",
peer_conn_info_to_string(p)
));
}
GlobalCtxEvent::ListenerAdded(p) => {
if p.scheme() == "ring" {
continue;
}
print_event(format!("new listener added. listener: {}", p));
}
GlobalCtxEvent::ConnectionAccepted(local, remote) => {
print_event(format!(
"new connection accepted. local: {}, remote: {}",
local, remote
));
}
GlobalCtxEvent::ConnectionError(local, remote, err) => {
print_event(format!(
"connection error. local: {}, remote: {}, err: {}",
local, remote, err
));
}
GlobalCtxEvent::TunDeviceReady(dev) => {
print_event(format!("tun device ready. dev: {}", dev));
}
GlobalCtxEvent::Connecting(dst) => {
print_event(format!("connecting to peer. dst: {}", dst));
}
GlobalCtxEvent::ConnectError(dst, err) => {
print_event(format!("connect to peer error. dst: {}, err: {}", dst, err));
}
GlobalCtxEvent::VpnPortalClientConnected(portal, client_addr) => {
print_event(format!(
"vpn portal client connected. portal: {}, client_addr: {}",
portal, client_addr
));
}
GlobalCtxEvent::VpnPortalClientDisconnected(portal, client_addr) => {
print_event(format!(
"vpn portal client disconnected. portal: {}, client_addr: {}",
portal, client_addr
));
}
}
}
});
println!("Starting easytier with config:");
println!("############### TOML ##############\n");
println!("{}", cfg.dump());
println!("-----------------------------------");
inst.run().await.unwrap();
inst.wait().await;
}
+293
View File
@@ -0,0 +1,293 @@
use std::{
mem::MaybeUninit,
net::{IpAddr, Ipv4Addr, SocketAddrV4},
sync::Arc,
thread,
};
use pnet::packet::{
icmp::{self, IcmpTypes},
ip::IpNextHeaderProtocols,
ipv4::{self, Ipv4Packet, MutableIpv4Packet},
Packet,
};
use socket2::Socket;
use tokio::{
sync::{mpsc::UnboundedSender, Mutex},
task::JoinSet,
};
use tokio_util::bytes::Bytes;
use tracing::Instrument;
use crate::{
common::{error::Error, global_ctx::ArcGlobalCtx, PeerId},
peers::{packet, peer_manager::PeerManager, PeerPacketFilter},
};
use super::CidrSet;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
struct IcmpNatKey {
dst_ip: std::net::IpAddr,
icmp_id: u16,
icmp_seq: u16,
}
#[derive(Debug)]
struct IcmpNatEntry {
src_peer_id: PeerId,
my_peer_id: PeerId,
src_ip: IpAddr,
start_time: std::time::Instant,
}
impl IcmpNatEntry {
fn new(src_peer_id: PeerId, my_peer_id: PeerId, src_ip: IpAddr) -> Result<Self, Error> {
Ok(Self {
src_peer_id,
my_peer_id,
src_ip,
start_time: std::time::Instant::now(),
})
}
}
type IcmpNatTable = Arc<dashmap::DashMap<IcmpNatKey, IcmpNatEntry>>;
type NewPacketSender = tokio::sync::mpsc::UnboundedSender<IcmpNatKey>;
type NewPacketReceiver = tokio::sync::mpsc::UnboundedReceiver<IcmpNatKey>;
#[derive(Debug)]
pub struct IcmpProxy {
global_ctx: ArcGlobalCtx,
peer_manager: Arc<PeerManager>,
cidr_set: CidrSet,
socket: socket2::Socket,
nat_table: IcmpNatTable,
tasks: Mutex<JoinSet<()>>,
}
fn socket_recv(socket: &Socket, buf: &mut [MaybeUninit<u8>]) -> Result<(usize, IpAddr), Error> {
let (size, addr) = socket.recv_from(buf)?;
let addr = match addr.as_socket() {
None => IpAddr::V4(Ipv4Addr::UNSPECIFIED),
Some(add) => add.ip(),
};
Ok((size, addr))
}
fn socket_recv_loop(
socket: Socket,
nat_table: IcmpNatTable,
sender: UnboundedSender<packet::Packet>,
) {
let mut buf = [0u8; 4096];
let data: &mut [MaybeUninit<u8>] = unsafe { std::mem::transmute(&mut buf[12..]) };
loop {
let Ok((len, peer_ip)) = socket_recv(&socket, data) else {
continue;
};
if !peer_ip.is_ipv4() {
continue;
}
let Some(mut ipv4_packet) = MutableIpv4Packet::new(&mut buf[12..12 + len]) else {
continue;
};
let Some(icmp_packet) = icmp::echo_reply::EchoReplyPacket::new(ipv4_packet.payload())
else {
continue;
};
if icmp_packet.get_icmp_type() != IcmpTypes::EchoReply {
continue;
}
let key = IcmpNatKey {
dst_ip: peer_ip,
icmp_id: icmp_packet.get_identifier(),
icmp_seq: icmp_packet.get_sequence_number(),
};
let Some((_, v)) = nat_table.remove(&key) else {
continue;
};
// send packet back to the peer where this request origin.
let IpAddr::V4(dest_ip) = v.src_ip else {
continue;
};
ipv4_packet.set_destination(dest_ip);
ipv4_packet.set_checksum(ipv4::checksum(&ipv4_packet.to_immutable()));
let peer_packet = packet::Packet::new_data_packet(
v.my_peer_id,
v.src_peer_id,
&ipv4_packet.to_immutable().packet(),
);
if let Err(e) = sender.send(peer_packet) {
tracing::error!("send icmp packet to peer failed: {:?}, may exiting..", e);
break;
}
}
}
#[async_trait::async_trait]
impl PeerPacketFilter for IcmpProxy {
async fn try_process_packet_from_peer(
&self,
packet: &packet::ArchivedPacket,
_: &Bytes,
) -> Option<()> {
let _ = self.global_ctx.get_ipv4()?;
if packet.packet_type != packet::PacketType::Data {
return None;
};
let ipv4 = Ipv4Packet::new(&packet.payload.as_bytes())?;
if ipv4.get_version() != 4 || ipv4.get_next_level_protocol() != IpNextHeaderProtocols::Icmp
{
return None;
}
if !self.cidr_set.contains_v4(ipv4.get_destination()) {
return None;
}
let icmp_packet = icmp::echo_request::EchoRequestPacket::new(&ipv4.payload())?;
if icmp_packet.get_icmp_type() != IcmpTypes::EchoRequest {
// drop it because we do not support other icmp types
tracing::trace!("unsupported icmp type: {:?}", icmp_packet.get_icmp_type());
return Some(());
}
let icmp_id = icmp_packet.get_identifier();
let icmp_seq = icmp_packet.get_sequence_number();
let key = IcmpNatKey {
dst_ip: ipv4.get_destination().into(),
icmp_id,
icmp_seq,
};
let value = IcmpNatEntry::new(
packet.from_peer.into(),
packet.to_peer.into(),
ipv4.get_source().into(),
)
.ok()?;
if let Some(old) = self.nat_table.insert(key, value) {
tracing::info!("icmp nat table entry replaced: {:?}", old);
}
if let Err(e) = self.send_icmp_packet(ipv4.get_destination(), &icmp_packet) {
tracing::error!("send icmp packet failed: {:?}", e);
}
Some(())
}
}
impl IcmpProxy {
pub fn new(
global_ctx: ArcGlobalCtx,
peer_manager: Arc<PeerManager>,
) -> Result<Arc<Self>, Error> {
let cidr_set = CidrSet::new(global_ctx.clone());
let _g = global_ctx.net_ns.guard();
let socket = socket2::Socket::new(
socket2::Domain::IPV4,
socket2::Type::RAW,
Some(socket2::Protocol::ICMPV4),
)?;
socket.bind(&socket2::SockAddr::from(SocketAddrV4::new(
std::net::Ipv4Addr::UNSPECIFIED,
0,
)))?;
let ret = Self {
global_ctx,
peer_manager,
cidr_set,
socket,
nat_table: Arc::new(dashmap::DashMap::new()),
tasks: Mutex::new(JoinSet::new()),
};
Ok(Arc::new(ret))
}
pub async fn start(self: &Arc<Self>) -> Result<(), Error> {
self.start_icmp_proxy().await?;
self.start_nat_table_cleaner().await?;
Ok(())
}
async fn start_nat_table_cleaner(self: &Arc<Self>) -> Result<(), Error> {
let nat_table = self.nat_table.clone();
self.tasks.lock().await.spawn(
async move {
loop {
tokio::time::sleep(std::time::Duration::from_secs(1)).await;
nat_table.retain(|_, v| v.start_time.elapsed().as_secs() < 20);
}
}
.instrument(tracing::info_span!("icmp proxy nat table cleaner")),
);
Ok(())
}
async fn start_icmp_proxy(self: &Arc<Self>) -> Result<(), Error> {
let socket = self.socket.try_clone()?;
let (sender, mut receiver) = tokio::sync::mpsc::unbounded_channel();
let nat_table = self.nat_table.clone();
thread::spawn(|| {
socket_recv_loop(socket, nat_table, sender);
});
let peer_manager = self.peer_manager.clone();
self.tasks.lock().await.spawn(
async move {
while let Some(msg) = receiver.recv().await {
let to_peer_id = msg.to_peer.into();
let ret = peer_manager.send_msg(msg.into(), to_peer_id).await;
if ret.is_err() {
tracing::error!("send icmp packet to peer failed: {:?}", ret);
}
}
}
.instrument(tracing::info_span!("icmp proxy send loop")),
);
self.peer_manager
.add_packet_process_pipeline(Box::new(self.clone()))
.await;
Ok(())
}
fn send_icmp_packet(
&self,
dst_ip: Ipv4Addr,
icmp_packet: &icmp::echo_request::EchoRequestPacket,
) -> Result<(), Error> {
self.socket.send_to(
icmp_packet.packet(),
&SocketAddrV4::new(dst_ip.into(), 0).into(),
)?;
Ok(())
}
}
+56
View File
@@ -0,0 +1,56 @@
use dashmap::DashSet;
use std::sync::Arc;
use tokio::task::JoinSet;
use crate::common::global_ctx::ArcGlobalCtx;
pub mod icmp_proxy;
pub mod tcp_proxy;
pub mod udp_proxy;
#[derive(Debug)]
struct CidrSet {
global_ctx: ArcGlobalCtx,
cidr_set: Arc<DashSet<cidr::IpCidr>>,
tasks: JoinSet<()>,
}
impl CidrSet {
pub fn new(global_ctx: ArcGlobalCtx) -> Self {
let mut ret = Self {
global_ctx,
cidr_set: Arc::new(DashSet::new()),
tasks: JoinSet::new(),
};
ret.run_cidr_updater();
ret
}
fn run_cidr_updater(&mut self) {
let global_ctx = self.global_ctx.clone();
let cidr_set = self.cidr_set.clone();
self.tasks.spawn(async move {
let mut last_cidrs = vec![];
loop {
let cidrs = global_ctx.get_proxy_cidrs();
if cidrs != last_cidrs {
last_cidrs = cidrs.clone();
cidr_set.clear();
for cidr in cidrs.iter() {
cidr_set.insert(cidr.clone());
}
}
tokio::time::sleep(std::time::Duration::from_secs(1)).await;
}
});
}
pub fn contains_v4(&self, ip: std::net::Ipv4Addr) -> bool {
let ip = ip.into();
return self.cidr_set.iter().any(|cidr| cidr.contains(&ip));
}
pub fn is_empty(&self) -> bool {
return self.cidr_set.is_empty();
}
}
+407
View File
@@ -0,0 +1,407 @@
use crossbeam::atomic::AtomicCell;
use dashmap::DashMap;
use pnet::packet::ip::IpNextHeaderProtocols;
use pnet::packet::ipv4::{Ipv4Packet, MutableIpv4Packet};
use pnet::packet::tcp::{ipv4_checksum, MutableTcpPacket};
use std::net::{IpAddr, Ipv4Addr, SocketAddr, SocketAddrV4};
use std::sync::atomic::AtomicU16;
use std::sync::Arc;
use std::time::{Duration, Instant};
use tokio::io::copy_bidirectional;
use tokio::net::{TcpListener, TcpSocket, TcpStream};
use tokio::sync::Mutex;
use tokio::task::JoinSet;
use tokio_util::bytes::{Bytes, BytesMut};
use tracing::Instrument;
use crate::common::error::Result;
use crate::common::global_ctx::GlobalCtx;
use crate::common::join_joinset_background;
use crate::common::netns::NetNS;
use crate::peers::packet::{self, ArchivedPacket};
use crate::peers::peer_manager::PeerManager;
use crate::peers::{NicPacketFilter, PeerPacketFilter};
use super::CidrSet;
#[derive(Debug, Clone, Copy, PartialEq)]
enum NatDstEntryState {
// receive syn packet but not start connecting to dst
SynReceived,
// connecting to dst
ConnectingDst,
// connected to dst
Connected,
// connection closed
Closed,
}
#[derive(Debug)]
pub struct NatDstEntry {
id: uuid::Uuid,
src: SocketAddr,
dst: SocketAddr,
start_time: Instant,
tasks: Mutex<JoinSet<()>>,
state: AtomicCell<NatDstEntryState>,
}
impl NatDstEntry {
pub fn new(src: SocketAddr, dst: SocketAddr) -> Self {
Self {
id: uuid::Uuid::new_v4(),
src,
dst,
start_time: Instant::now(),
tasks: Mutex::new(JoinSet::new()),
state: AtomicCell::new(NatDstEntryState::SynReceived),
}
}
}
type ArcNatDstEntry = Arc<NatDstEntry>;
type SynSockMap = Arc<DashMap<SocketAddr, ArcNatDstEntry>>;
type ConnSockMap = Arc<DashMap<uuid::Uuid, ArcNatDstEntry>>;
// peer src addr to nat entry, when respond tcp packet, should modify the tcp src addr to the nat entry's dst addr
type AddrConnSockMap = Arc<DashMap<SocketAddr, ArcNatDstEntry>>;
#[derive(Debug)]
pub struct TcpProxy {
global_ctx: Arc<GlobalCtx>,
peer_manager: Arc<PeerManager>,
local_port: AtomicU16,
tasks: Arc<std::sync::Mutex<JoinSet<()>>>,
syn_map: SynSockMap,
conn_map: ConnSockMap,
addr_conn_map: AddrConnSockMap,
cidr_set: CidrSet,
}
#[async_trait::async_trait]
impl PeerPacketFilter for TcpProxy {
async fn try_process_packet_from_peer(&self, packet: &ArchivedPacket, _: &Bytes) -> Option<()> {
let ipv4_addr = self.global_ctx.get_ipv4()?;
if packet.packet_type != packet::PacketType::Data {
return None;
};
let payload_bytes = packet.payload.as_bytes();
let ipv4 = Ipv4Packet::new(payload_bytes)?;
if ipv4.get_version() != 4 || ipv4.get_next_level_protocol() != IpNextHeaderProtocols::Tcp {
return None;
}
if !self.cidr_set.contains_v4(ipv4.get_destination()) {
return None;
}
tracing::trace!(ipv4 = ?ipv4, cidr_set = ?self.cidr_set, "proxy tcp packet received");
let mut packet_buffer = BytesMut::with_capacity(payload_bytes.len());
packet_buffer.extend_from_slice(&payload_bytes.to_vec());
let (ip_buffer, tcp_buffer) =
packet_buffer.split_at_mut(ipv4.get_header_length() as usize * 4);
let mut ip_packet = MutableIpv4Packet::new(ip_buffer).unwrap();
let mut tcp_packet = MutableTcpPacket::new(tcp_buffer).unwrap();
let is_tcp_syn = tcp_packet.get_flags() & pnet::packet::tcp::TcpFlags::SYN != 0;
if is_tcp_syn {
let source_ip = ip_packet.get_source();
let source_port = tcp_packet.get_source();
let src = SocketAddr::V4(SocketAddrV4::new(source_ip, source_port));
let dest_ip = ip_packet.get_destination();
let dest_port = tcp_packet.get_destination();
let dst = SocketAddr::V4(SocketAddrV4::new(dest_ip, dest_port));
let old_val = self
.syn_map
.insert(src, Arc::new(NatDstEntry::new(src, dst)));
tracing::trace!(src = ?src, dst = ?dst, old_entry = ?old_val, "tcp syn received");
}
ip_packet.set_destination(ipv4_addr);
tcp_packet.set_destination(self.get_local_port());
Self::update_ipv4_packet_checksum(&mut ip_packet, &mut tcp_packet);
tracing::trace!(ip_packet = ?ip_packet, tcp_packet = ?tcp_packet, "tcp packet forwarded");
if let Err(e) = self
.peer_manager
.get_nic_channel()
.send(packet_buffer.freeze())
.await
{
tracing::error!("send to nic failed: {:?}", e);
}
Some(())
}
}
#[async_trait::async_trait]
impl NicPacketFilter for TcpProxy {
async fn try_process_packet_from_nic(&self, mut data: BytesMut) -> BytesMut {
let Some(my_ipv4) = self.global_ctx.get_ipv4() else {
return data;
};
let header_len = {
let Some(ipv4) = &Ipv4Packet::new(&data[..]) else {
return data;
};
if ipv4.get_version() != 4
|| ipv4.get_source() != my_ipv4
|| ipv4.get_next_level_protocol() != IpNextHeaderProtocols::Tcp
{
return data;
}
ipv4.get_header_length() as usize * 4
};
let (ip_buffer, tcp_buffer) = data.split_at_mut(header_len);
let mut ip_packet = MutableIpv4Packet::new(ip_buffer).unwrap();
let mut tcp_packet = MutableTcpPacket::new(tcp_buffer).unwrap();
if tcp_packet.get_source() != self.get_local_port() {
return data;
}
let dst_addr = SocketAddr::V4(SocketAddrV4::new(
ip_packet.get_destination(),
tcp_packet.get_destination(),
));
tracing::trace!(dst_addr = ?dst_addr, "tcp packet try find entry");
let entry = if let Some(entry) = self.addr_conn_map.get(&dst_addr) {
entry
} else {
let Some(syn_entry) = self.syn_map.get(&dst_addr) else {
return data;
};
syn_entry
};
let nat_entry = entry.clone();
drop(entry);
assert_eq!(nat_entry.src, dst_addr);
let IpAddr::V4(ip) = nat_entry.dst.ip() else {
panic!("v4 nat entry src ip is not v4");
};
ip_packet.set_source(ip);
tcp_packet.set_source(nat_entry.dst.port());
Self::update_ipv4_packet_checksum(&mut ip_packet, &mut tcp_packet);
tracing::trace!(dst_addr = ?dst_addr, nat_entry = ?nat_entry, packet = ?ip_packet, "tcp packet after modified");
data
}
}
impl TcpProxy {
pub fn new(global_ctx: Arc<GlobalCtx>, peer_manager: Arc<PeerManager>) -> Arc<Self> {
Arc::new(Self {
global_ctx: global_ctx.clone(),
peer_manager,
local_port: AtomicU16::new(0),
tasks: Arc::new(std::sync::Mutex::new(JoinSet::new())),
syn_map: Arc::new(DashMap::new()),
conn_map: Arc::new(DashMap::new()),
addr_conn_map: Arc::new(DashMap::new()),
cidr_set: CidrSet::new(global_ctx),
})
}
fn update_ipv4_packet_checksum(
ipv4_packet: &mut MutableIpv4Packet,
tcp_packet: &mut MutableTcpPacket,
) {
tcp_packet.set_checksum(ipv4_checksum(
&tcp_packet.to_immutable(),
&ipv4_packet.get_source(),
&ipv4_packet.get_destination(),
));
ipv4_packet.set_checksum(pnet::packet::ipv4::checksum(&ipv4_packet.to_immutable()));
}
pub async fn start(self: &Arc<Self>) -> Result<()> {
self.run_syn_map_cleaner().await?;
self.run_listener().await?;
self.peer_manager
.add_packet_process_pipeline(Box::new(self.clone()))
.await;
self.peer_manager
.add_nic_packet_process_pipeline(Box::new(self.clone()))
.await;
join_joinset_background(self.tasks.clone(), "TcpProxy".to_owned());
Ok(())
}
async fn run_syn_map_cleaner(&self) -> Result<()> {
let syn_map = self.syn_map.clone();
let tasks = self.tasks.clone();
let syn_map_cleaner_task = async move {
loop {
syn_map.retain(|_, entry| {
if entry.start_time.elapsed() > Duration::from_secs(30) {
tracing::warn!(entry = ?entry, "syn nat entry expired");
entry.state.store(NatDstEntryState::Closed);
false
} else {
true
}
});
tokio::time::sleep(Duration::from_secs(10)).await;
}
};
tasks.lock().unwrap().spawn(syn_map_cleaner_task);
Ok(())
}
async fn run_listener(&self) -> Result<()> {
// bind on both v4 & v6
let listen_addr = SocketAddr::new(Ipv4Addr::UNSPECIFIED.into(), 0);
let net_ns = self.global_ctx.net_ns.clone();
let tcp_listener = net_ns
.run_async(|| async { TcpListener::bind(&listen_addr).await })
.await?;
self.local_port.store(
tcp_listener.local_addr()?.port(),
std::sync::atomic::Ordering::Relaxed,
);
let tasks = self.tasks.clone();
let syn_map = self.syn_map.clone();
let conn_map = self.conn_map.clone();
let addr_conn_map = self.addr_conn_map.clone();
let accept_task = async move {
tracing::info!(listener = ?tcp_listener, "tcp connection start accepting");
let conn_map = conn_map.clone();
while let Ok((tcp_stream, socket_addr)) = tcp_listener.accept().await {
let Some(entry) = syn_map.get(&socket_addr) else {
tracing::error!("tcp connection from unknown source: {:?}", socket_addr);
continue;
};
assert_eq!(entry.state.load(), NatDstEntryState::SynReceived);
let entry_clone = entry.clone();
drop(entry);
syn_map.remove_if(&socket_addr, |_, entry| entry.id == entry_clone.id);
entry_clone.state.store(NatDstEntryState::ConnectingDst);
let _ = addr_conn_map.insert(entry_clone.src, entry_clone.clone());
let old_nat_val = conn_map.insert(entry_clone.id, entry_clone.clone());
assert!(old_nat_val.is_none());
tasks.lock().unwrap().spawn(Self::connect_to_nat_dst(
net_ns.clone(),
tcp_stream,
conn_map.clone(),
addr_conn_map.clone(),
entry_clone,
));
}
tracing::error!("nat tcp listener exited");
panic!("nat tcp listener exited");
};
self.tasks
.lock()
.unwrap()
.spawn(accept_task.instrument(tracing::info_span!("tcp_proxy_listener")));
Ok(())
}
fn remove_entry_from_all_conn_map(
conn_map: ConnSockMap,
addr_conn_map: AddrConnSockMap,
nat_entry: ArcNatDstEntry,
) {
conn_map.remove(&nat_entry.id);
addr_conn_map.remove_if(&nat_entry.src, |_, entry| entry.id == nat_entry.id);
}
async fn connect_to_nat_dst(
net_ns: NetNS,
src_tcp_stream: TcpStream,
conn_map: ConnSockMap,
addr_conn_map: AddrConnSockMap,
nat_entry: ArcNatDstEntry,
) {
if let Err(e) = src_tcp_stream.set_nodelay(true) {
tracing::warn!("set_nodelay failed, ignore it: {:?}", e);
}
let _guard = net_ns.guard();
let socket = TcpSocket::new_v4().unwrap();
if let Err(e) = socket.set_nodelay(true) {
tracing::warn!("set_nodelay failed, ignore it: {:?}", e);
}
let Ok(Ok(dst_tcp_stream)) = tokio::time::timeout(
Duration::from_secs(10),
TcpSocket::new_v4().unwrap().connect(nat_entry.dst),
)
.await
else {
tracing::error!("connect to dst failed: {:?}", nat_entry);
nat_entry.state.store(NatDstEntryState::Closed);
Self::remove_entry_from_all_conn_map(conn_map, addr_conn_map, nat_entry);
return;
};
drop(_guard);
assert_eq!(nat_entry.state.load(), NatDstEntryState::ConnectingDst);
nat_entry.state.store(NatDstEntryState::Connected);
Self::handle_nat_connection(
src_tcp_stream,
dst_tcp_stream,
conn_map,
addr_conn_map,
nat_entry,
)
.await;
}
async fn handle_nat_connection(
mut src_tcp_stream: TcpStream,
mut dst_tcp_stream: TcpStream,
conn_map: ConnSockMap,
addr_conn_map: AddrConnSockMap,
nat_entry: ArcNatDstEntry,
) {
let nat_entry_clone = nat_entry.clone();
nat_entry.tasks.lock().await.spawn(async move {
let ret = copy_bidirectional(&mut src_tcp_stream, &mut dst_tcp_stream).await;
tracing::trace!(nat_entry = ?nat_entry_clone, ret = ?ret, "nat tcp connection closed");
nat_entry_clone.state.store(NatDstEntryState::Closed);
Self::remove_entry_from_all_conn_map(conn_map, addr_conn_map, nat_entry_clone);
});
}
pub fn get_local_port(&self) -> u16 {
self.local_port.load(std::sync::atomic::Ordering::Relaxed)
}
}
+383
View File
@@ -0,0 +1,383 @@
use std::{
net::{SocketAddr, SocketAddrV4},
sync::{atomic::AtomicBool, Arc},
time::Duration,
};
use dashmap::DashMap;
use pnet::packet::{
ip::IpNextHeaderProtocols,
ipv4::{self, Ipv4Flags, Ipv4Packet, MutableIpv4Packet},
udp::{self, MutableUdpPacket},
Packet,
};
use tokio::{
net::UdpSocket,
sync::{
mpsc::{unbounded_channel, UnboundedReceiver, UnboundedSender},
Mutex,
},
task::{JoinHandle, JoinSet},
time::timeout,
};
use tokio_util::bytes::Bytes;
use tracing::Level;
use crate::{
common::{error::Error, global_ctx::ArcGlobalCtx, PeerId},
peers::{packet, peer_manager::PeerManager, PeerPacketFilter},
tunnels::common::setup_sokcet2,
};
use super::CidrSet;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
struct UdpNatKey {
src_socket: SocketAddr,
}
#[derive(Debug)]
struct UdpNatEntry {
src_peer_id: PeerId,
my_peer_id: PeerId,
src_socket: SocketAddr,
socket: UdpSocket,
forward_task: Mutex<Option<JoinHandle<()>>>,
stopped: AtomicBool,
start_time: std::time::Instant,
}
impl UdpNatEntry {
#[tracing::instrument(err(level = Level::WARN))]
fn new(src_peer_id: PeerId, my_peer_id: PeerId, src_socket: SocketAddr) -> Result<Self, Error> {
// TODO: try use src port, so we will be ip restricted nat type
let socket2_socket = socket2::Socket::new(
socket2::Domain::IPV4,
socket2::Type::DGRAM,
Some(socket2::Protocol::UDP),
)?;
let dst_socket_addr = "0.0.0.0:0".parse().unwrap();
setup_sokcet2(&socket2_socket, &dst_socket_addr)?;
let socket = UdpSocket::from_std(socket2_socket.into())?;
Ok(Self {
src_peer_id,
my_peer_id,
src_socket,
socket,
forward_task: Mutex::new(None),
stopped: AtomicBool::new(false),
start_time: std::time::Instant::now(),
})
}
pub fn stop(&self) {
self.stopped
.store(true, std::sync::atomic::Ordering::Relaxed);
}
async fn compose_ipv4_packet(
self: &Arc<Self>,
packet_sender: &mut UnboundedSender<packet::Packet>,
buf: &mut [u8],
src_v4: &SocketAddrV4,
payload_len: usize,
payload_mtu: usize,
ip_id: u16,
) -> Result<(), Error> {
let SocketAddr::V4(nat_src_v4) = self.src_socket else {
return Err(Error::Unknown);
};
assert_eq!(0, payload_mtu % 8);
// udp payload is in buf[20 + 8..]
let mut udp_packet = MutableUdpPacket::new(&mut buf[20..28 + payload_len]).unwrap();
udp_packet.set_source(src_v4.port());
udp_packet.set_destination(self.src_socket.port());
udp_packet.set_length(payload_len as u16 + 8);
udp_packet.set_checksum(udp::ipv4_checksum(
&udp_packet.to_immutable(),
src_v4.ip(),
nat_src_v4.ip(),
));
let payload_len = payload_len + 8; // include udp header
let total_pieces = (payload_len + payload_mtu - 1) / payload_mtu;
let mut buf_offset = 0;
let mut fragment_offset = 0;
let mut cur_piece = 0;
while fragment_offset < payload_len {
let next_fragment_offset = std::cmp::min(fragment_offset + payload_mtu, payload_len);
let fragment_len = next_fragment_offset - fragment_offset;
let mut ipv4_packet =
MutableIpv4Packet::new(&mut buf[buf_offset..buf_offset + fragment_len + 20])
.unwrap();
ipv4_packet.set_version(4);
ipv4_packet.set_header_length(5);
ipv4_packet.set_total_length((fragment_len + 20) as u16);
ipv4_packet.set_identification(ip_id);
if total_pieces > 1 {
if cur_piece != total_pieces - 1 {
ipv4_packet.set_flags(Ipv4Flags::MoreFragments);
} else {
ipv4_packet.set_flags(0);
}
assert_eq!(0, fragment_offset % 8);
ipv4_packet.set_fragment_offset(fragment_offset as u16 / 8);
} else {
ipv4_packet.set_flags(Ipv4Flags::DontFragment);
ipv4_packet.set_fragment_offset(0);
}
ipv4_packet.set_ecn(0);
ipv4_packet.set_dscp(0);
ipv4_packet.set_ttl(32);
ipv4_packet.set_source(src_v4.ip().clone());
ipv4_packet.set_destination(nat_src_v4.ip().clone());
ipv4_packet.set_next_level_protocol(IpNextHeaderProtocols::Udp);
ipv4_packet.set_checksum(ipv4::checksum(&ipv4_packet.to_immutable()));
tracing::trace!(?ipv4_packet, "udp nat packet response send");
let peer_packet = packet::Packet::new_data_packet(
self.my_peer_id,
self.src_peer_id,
&ipv4_packet.to_immutable().packet(),
);
if let Err(e) = packet_sender.send(peer_packet) {
tracing::error!("send icmp packet to peer failed: {:?}, may exiting..", e);
return Err(Error::AnyhowError(e.into()));
}
buf_offset += next_fragment_offset - fragment_offset;
fragment_offset = next_fragment_offset;
cur_piece += 1;
}
Ok(())
}
async fn forward_task(self: Arc<Self>, mut packet_sender: UnboundedSender<packet::Packet>) {
let mut buf = [0u8; 8192];
let mut udp_body: &mut [u8] = unsafe { std::mem::transmute(&mut buf[20 + 8..]) };
let mut ip_id = 1;
loop {
let (len, src_socket) = match timeout(
Duration::from_secs(120),
self.socket.recv_from(&mut udp_body),
)
.await
{
Ok(Ok(x)) => x,
Ok(Err(err)) => {
tracing::error!(?err, "udp nat recv failed");
break;
}
Err(err) => {
tracing::error!(?err, "udp nat recv timeout");
break;
}
};
tracing::trace!(?len, ?src_socket, "udp nat packet response received");
if self.stopped.load(std::sync::atomic::Ordering::Relaxed) {
break;
}
let SocketAddr::V4(src_v4) = src_socket else {
continue;
};
let Ok(_) = Self::compose_ipv4_packet(
&self,
&mut packet_sender,
&mut buf,
&src_v4,
len,
1200,
ip_id,
)
.await
else {
break;
};
ip_id = ip_id.wrapping_add(1);
}
self.stop();
}
}
#[derive(Debug)]
pub struct UdpProxy {
global_ctx: ArcGlobalCtx,
peer_manager: Arc<PeerManager>,
cidr_set: CidrSet,
nat_table: Arc<DashMap<UdpNatKey, Arc<UdpNatEntry>>>,
sender: UnboundedSender<packet::Packet>,
receiver: Mutex<Option<UnboundedReceiver<packet::Packet>>>,
tasks: Mutex<JoinSet<()>>,
}
#[async_trait::async_trait]
impl PeerPacketFilter for UdpProxy {
async fn try_process_packet_from_peer(
&self,
packet: &packet::ArchivedPacket,
_: &Bytes,
) -> Option<()> {
if self.cidr_set.is_empty() {
return None;
}
let _ = self.global_ctx.get_ipv4()?;
if packet.packet_type != packet::PacketType::Data {
return None;
};
let ipv4 = Ipv4Packet::new(packet.payload.as_bytes())?;
if ipv4.get_version() != 4 || ipv4.get_next_level_protocol() != IpNextHeaderProtocols::Udp {
return None;
}
if !self.cidr_set.contains_v4(ipv4.get_destination()) {
return None;
}
let udp_packet = udp::UdpPacket::new(ipv4.payload())?;
tracing::trace!(
?packet,
?ipv4,
?udp_packet,
"udp nat packet request received"
);
let nat_key = UdpNatKey {
src_socket: SocketAddr::new(ipv4.get_source().into(), udp_packet.get_source()),
};
let nat_entry = self
.nat_table
.entry(nat_key)
.or_try_insert_with::<Error>(|| {
tracing::info!(?packet, ?ipv4, ?udp_packet, "udp nat table entry created");
let _g = self.global_ctx.net_ns.guard();
Ok(Arc::new(UdpNatEntry::new(
packet.from_peer.into(),
packet.to_peer.into(),
nat_key.src_socket,
)?))
})
.ok()?
.clone();
if nat_entry.forward_task.lock().await.is_none() {
nat_entry
.forward_task
.lock()
.await
.replace(tokio::spawn(UdpNatEntry::forward_task(
nat_entry.clone(),
self.sender.clone(),
)));
}
// TODO: should it be async.
let dst_socket =
SocketAddr::new(ipv4.get_destination().into(), udp_packet.get_destination());
let send_ret = {
let _g = self.global_ctx.net_ns.guard();
nat_entry
.socket
.send_to(udp_packet.payload(), dst_socket)
.await
};
if let Err(send_err) = send_ret {
tracing::error!(
?send_err,
?nat_key,
?nat_entry,
?send_err,
"udp nat send failed"
);
}
Some(())
}
}
impl UdpProxy {
pub fn new(
global_ctx: ArcGlobalCtx,
peer_manager: Arc<PeerManager>,
) -> Result<Arc<Self>, Error> {
let cidr_set = CidrSet::new(global_ctx.clone());
let (sender, receiver) = unbounded_channel();
let ret = Self {
global_ctx,
peer_manager,
cidr_set,
nat_table: Arc::new(DashMap::new()),
sender,
receiver: Mutex::new(Some(receiver)),
tasks: Mutex::new(JoinSet::new()),
};
Ok(Arc::new(ret))
}
pub async fn start(self: &Arc<Self>) -> Result<(), Error> {
self.peer_manager
.add_packet_process_pipeline(Box::new(self.clone()))
.await;
// clean up nat table
let nat_table = self.nat_table.clone();
self.tasks.lock().await.spawn(async move {
loop {
tokio::time::sleep(Duration::from_secs(15)).await;
nat_table.retain(|_, v| {
if v.start_time.elapsed().as_secs() > 120 {
tracing::info!(?v, "udp nat table entry removed");
v.stop();
false
} else {
true
}
});
}
});
// forward packets to peer manager
let mut receiver = self.receiver.lock().await.take().unwrap();
let peer_manager = self.peer_manager.clone();
self.tasks.lock().await.spawn(async move {
while let Some(msg) = receiver.recv().await {
let to_peer_id: PeerId = msg.to_peer.into();
tracing::trace!(?msg, ?to_peer_id, "udp nat packet response send");
let ret = peer_manager.send_msg(msg.into(), to_peer_id).await;
if ret.is_err() {
tracing::error!("send icmp packet to peer failed: {:?}", ret);
}
}
});
Ok(())
}
}
impl Drop for UdpProxy {
fn drop(&mut self) {
for v in self.nat_table.iter() {
v.stop();
}
}
}
+481
View File
@@ -0,0 +1,481 @@
use std::borrow::BorrowMut;
use std::net::Ipv4Addr;
use std::sync::{Arc, Weak};
use anyhow::Context;
use futures::StreamExt;
use pnet::packet::ethernet::EthernetPacket;
use pnet::packet::ipv4::Ipv4Packet;
use tokio::{sync::Mutex, task::JoinSet};
use tokio_util::bytes::{Bytes, BytesMut};
use tonic::transport::Server;
use crate::common::config::ConfigLoader;
use crate::common::error::Error;
use crate::common::global_ctx::{ArcGlobalCtx, GlobalCtx, GlobalCtxEvent};
use crate::common::PeerId;
use crate::connector::direct::DirectConnectorManager;
use crate::connector::manual::{ConnectorManagerRpcService, ManualConnectorManager};
use crate::connector::udp_hole_punch::UdpHolePunchConnector;
use crate::gateway::icmp_proxy::IcmpProxy;
use crate::gateway::tcp_proxy::TcpProxy;
use crate::gateway::udp_proxy::UdpProxy;
use crate::peer_center::instance::PeerCenterInstance;
use crate::peers::peer_conn::PeerConnId;
use crate::peers::peer_manager::{PeerManager, RouteAlgoType};
use crate::peers::rpc_service::PeerManagerRpcService;
use crate::rpc::vpn_portal_rpc_server::VpnPortalRpc;
use crate::rpc::{GetVpnPortalInfoRequest, GetVpnPortalInfoResponse, VpnPortalInfo};
use crate::tunnels::SinkItem;
use crate::vpn_portal::{self, VpnPortal};
use tokio_stream::wrappers::ReceiverStream;
use super::listeners::ListenerManager;
use super::virtual_nic;
pub struct Instance {
inst_name: String,
id: uuid::Uuid,
virtual_nic: Option<Arc<virtual_nic::VirtualNic>>,
peer_packet_receiver: Option<ReceiverStream<SinkItem>>,
tasks: JoinSet<()>,
peer_manager: Arc<PeerManager>,
listener_manager: Arc<Mutex<ListenerManager<PeerManager>>>,
conn_manager: Arc<ManualConnectorManager>,
direct_conn_manager: Arc<DirectConnectorManager>,
udp_hole_puncher: Arc<Mutex<UdpHolePunchConnector>>,
tcp_proxy: Arc<TcpProxy>,
icmp_proxy: Arc<IcmpProxy>,
udp_proxy: Arc<UdpProxy>,
peer_center: Arc<PeerCenterInstance>,
vpn_portal: Arc<Mutex<Box<dyn VpnPortal>>>,
global_ctx: ArcGlobalCtx,
}
impl Instance {
pub fn new(config: impl ConfigLoader + Send + Sync + 'static) -> Self {
let global_ctx = Arc::new(GlobalCtx::new(config));
log::info!(
"[INIT] instance creating. config: {}",
global_ctx.config.dump()
);
let (peer_packet_sender, peer_packet_receiver) = tokio::sync::mpsc::channel(100);
let id = global_ctx.get_id();
let peer_manager = Arc::new(PeerManager::new(
RouteAlgoType::Ospf,
global_ctx.clone(),
peer_packet_sender.clone(),
));
let listener_manager = Arc::new(Mutex::new(ListenerManager::new(
global_ctx.clone(),
peer_manager.clone(),
)));
let conn_manager = Arc::new(ManualConnectorManager::new(
global_ctx.clone(),
peer_manager.clone(),
));
let mut direct_conn_manager =
DirectConnectorManager::new(global_ctx.clone(), peer_manager.clone());
direct_conn_manager.run();
let udp_hole_puncher = UdpHolePunchConnector::new(global_ctx.clone(), peer_manager.clone());
let arc_tcp_proxy = TcpProxy::new(global_ctx.clone(), peer_manager.clone());
let arc_icmp_proxy = IcmpProxy::new(global_ctx.clone(), peer_manager.clone())
.with_context(|| "create icmp proxy failed")
.unwrap();
let arc_udp_proxy = UdpProxy::new(global_ctx.clone(), peer_manager.clone())
.with_context(|| "create udp proxy failed")
.unwrap();
let peer_center = Arc::new(PeerCenterInstance::new(peer_manager.clone()));
let vpn_portal_inst = vpn_portal::wireguard::WireGuard::default();
Instance {
inst_name: global_ctx.inst_name.clone(),
id,
virtual_nic: None,
peer_packet_receiver: Some(ReceiverStream::new(peer_packet_receiver)),
tasks: JoinSet::new(),
peer_manager,
listener_manager,
conn_manager,
direct_conn_manager: Arc::new(direct_conn_manager),
udp_hole_puncher: Arc::new(Mutex::new(udp_hole_puncher)),
tcp_proxy: arc_tcp_proxy,
icmp_proxy: arc_icmp_proxy,
udp_proxy: arc_udp_proxy,
peer_center,
vpn_portal: Arc::new(Mutex::new(Box::new(vpn_portal_inst))),
global_ctx,
}
}
pub fn get_conn_manager(&self) -> Arc<ManualConnectorManager> {
self.conn_manager.clone()
}
async fn do_forward_nic_to_peers_ipv4(ret: BytesMut, mgr: &PeerManager) {
if let Some(ipv4) = Ipv4Packet::new(&ret) {
if ipv4.get_version() != 4 {
tracing::info!("[USER_PACKET] not ipv4 packet: {:?}", ipv4);
return;
}
let dst_ipv4 = ipv4.get_destination();
tracing::trace!(
?ret,
"[USER_PACKET] recv new packet from tun device and forward to peers."
);
let send_ret = mgr.send_msg_ipv4(ret, dst_ipv4).await;
if send_ret.is_err() {
tracing::trace!(?send_ret, "[USER_PACKET] send_msg_ipv4 failed")
}
} else {
tracing::warn!(?ret, "[USER_PACKET] not ipv4 packet");
}
}
async fn do_forward_nic_to_peers_ethernet(mut ret: BytesMut, mgr: &PeerManager) {
if let Some(eth) = EthernetPacket::new(&ret) {
log::warn!("begin to forward: {:?}, type: {}", eth, eth.get_ethertype());
Self::do_forward_nic_to_peers_ipv4(ret.split_off(14), mgr).await;
} else {
log::warn!("not ipv4 packet: {:?}", ret);
}
}
fn do_forward_nic_to_peers(&mut self) -> Result<(), Error> {
// read from nic and write to corresponding tunnel
let nic = self.virtual_nic.as_ref().unwrap();
let nic = nic.clone();
let mgr = self.peer_manager.clone();
self.tasks.spawn(async move {
let mut stream = nic.pin_recv_stream();
while let Some(ret) = stream.next().await {
if ret.is_err() {
log::error!("read from nic failed: {:?}", ret);
break;
}
Self::do_forward_nic_to_peers_ipv4(ret.unwrap(), mgr.as_ref()).await;
// Self::do_forward_nic_to_peers_ethernet(ret.into(), mgr.as_ref()).await;
}
});
Ok(())
}
fn do_forward_peers_to_nic(
tasks: &mut JoinSet<()>,
nic: Arc<virtual_nic::VirtualNic>,
channel: Option<ReceiverStream<Bytes>>,
) {
tasks.spawn(async move {
let send = nic.pin_send_stream();
let channel = channel.unwrap();
let ret = channel
.map(|packet| {
log::trace!(
"[USER_PACKET] forward packet from peers to nic. packet: {:?}",
packet
);
Ok(packet)
})
.forward(send)
.await;
if ret.is_err() {
panic!("do_forward_tunnel_to_nic");
}
});
}
async fn add_initial_peers(&mut self) -> Result<(), Error> {
for peer in self.global_ctx.config.get_peers().iter() {
self.get_conn_manager()
.add_connector_by_url(peer.uri.as_str())
.await?;
}
Ok(())
}
async fn prepare_tun_device(&mut self) -> Result<(), Error> {
let nic = virtual_nic::VirtualNic::new(self.get_global_ctx())
.create_dev()
.await?;
self.global_ctx
.issue_event(GlobalCtxEvent::TunDeviceReady(nic.ifname().to_string()));
self.virtual_nic = Some(Arc::new(nic));
self.do_forward_nic_to_peers().unwrap();
Self::do_forward_peers_to_nic(
self.tasks.borrow_mut(),
self.virtual_nic.as_ref().unwrap().clone(),
self.peer_packet_receiver.take(),
);
Ok(())
}
async fn assign_ipv4_to_tun_device(&mut self, ipv4_addr: Ipv4Addr) -> Result<(), Error> {
let nic = self.virtual_nic.as_ref().unwrap().clone();
nic.link_up().await?;
nic.remove_ip(None).await?;
nic.add_ip(ipv4_addr, 24).await?;
if cfg!(target_os = "macos") {
nic.add_route(ipv4_addr, 24).await?;
}
Ok(())
}
pub async fn run(&mut self) -> Result<(), Error> {
self.prepare_tun_device().await?;
if let Some(ipv4_addr) = self.global_ctx.get_ipv4() {
self.assign_ipv4_to_tun_device(ipv4_addr).await?;
}
self.listener_manager
.lock()
.await
.prepare_listeners()
.await?;
self.listener_manager.lock().await.run().await?;
self.peer_manager.run().await?;
self.run_rpc_server().unwrap();
self.tcp_proxy.start().await.unwrap();
self.icmp_proxy.start().await.unwrap();
self.udp_proxy.start().await.unwrap();
self.run_proxy_cidrs_route_updater();
self.udp_hole_puncher.lock().await.run().await?;
self.peer_center.init().await;
self.add_initial_peers().await?;
if let Some(_) = self.global_ctx.get_vpn_portal_cidr() {
self.vpn_portal
.lock()
.await
.start(self.get_global_ctx(), self.get_peer_manager())
.await?;
}
Ok(())
}
pub fn get_peer_manager(&self) -> Arc<PeerManager> {
self.peer_manager.clone()
}
pub async fn close_peer_conn(
&mut self,
peer_id: PeerId,
conn_id: &PeerConnId,
) -> Result<(), Error> {
self.peer_manager
.get_peer_map()
.close_peer_conn(peer_id, conn_id)
.await?;
Ok(())
}
pub async fn wait(&mut self) {
while let Some(ret) = self.tasks.join_next().await {
log::info!("task finished: {:?}", ret);
ret.unwrap();
}
}
pub fn id(&self) -> uuid::Uuid {
self.id
}
pub fn peer_id(&self) -> PeerId {
self.peer_manager.my_peer_id()
}
fn get_vpn_portal_rpc_service(&self) -> impl VpnPortalRpc {
struct VpnPortalRpcService {
peer_mgr: Weak<PeerManager>,
vpn_portal: Weak<Mutex<Box<dyn VpnPortal>>>,
}
#[tonic::async_trait]
impl VpnPortalRpc for VpnPortalRpcService {
async fn get_vpn_portal_info(
&self,
_request: tonic::Request<GetVpnPortalInfoRequest>,
) -> Result<tonic::Response<GetVpnPortalInfoResponse>, tonic::Status> {
let Some(vpn_portal) = self.vpn_portal.upgrade() else {
return Err(tonic::Status::unavailable("vpn portal not available"));
};
let Some(peer_mgr) = self.peer_mgr.upgrade() else {
return Err(tonic::Status::unavailable("peer manager not available"));
};
let vpn_portal = vpn_portal.lock().await;
let ret = GetVpnPortalInfoResponse {
vpn_portal_info: Some(VpnPortalInfo {
vpn_type: vpn_portal.name(),
client_config: vpn_portal.dump_client_config(peer_mgr).await,
connected_clients: vpn_portal.list_clients().await,
}),
};
Ok(tonic::Response::new(ret))
}
}
VpnPortalRpcService {
peer_mgr: Arc::downgrade(&self.peer_manager),
vpn_portal: Arc::downgrade(&self.vpn_portal),
}
}
fn run_rpc_server(&mut self) -> Result<(), Box<dyn std::error::Error>> {
let Some(addr) = self.global_ctx.config.get_rpc_portal() else {
tracing::info!("rpc server not enabled, because rpc_portal is not set.");
return Ok(());
};
let peer_mgr = self.peer_manager.clone();
let conn_manager = self.conn_manager.clone();
let net_ns = self.global_ctx.net_ns.clone();
let peer_center = self.peer_center.clone();
let vpn_portal_rpc = self.get_vpn_portal_rpc_service();
self.tasks.spawn(async move {
let _g = net_ns.guard();
Server::builder()
.add_service(
crate::rpc::peer_manage_rpc_server::PeerManageRpcServer::new(
PeerManagerRpcService::new(peer_mgr),
),
)
.add_service(
crate::rpc::connector_manage_rpc_server::ConnectorManageRpcServer::new(
ConnectorManagerRpcService(conn_manager.clone()),
),
)
.add_service(
crate::rpc::peer_center_rpc_server::PeerCenterRpcServer::new(
peer_center.get_rpc_service(),
),
)
.add_service(crate::rpc::vpn_portal_rpc_server::VpnPortalRpcServer::new(
vpn_portal_rpc,
))
.serve(addr)
.await
.with_context(|| format!("rpc server failed. addr: {}", addr))
.unwrap();
});
Ok(())
}
fn run_proxy_cidrs_route_updater(&mut self) {
let peer_mgr = self.peer_manager.clone();
let global_ctx = self.global_ctx.clone();
let net_ns = self.global_ctx.net_ns.clone();
let nic = self.virtual_nic.as_ref().unwrap().clone();
self.tasks.spawn(async move {
let mut cur_proxy_cidrs = vec![];
loop {
let mut proxy_cidrs = vec![];
let routes = peer_mgr.list_routes().await;
for r in routes {
for cidr in r.proxy_cidrs {
let Ok(cidr) = cidr.parse::<cidr::Ipv4Cidr>() else {
continue;
};
proxy_cidrs.push(cidr);
}
}
// add vpn portal cidr to proxy_cidrs
if let Some(vpn_cfg) = global_ctx.config.get_vpn_portal_config() {
proxy_cidrs.push(vpn_cfg.client_cidr);
}
// if route is in cur_proxy_cidrs but not in proxy_cidrs, delete it.
for cidr in cur_proxy_cidrs.iter() {
if proxy_cidrs.contains(cidr) {
continue;
}
let _g = net_ns.guard();
let ret = nic
.get_ifcfg()
.remove_ipv4_route(
nic.ifname(),
cidr.first_address(),
cidr.network_length(),
)
.await;
if ret.is_err() {
tracing::trace!(
cidr = ?cidr,
err = ?ret,
"remove route failed.",
);
}
}
for cidr in proxy_cidrs.iter() {
if cur_proxy_cidrs.contains(cidr) {
continue;
}
let _g = net_ns.guard();
let ret = nic
.get_ifcfg()
.add_ipv4_route(nic.ifname(), cidr.first_address(), cidr.network_length())
.await;
if ret.is_err() {
tracing::trace!(
cidr = ?cidr,
err = ?ret,
"add route failed.",
);
}
}
cur_proxy_cidrs = proxy_cidrs;
tokio::time::sleep(std::time::Duration::from_secs(1)).await;
}
});
}
pub fn get_global_ctx(&self) -> ArcGlobalCtx {
self.global_ctx.clone()
}
}
+202
View File
@@ -0,0 +1,202 @@
use std::{fmt::Debug, sync::Arc};
use anyhow::Context;
use async_trait::async_trait;
use tokio::{sync::Mutex, task::JoinSet};
use crate::{
common::{
error::Error,
global_ctx::{ArcGlobalCtx, GlobalCtxEvent},
netns::NetNS,
},
peers::peer_manager::PeerManager,
tunnels::{
ring_tunnel::RingTunnelListener,
tcp_tunnel::TcpTunnelListener,
udp_tunnel::UdpTunnelListener,
wireguard::{WgConfig, WgTunnelListener},
Tunnel, TunnelListener,
},
};
#[async_trait]
pub trait TunnelHandlerForListener {
async fn handle_tunnel(&self, tunnel: Box<dyn Tunnel>) -> Result<(), Error>;
}
#[async_trait]
impl TunnelHandlerForListener for PeerManager {
#[tracing::instrument]
async fn handle_tunnel(&self, tunnel: Box<dyn Tunnel>) -> Result<(), Error> {
self.add_tunnel_as_server(tunnel).await
}
}
pub struct ListenerManager<H> {
global_ctx: ArcGlobalCtx,
net_ns: NetNS,
listeners: Vec<Arc<Mutex<dyn TunnelListener>>>,
peer_manager: Arc<H>,
tasks: JoinSet<()>,
}
impl<H: TunnelHandlerForListener + Send + Sync + 'static + Debug> ListenerManager<H> {
pub fn new(global_ctx: ArcGlobalCtx, peer_manager: Arc<H>) -> Self {
Self {
global_ctx: global_ctx.clone(),
net_ns: global_ctx.net_ns.clone(),
listeners: Vec::new(),
peer_manager,
tasks: JoinSet::new(),
}
}
pub async fn prepare_listeners(&mut self) -> Result<(), Error> {
self.add_listener(RingTunnelListener::new(
format!("ring://{}", self.global_ctx.get_id())
.parse()
.unwrap(),
))
.await?;
for l in self.global_ctx.config.get_listener_uris().iter() {
match l.scheme() {
"tcp" => {
self.add_listener(TcpTunnelListener::new(l.clone())).await?;
}
"udp" => {
self.add_listener(UdpTunnelListener::new(l.clone())).await?;
}
"wg" => {
let nid = self.global_ctx.get_network_identity();
let wg_config =
WgConfig::new_from_network_identity(&nid.network_name, &nid.network_secret);
self.add_listener(WgTunnelListener::new(l.clone(), wg_config))
.await?;
}
_ => {
log::warn!("unsupported listener uri: {}", l);
}
}
}
Ok(())
}
pub async fn add_listener<Listener>(&mut self, listener: Listener) -> Result<(), Error>
where
Listener: TunnelListener + 'static,
{
let listener = Arc::new(Mutex::new(listener));
self.listeners.push(listener);
Ok(())
}
#[tracing::instrument]
async fn run_listener(
listener: Arc<Mutex<dyn TunnelListener>>,
peer_manager: Arc<H>,
global_ctx: ArcGlobalCtx,
) {
let mut l = listener.lock().await;
global_ctx.add_running_listener(l.local_url());
global_ctx.issue_event(GlobalCtxEvent::ListenerAdded(l.local_url()));
while let Ok(ret) = l.accept().await {
let tunnel_info = ret.info().unwrap();
global_ctx.issue_event(GlobalCtxEvent::ConnectionAccepted(
tunnel_info.local_addr.clone(),
tunnel_info.remote_addr.clone(),
));
tracing::info!(ret = ?ret, "conn accepted");
let peer_manager = peer_manager.clone();
let global_ctx = global_ctx.clone();
tokio::spawn(async move {
let server_ret = peer_manager.handle_tunnel(ret).await;
if let Err(e) = &server_ret {
global_ctx.issue_event(GlobalCtxEvent::ConnectionError(
tunnel_info.local_addr,
tunnel_info.remote_addr,
e.to_string(),
));
tracing::error!(error = ?e, "handle conn error");
}
});
}
}
pub async fn run(&mut self) -> Result<(), Error> {
for listener in &self.listeners {
let _guard = self.net_ns.guard();
let addr = listener.lock().await.local_url();
log::warn!("run listener: {:?}", listener);
listener
.lock()
.await
.listen()
.await
.with_context(|| format!("failed to add listener {}", addr))?;
self.tasks.spawn(Self::run_listener(
listener.clone(),
self.peer_manager.clone(),
self.global_ctx.clone(),
));
}
Ok(())
}
}
#[cfg(test)]
mod tests {
use futures::{SinkExt, StreamExt};
use tokio::time::timeout;
use crate::{
common::global_ctx::tests::get_mock_global_ctx,
tunnels::{ring_tunnel::RingTunnelConnector, TunnelConnector},
};
use super::*;
#[derive(Debug)]
struct MockListenerHandler {}
#[async_trait]
impl TunnelHandlerForListener for MockListenerHandler {
async fn handle_tunnel(&self, _tunnel: Box<dyn Tunnel>) -> Result<(), Error> {
let data = "abc";
_tunnel.pin_sink().send(data.into()).await.unwrap();
Err(Error::Unknown)
}
}
#[tokio::test]
async fn handle_error_in_accept() {
let handler = Arc::new(MockListenerHandler {});
let mut listener_mgr = ListenerManager::new(get_mock_global_ctx(), handler.clone());
let ring_id = format!("ring://{}", uuid::Uuid::new_v4());
listener_mgr
.add_listener(RingTunnelListener::new(ring_id.parse().unwrap()))
.await
.unwrap();
listener_mgr.run().await.unwrap();
let connect_once = |ring_id| async move {
let tunnel = RingTunnelConnector::new(ring_id).connect().await.unwrap();
assert_eq!(tunnel.pin_stream().next().await.unwrap().unwrap(), "abc");
tunnel
};
timeout(std::time::Duration::from_secs(1), async move {
connect_once(ring_id.parse().unwrap()).await;
// handle tunnel fail should not impact the second connect
connect_once(ring_id.parse().unwrap()).await;
})
.await
.unwrap();
}
}
+4
View File
@@ -0,0 +1,4 @@
pub mod instance;
pub mod listeners;
pub mod tun_codec;
pub mod virtual_nic;
+179
View File
@@ -0,0 +1,179 @@
use std::io;
use byteorder::{NativeEndian, NetworkEndian, WriteBytesExt};
use tokio_util::bytes::{BufMut, Bytes, BytesMut};
use tokio_util::codec::{Decoder, Encoder};
/// A packet protocol IP version
#[derive(Debug, Clone, Copy, Default)]
enum PacketProtocol {
#[default]
IPv4,
IPv6,
Other(u8),
}
// Note: the protocol in the packet information header is platform dependent.
impl PacketProtocol {
#[cfg(any(target_os = "linux", target_os = "android"))]
fn into_pi_field(self) -> Result<u16, io::Error> {
use nix::libc;
match self {
PacketProtocol::IPv4 => Ok(libc::ETH_P_IP as u16),
PacketProtocol::IPv6 => Ok(libc::ETH_P_IPV6 as u16),
PacketProtocol::Other(_) => Err(io::Error::new(
io::ErrorKind::Other,
"neither an IPv4 nor IPv6 packet",
)),
}
}
#[cfg(any(target_os = "macos", target_os = "ios"))]
fn into_pi_field(self) -> Result<u16, io::Error> {
use nix::libc;
match self {
PacketProtocol::IPv4 => Ok(libc::PF_INET as u16),
PacketProtocol::IPv6 => Ok(libc::PF_INET6 as u16),
PacketProtocol::Other(_) => Err(io::Error::new(
io::ErrorKind::Other,
"neither an IPv4 nor IPv6 packet",
)),
}
}
#[cfg(target_os = "windows")]
fn into_pi_field(self) -> Result<u16, io::Error> {
unimplemented!()
}
}
#[derive(Debug)]
pub enum TunPacketBuffer {
Bytes(Bytes),
BytesMut(BytesMut),
}
impl From<TunPacketBuffer> for Bytes {
fn from(buf: TunPacketBuffer) -> Self {
match buf {
TunPacketBuffer::Bytes(bytes) => bytes,
TunPacketBuffer::BytesMut(bytes) => bytes.freeze(),
}
}
}
impl AsRef<[u8]> for TunPacketBuffer {
fn as_ref(&self) -> &[u8] {
match self {
TunPacketBuffer::Bytes(bytes) => bytes.as_ref(),
TunPacketBuffer::BytesMut(bytes) => bytes.as_ref(),
}
}
}
/// A Tun Packet to be sent or received on the TUN interface.
#[derive(Debug)]
pub struct TunPacket(PacketProtocol, TunPacketBuffer);
/// Infer the protocol based on the first nibble in the packet buffer.
fn infer_proto(buf: &[u8]) -> PacketProtocol {
match buf[0] >> 4 {
4 => PacketProtocol::IPv4,
6 => PacketProtocol::IPv6,
p => PacketProtocol::Other(p),
}
}
impl TunPacket {
/// Create a new `TunPacket` based on a byte slice.
pub fn new(buffer: TunPacketBuffer) -> TunPacket {
let proto = infer_proto(buffer.as_ref());
TunPacket(proto, buffer)
}
/// Return this packet's bytes.
pub fn get_bytes(&self) -> &[u8] {
match &self.1 {
TunPacketBuffer::Bytes(bytes) => bytes.as_ref(),
TunPacketBuffer::BytesMut(bytes) => bytes.as_ref(),
}
}
pub fn into_bytes(self) -> Bytes {
match self.1 {
TunPacketBuffer::Bytes(bytes) => bytes,
TunPacketBuffer::BytesMut(bytes) => bytes.freeze(),
}
}
pub fn into_bytes_mut(self) -> BytesMut {
match self.1 {
TunPacketBuffer::Bytes(_) => panic!("cannot into_bytes_mut from bytes"),
TunPacketBuffer::BytesMut(bytes) => bytes,
}
}
}
/// A TunPacket Encoder/Decoder.
pub struct TunPacketCodec(bool, i32);
impl TunPacketCodec {
/// Create a new `TunPacketCodec` specifying whether the underlying
/// tunnel Device has enabled the packet information header.
pub fn new(pi: bool, mtu: i32) -> TunPacketCodec {
TunPacketCodec(pi, mtu)
}
}
impl Decoder for TunPacketCodec {
type Item = TunPacket;
type Error = io::Error;
fn decode(&mut self, buf: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
if buf.is_empty() {
return Ok(None);
}
let mut pkt = buf.split_to(buf.len());
// reserve enough space for the next packet
if self.0 {
buf.reserve(self.1 as usize + 4);
} else {
buf.reserve(self.1 as usize);
}
// if the packet information is enabled we have to ignore the first 4 bytes
if self.0 {
let _ = pkt.split_to(4);
}
let proto = infer_proto(pkt.as_ref());
Ok(Some(TunPacket(proto, TunPacketBuffer::BytesMut(pkt))))
}
}
impl Encoder<TunPacket> for TunPacketCodec {
type Error = io::Error;
fn encode(&mut self, item: TunPacket, dst: &mut BytesMut) -> Result<(), Self::Error> {
dst.reserve(item.get_bytes().len() + 4);
match item {
TunPacket(proto, bytes) if self.0 => {
// build the packet information header comprising of 2 u16
// fields: flags and protocol.
let mut buf = Vec::<u8>::with_capacity(4);
// flags is always 0
buf.write_u16::<NativeEndian>(0)?;
// write the protocol as network byte order
buf.write_u16::<NetworkEndian>(proto.into_pi_field()?)?;
dst.put_slice(&buf);
dst.put(Bytes::from(bytes));
}
TunPacket(_, bytes) => dst.put(Bytes::from(bytes)),
}
Ok(())
}
}
+196
View File
@@ -0,0 +1,196 @@
use std::{net::Ipv4Addr, pin::Pin};
use crate::{
common::{
error::Result,
global_ctx::ArcGlobalCtx,
ifcfg::{IfConfiger, IfConfiguerTrait},
},
tunnels::{
codec::BytesCodec, common::FramedTunnel, DatagramSink, DatagramStream, Tunnel, TunnelError,
},
};
use futures::{SinkExt, StreamExt};
use tokio_util::{bytes::Bytes, codec::Framed};
use tun::Device;
use super::tun_codec::{TunPacket, TunPacketCodec};
pub struct VirtualNic {
dev_name: String,
queue_num: usize,
global_ctx: ArcGlobalCtx,
ifname: Option<String>,
tun: Option<Box<dyn Tunnel>>,
ifcfg: Box<dyn IfConfiguerTrait + Send + Sync + 'static>,
}
impl VirtualNic {
pub fn new(global_ctx: ArcGlobalCtx) -> Self {
Self {
dev_name: "".to_owned(),
queue_num: 1,
global_ctx,
ifname: None,
tun: None,
ifcfg: Box::new(IfConfiger {}),
}
}
pub fn set_dev_name(mut self, dev_name: &str) -> Result<Self> {
self.dev_name = dev_name.to_owned();
Ok(self)
}
pub fn set_queue_num(mut self, queue_num: usize) -> Result<Self> {
self.queue_num = queue_num;
Ok(self)
}
async fn create_dev_ret_err(&mut self) -> Result<()> {
let mut config = tun::Configuration::default();
let has_packet_info = cfg!(target_os = "macos");
config.layer(tun::Layer::L3);
#[cfg(target_os = "linux")]
{
config.platform(|config| {
// detect protocol by ourselves for cross platform
config.packet_information(false);
});
}
if self.queue_num != 1 {
todo!("queue_num != 1")
}
config.queues(self.queue_num);
config.up();
let dev = {
let _g = self.global_ctx.net_ns.guard();
tun::create_as_async(&config)?
};
let ifname = dev.get_ref().name()?;
self.ifcfg.wait_interface_show(ifname.as_str()).await?;
let ft: Box<dyn Tunnel> = if has_packet_info {
let framed = Framed::new(dev, TunPacketCodec::new(true, 2500));
let (sink, stream) = framed.split();
let new_stream = stream.map(|item| match item {
Ok(item) => Ok(item.into_bytes_mut()),
Err(err) => {
println!("tun stream error: {:?}", err);
Err(TunnelError::TunError(err.to_string()))
}
});
let new_sink = Box::pin(sink.with(|item: Bytes| async move {
if false {
return Err(TunnelError::TunError("tun sink error".to_owned()));
}
Ok(TunPacket::new(super::tun_codec::TunPacketBuffer::Bytes(
item,
)))
}));
Box::new(FramedTunnel::new(new_stream, new_sink, None))
} else {
let framed = Framed::new(dev, BytesCodec::new(2500));
let (sink, stream) = framed.split();
Box::new(FramedTunnel::new(stream, sink, None))
};
self.ifname = Some(ifname.to_owned());
self.tun = Some(ft);
Ok(())
}
pub async fn create_dev(mut self) -> Result<Self> {
self.create_dev_ret_err().await?;
Ok(self)
}
pub fn ifname(&self) -> &str {
self.ifname.as_ref().unwrap().as_str()
}
pub async fn link_up(&self) -> Result<()> {
let _g = self.global_ctx.net_ns.guard();
self.ifcfg.set_link_status(self.ifname(), true).await?;
Ok(())
}
pub async fn add_route(&self, address: Ipv4Addr, cidr: u8) -> Result<()> {
let _g = self.global_ctx.net_ns.guard();
self.ifcfg
.add_ipv4_route(self.ifname(), address, cidr)
.await?;
Ok(())
}
pub async fn remove_ip(&self, ip: Option<Ipv4Addr>) -> Result<()> {
let _g = self.global_ctx.net_ns.guard();
self.ifcfg.remove_ip(self.ifname(), ip).await?;
Ok(())
}
pub async fn add_ip(&self, ip: Ipv4Addr, cidr: i32) -> Result<()> {
let _g = self.global_ctx.net_ns.guard();
self.ifcfg
.add_ipv4_ip(self.ifname(), ip, cidr as u8)
.await?;
Ok(())
}
pub fn pin_recv_stream(&self) -> Pin<Box<dyn DatagramStream>> {
self.tun.as_ref().unwrap().pin_stream()
}
pub fn pin_send_stream(&self) -> Pin<Box<dyn DatagramSink>> {
self.tun.as_ref().unwrap().pin_sink()
}
pub fn get_ifcfg(&self) -> &dyn IfConfiguerTrait {
self.ifcfg.as_ref()
}
}
#[cfg(test)]
mod tests {
use crate::common::{error::Error, global_ctx::tests::get_mock_global_ctx};
use super::VirtualNic;
async fn run_test_helper() -> Result<VirtualNic, Error> {
let dev = VirtualNic::new(get_mock_global_ctx()).create_dev().await?;
tokio::time::sleep(tokio::time::Duration::from_secs(1)).await;
dev.link_up().await?;
dev.remove_ip(None).await?;
dev.add_ip("10.144.111.1".parse().unwrap(), 24).await?;
Ok(dev)
}
#[tokio::test]
async fn tun_test() {
let _dev = run_test_helper().await.unwrap();
// let mut stream = nic.pin_recv_stream();
// while let Some(item) = stream.next().await {
// println!("item: {:?}", item);
// }
// let framed = dev.into_framed();
// let (mut s, mut b) = framed.split();
// loop {
// let tmp = b.next().await.unwrap().unwrap();
// let tmp = EthernetPacket::new(tmp.get_bytes());
// println!("ret: {:?}", tmp.unwrap());
// }
}
}
+382
View File
@@ -0,0 +1,382 @@
use std::{
collections::hash_map::DefaultHasher,
hash::{Hash, Hasher},
sync::{
atomic::{AtomicBool, Ordering},
Arc,
},
time::{Duration, SystemTime},
};
use crossbeam::atomic::AtomicCell;
use futures::Future;
use tokio::{
sync::{Mutex, RwLock},
task::JoinSet,
};
use tracing::Instrument;
use crate::{
common::PeerId,
peers::{peer_manager::PeerManager, rpc_service::PeerManagerRpcService},
rpc::{GetGlobalPeerMapRequest, GetGlobalPeerMapResponse},
};
use super::{
server::PeerCenterServer,
service::{GlobalPeerMap, PeerCenterService, PeerCenterServiceClient, PeerInfoForGlobalMap},
Digest, Error,
};
struct PeerCenterBase {
peer_mgr: Arc<PeerManager>,
tasks: Arc<Mutex<JoinSet<()>>>,
lock: Arc<Mutex<()>>,
}
static SERVICE_ID: u32 = 5;
struct PeridicJobCtx<T> {
peer_mgr: Arc<PeerManager>,
center_peer: AtomicCell<PeerId>,
job_ctx: T,
}
impl PeerCenterBase {
pub async fn init(&self) -> Result<(), Error> {
self.peer_mgr.get_peer_rpc_mgr().run_service(
SERVICE_ID,
PeerCenterServer::new(self.peer_mgr.my_peer_id()).serve(),
);
Ok(())
}
async fn select_center_peer(peer_mgr: &Arc<PeerManager>) -> Option<PeerId> {
let peers = peer_mgr.list_routes().await;
if peers.is_empty() {
return None;
}
// find peer with alphabetical smallest id.
let mut min_peer = peer_mgr.my_peer_id();
for peer in peers.iter() {
let peer_id = peer.peer_id;
if peer_id < min_peer {
min_peer = peer_id;
}
}
Some(min_peer)
}
async fn init_periodic_job<
T: Send + Sync + 'static + Clone,
Fut: Future<Output = Result<u32, tarpc::client::RpcError>> + Send + 'static,
>(
&self,
job_ctx: T,
job_fn: (impl Fn(PeerCenterServiceClient, Arc<PeridicJobCtx<T>>) -> Fut + Send + Sync + 'static),
) -> () {
let my_peer_id = self.peer_mgr.my_peer_id();
let peer_mgr = self.peer_mgr.clone();
let lock = self.lock.clone();
self.tasks.lock().await.spawn(
async move {
let ctx = Arc::new(PeridicJobCtx {
peer_mgr: peer_mgr.clone(),
center_peer: AtomicCell::new(PeerId::default()),
job_ctx,
});
loop {
let Some(center_peer) = Self::select_center_peer(&peer_mgr).await else {
tracing::trace!("no center peer found, sleep 1 second");
tokio::time::sleep(Duration::from_secs(1)).await;
continue;
};
ctx.center_peer.store(center_peer.clone());
tracing::trace!(?center_peer, "run periodic job");
let rpc_mgr = peer_mgr.get_peer_rpc_mgr();
let _g = lock.lock().await;
let ret = rpc_mgr
.do_client_rpc_scoped(SERVICE_ID, center_peer, |c| async {
let client =
PeerCenterServiceClient::new(tarpc::client::Config::default(), c)
.spawn();
job_fn(client, ctx.clone()).await
})
.await;
drop(_g);
let Ok(sleep_time_ms) = ret else {
tracing::error!("periodic job to center server rpc failed: {:?}", ret);
tokio::time::sleep(Duration::from_secs(3)).await;
continue;
};
if sleep_time_ms > 0 {
tokio::time::sleep(Duration::from_millis(sleep_time_ms as u64)).await;
}
}
}
.instrument(tracing::info_span!("periodic_job", ?my_peer_id)),
);
}
pub fn new(peer_mgr: Arc<PeerManager>) -> Self {
PeerCenterBase {
peer_mgr,
tasks: Arc::new(Mutex::new(JoinSet::new())),
lock: Arc::new(Mutex::new(())),
}
}
}
pub struct PeerCenterInstanceService {
global_peer_map: Arc<RwLock<GlobalPeerMap>>,
global_peer_map_digest: Arc<RwLock<Digest>>,
}
#[tonic::async_trait]
impl crate::rpc::cli::peer_center_rpc_server::PeerCenterRpc for PeerCenterInstanceService {
async fn get_global_peer_map(
&self,
_request: tonic::Request<GetGlobalPeerMapRequest>,
) -> Result<tonic::Response<GetGlobalPeerMapResponse>, tonic::Status> {
let global_peer_map = self.global_peer_map.read().await.clone();
Ok(tonic::Response::new(GetGlobalPeerMapResponse {
global_peer_map: global_peer_map
.map
.into_iter()
.map(|(k, v)| (k, v))
.collect(),
}))
}
}
pub struct PeerCenterInstance {
peer_mgr: Arc<PeerManager>,
client: Arc<PeerCenterBase>,
global_peer_map: Arc<RwLock<GlobalPeerMap>>,
global_peer_map_digest: Arc<RwLock<Digest>>,
}
impl PeerCenterInstance {
pub fn new(peer_mgr: Arc<PeerManager>) -> Self {
PeerCenterInstance {
peer_mgr: peer_mgr.clone(),
client: Arc::new(PeerCenterBase::new(peer_mgr.clone())),
global_peer_map: Arc::new(RwLock::new(GlobalPeerMap::new())),
global_peer_map_digest: Arc::new(RwLock::new(Digest::default())),
}
}
pub async fn init(&self) {
self.client.init().await.unwrap();
self.init_get_global_info_job().await;
self.init_report_peers_job().await;
}
async fn init_get_global_info_job(&self) {
struct Ctx {
global_peer_map: Arc<RwLock<GlobalPeerMap>>,
global_peer_map_digest: Arc<RwLock<Digest>>,
}
let ctx = Arc::new(Ctx {
global_peer_map: self.global_peer_map.clone(),
global_peer_map_digest: self.global_peer_map_digest.clone(),
});
self.client
.init_periodic_job(ctx, |client, ctx| async move {
let mut rpc_ctx = tarpc::context::current();
rpc_ctx.deadline = SystemTime::now() + Duration::from_secs(3);
let ret = client
.get_global_peer_map(
rpc_ctx,
ctx.job_ctx.global_peer_map_digest.read().await.clone(),
)
.await?;
let Ok(resp) = ret else {
tracing::error!(
"get global info from center server got error result: {:?}",
ret
);
return Ok(1000);
};
let Some(resp) = resp else {
return Ok(5000);
};
tracing::info!(
"get global info from center server: {:?}, digest: {:?}",
resp.global_peer_map,
resp.digest
);
*ctx.job_ctx.global_peer_map.write().await = resp.global_peer_map;
*ctx.job_ctx.global_peer_map_digest.write().await = resp.digest;
Ok(10000)
})
.await;
}
async fn init_report_peers_job(&self) {
struct Ctx {
service: PeerManagerRpcService,
need_send_peers: AtomicBool,
last_report_peers: Mutex<PeerInfoForGlobalMap>,
last_center_peer: AtomicCell<PeerId>,
}
let ctx = Arc::new(Ctx {
service: PeerManagerRpcService::new(self.peer_mgr.clone()),
need_send_peers: AtomicBool::new(true),
last_report_peers: Mutex::new(PeerInfoForGlobalMap::default()),
last_center_peer: AtomicCell::new(PeerId::default()),
});
self.client
.init_periodic_job(ctx, |client, ctx| async move {
let my_node_id = ctx.peer_mgr.my_peer_id();
// if peers are not same in next 10 seconds, report peers to center server
let mut peers = PeerInfoForGlobalMap::default();
for _ in 1..10 {
peers = ctx.job_ctx.service.list_peers().await.into();
if ctx.center_peer.load() != ctx.job_ctx.last_center_peer.load() {
// if center peer changed, report peers immediately
break;
}
if peers == *ctx.job_ctx.last_report_peers.lock().await {
return Ok(3000);
}
tokio::time::sleep(Duration::from_secs(2)).await;
}
*ctx.job_ctx.last_report_peers.lock().await = peers.clone();
let mut hasher = DefaultHasher::new();
peers.hash(&mut hasher);
let peers = if ctx.job_ctx.need_send_peers.load(Ordering::Relaxed) {
Some(peers)
} else {
None
};
let mut rpc_ctx = tarpc::context::current();
rpc_ctx.deadline = SystemTime::now() + Duration::from_secs(3);
let ret = client
.report_peers(
rpc_ctx,
my_node_id.clone(),
peers,
hasher.finish() as Digest,
)
.await?;
if matches!(ret.as_ref().err(), Some(Error::DigestMismatch)) {
ctx.job_ctx.need_send_peers.store(true, Ordering::Relaxed);
return Ok(0);
} else if ret.is_err() {
tracing::error!("report peers to center server got error result: {:?}", ret);
return Ok(500);
}
ctx.job_ctx.last_center_peer.store(ctx.center_peer.load());
ctx.job_ctx.need_send_peers.store(false, Ordering::Relaxed);
Ok(3000)
})
.await;
}
pub fn get_rpc_service(&self) -> PeerCenterInstanceService {
PeerCenterInstanceService {
global_peer_map: self.global_peer_map.clone(),
global_peer_map_digest: self.global_peer_map_digest.clone(),
}
}
}
#[cfg(test)]
mod tests {
use std::ops::Deref;
use crate::{
peer_center::server::get_global_data,
peers::tests::{connect_peer_manager, create_mock_peer_manager, wait_route_appear},
};
use super::*;
#[tokio::test]
async fn test_peer_center_instance() {
let peer_mgr_a = create_mock_peer_manager().await;
let peer_mgr_b = create_mock_peer_manager().await;
let peer_mgr_c = create_mock_peer_manager().await;
let peer_center_a = PeerCenterInstance::new(peer_mgr_a.clone());
let peer_center_b = PeerCenterInstance::new(peer_mgr_b.clone());
let peer_center_c = PeerCenterInstance::new(peer_mgr_c.clone());
let peer_centers = vec![&peer_center_a, &peer_center_b, &peer_center_c];
for pc in peer_centers.iter() {
pc.init().await;
}
connect_peer_manager(peer_mgr_a.clone(), peer_mgr_b.clone()).await;
connect_peer_manager(peer_mgr_b.clone(), peer_mgr_c.clone()).await;
wait_route_appear(peer_mgr_a.clone(), peer_mgr_c.clone())
.await
.unwrap();
let center_peer = PeerCenterBase::select_center_peer(&peer_mgr_a)
.await
.unwrap();
let center_data = get_global_data(center_peer);
// wait center_data has 3 records for 10 seconds
let now = std::time::Instant::now();
loop {
if center_data.read().await.global_peer_map.map.len() == 3 {
println!(
"center data ready, {:#?}",
center_data.read().await.global_peer_map
);
break;
}
if now.elapsed().as_secs() > 60 {
panic!("center data not ready");
}
tokio::time::sleep(Duration::from_millis(100)).await;
}
let mut digest = None;
for pc in peer_centers.iter() {
let rpc_service = pc.get_rpc_service();
let now = std::time::Instant::now();
while now.elapsed().as_secs() < 10 {
if rpc_service.global_peer_map.read().await.map.len() == 3 {
break;
}
tokio::time::sleep(Duration::from_millis(100)).await;
}
assert_eq!(rpc_service.global_peer_map.read().await.map.len(), 3);
println!("rpc service ready, {:#?}", rpc_service.global_peer_map);
if digest.is_none() {
digest = Some(rpc_service.global_peer_map_digest.read().await.clone());
} else {
let v = rpc_service.global_peer_map_digest.read().await;
assert_eq!(digest.as_ref().unwrap(), v.deref());
}
}
let global_digest = get_global_data(center_peer).read().await.digest.clone();
assert_eq!(digest.as_ref().unwrap(), &global_digest);
}
}
+20
View File
@@ -0,0 +1,20 @@
// peer_center is used to collect peer info into one peer node.
// the center node is selected with the following rules:
// 1. has smallest peer id
// 2. TODO: has allow_to_be_center peer feature
// peer center is not guaranteed to be stable and can be changed when peer enter or leave.
// it's used to reduce the cost to exchange infos between peers.
pub mod instance;
mod server;
mod service;
#[derive(thiserror::Error, Debug, serde::Deserialize, serde::Serialize)]
pub enum Error {
#[error("Digest not match, need provide full peer info to center server.")]
DigestMismatch,
#[error("Not center server")]
NotCenterServer,
}
pub type Digest = u64;
+152
View File
@@ -0,0 +1,152 @@
use std::{
hash::{Hash, Hasher},
sync::Arc,
};
use dashmap::DashMap;
use once_cell::sync::Lazy;
use tokio::{sync::RwLock, task::JoinSet};
use crate::common::PeerId;
use super::{
service::{GetGlobalPeerMapResponse, GlobalPeerMap, PeerCenterService, PeerInfoForGlobalMap},
Digest, Error,
};
pub(crate) struct PeerCenterServerGlobalData {
pub global_peer_map: GlobalPeerMap,
pub digest: Digest,
pub update_time: std::time::Instant,
pub peer_update_time: DashMap<PeerId, std::time::Instant>,
}
impl PeerCenterServerGlobalData {
fn new() -> Self {
PeerCenterServerGlobalData {
global_peer_map: GlobalPeerMap::new(),
digest: Digest::default(),
update_time: std::time::Instant::now(),
peer_update_time: DashMap::new(),
}
}
}
// a global unique instance for PeerCenterServer
pub(crate) static GLOBAL_DATA: Lazy<DashMap<PeerId, Arc<RwLock<PeerCenterServerGlobalData>>>> =
Lazy::new(DashMap::new);
pub(crate) fn get_global_data(node_id: PeerId) -> Arc<RwLock<PeerCenterServerGlobalData>> {
GLOBAL_DATA
.entry(node_id)
.or_insert_with(|| Arc::new(RwLock::new(PeerCenterServerGlobalData::new())))
.value()
.clone()
}
#[derive(Clone, Debug)]
pub struct PeerCenterServer {
// every peer has its own server, so use per-struct dash map is ok.
my_node_id: PeerId,
digest_map: DashMap<PeerId, Digest>,
tasks: Arc<JoinSet<()>>,
}
impl PeerCenterServer {
pub fn new(my_node_id: PeerId) -> Self {
let mut tasks = JoinSet::new();
tasks.spawn(async move {
loop {
tokio::time::sleep(std::time::Duration::from_secs(10)).await;
PeerCenterServer::clean_outdated_peer(my_node_id).await;
}
});
PeerCenterServer {
my_node_id,
digest_map: DashMap::new(),
tasks: Arc::new(tasks),
}
}
async fn clean_outdated_peer(my_node_id: PeerId) {
let data = get_global_data(my_node_id);
let mut locked_data = data.write().await;
let now = std::time::Instant::now();
let mut to_remove = Vec::new();
for kv in locked_data.peer_update_time.iter() {
if now.duration_since(*kv.value()).as_secs() > 20 {
to_remove.push(*kv.key());
}
}
for peer_id in to_remove {
locked_data.global_peer_map.map.remove(&peer_id);
locked_data.peer_update_time.remove(&peer_id);
}
}
}
#[tarpc::server]
impl PeerCenterService for PeerCenterServer {
#[tracing::instrument()]
async fn report_peers(
self,
_: tarpc::context::Context,
my_peer_id: PeerId,
peers: Option<PeerInfoForGlobalMap>,
digest: Digest,
) -> Result<(), Error> {
tracing::trace!("receive report_peers");
let data = get_global_data(self.my_node_id);
let mut locked_data = data.write().await;
locked_data
.peer_update_time
.insert(my_peer_id, std::time::Instant::now());
let old_digest = self.digest_map.get(&my_peer_id);
// if digest match, no need to update
if let Some(old_digest) = old_digest {
if *old_digest == digest {
return Ok(());
}
}
if peers.is_none() {
return Err(Error::DigestMismatch);
}
self.digest_map.insert(my_peer_id, digest);
locked_data
.global_peer_map
.map
.insert(my_peer_id, peers.unwrap());
let mut hasher = std::collections::hash_map::DefaultHasher::new();
locked_data.global_peer_map.map.hash(&mut hasher);
locked_data.digest = hasher.finish() as Digest;
locked_data.update_time = std::time::Instant::now();
Ok(())
}
async fn get_global_peer_map(
self,
_: tarpc::context::Context,
digest: Digest,
) -> Result<Option<GetGlobalPeerMapResponse>, Error> {
let data = get_global_data(self.my_node_id);
if digest == data.read().await.digest {
return Ok(None);
}
let data = get_global_data(self.my_node_id);
let locked_data = data.read().await;
Ok(Some(GetGlobalPeerMapResponse {
global_peer_map: locked_data.global_peer_map.clone(),
digest: locked_data.digest,
}))
}
}
+84
View File
@@ -0,0 +1,84 @@
use std::collections::BTreeMap;
use crate::{common::PeerId, rpc::DirectConnectedPeerInfo};
use super::{Digest, Error};
use crate::rpc::PeerInfo;
pub type LatencyLevel = crate::rpc::cli::LatencyLevel;
impl LatencyLevel {
pub const fn from_latency_ms(lat_ms: u32) -> Self {
if lat_ms < 10 {
LatencyLevel::VeryLow
} else if lat_ms < 50 {
LatencyLevel::Low
} else if lat_ms < 100 {
LatencyLevel::Normal
} else if lat_ms < 200 {
LatencyLevel::High
} else {
LatencyLevel::VeryHigh
}
}
}
pub type PeerInfoForGlobalMap = crate::rpc::cli::PeerInfoForGlobalMap;
impl From<Vec<PeerInfo>> for PeerInfoForGlobalMap {
fn from(peers: Vec<PeerInfo>) -> Self {
let mut peer_map = BTreeMap::new();
for peer in peers {
let min_lat = peer
.conns
.iter()
.map(|conn| conn.stats.as_ref().unwrap().latency_us)
.min()
.unwrap_or(0);
let dp_info = DirectConnectedPeerInfo {
latency_level: LatencyLevel::from_latency_ms(min_lat as u32 / 1000) as i32,
};
// sort conn info so hash result is stable
peer_map.insert(peer.peer_id, dp_info);
}
PeerInfoForGlobalMap {
direct_peers: peer_map,
}
}
}
// a global peer topology map, peers can use it to find optimal path to other peers
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
pub struct GlobalPeerMap {
pub map: BTreeMap<PeerId, PeerInfoForGlobalMap>,
}
impl GlobalPeerMap {
pub fn new() -> Self {
GlobalPeerMap {
map: BTreeMap::new(),
}
}
}
#[derive(Debug, Clone, serde::Deserialize, serde::Serialize)]
pub struct GetGlobalPeerMapResponse {
pub global_peer_map: GlobalPeerMap,
pub digest: Digest,
}
#[tarpc::service]
pub trait PeerCenterService {
// report center server which peer is directly connected to me
// digest is a hash of current peer map, if digest not match, we need to transfer the whole map
async fn report_peers(
my_peer_id: PeerId,
peers: Option<PeerInfoForGlobalMap>,
digest: Digest,
) -> Result<(), Error>;
async fn get_global_peer_map(digest: Digest)
-> Result<Option<GetGlobalPeerMapResponse>, Error>;
}
@@ -0,0 +1,200 @@
use std::{
sync::Arc,
time::{Duration, SystemTime},
};
use dashmap::DashMap;
use tokio::{
sync::{mpsc, Mutex},
task::JoinSet,
};
use tokio_util::bytes::Bytes;
use crate::common::{
error::Error,
global_ctx::{ArcGlobalCtx, NetworkIdentity},
PeerId,
};
use super::{
foreign_network_manager::{ForeignNetworkServiceClient, FOREIGN_NETWORK_SERVICE_ID},
peer_conn::PeerConn,
peer_map::PeerMap,
peer_rpc::PeerRpcManager,
};
pub struct ForeignNetworkClient {
global_ctx: ArcGlobalCtx,
peer_rpc: Arc<PeerRpcManager>,
my_peer_id: PeerId,
peer_map: Arc<PeerMap>,
next_hop: Arc<DashMap<PeerId, PeerId>>,
tasks: Mutex<JoinSet<()>>,
}
impl ForeignNetworkClient {
pub fn new(
global_ctx: ArcGlobalCtx,
packet_sender_to_mgr: mpsc::Sender<Bytes>,
peer_rpc: Arc<PeerRpcManager>,
my_peer_id: PeerId,
) -> Self {
let peer_map = Arc::new(PeerMap::new(
packet_sender_to_mgr,
global_ctx.clone(),
my_peer_id,
));
let next_hop = Arc::new(DashMap::new());
Self {
global_ctx,
peer_rpc,
my_peer_id,
peer_map,
next_hop,
tasks: Mutex::new(JoinSet::new()),
}
}
pub async fn add_new_peer_conn(&self, peer_conn: PeerConn) {
tracing::warn!(peer_conn = ?peer_conn.get_conn_info(), network = ?peer_conn.get_network_identity(), "add new peer conn in foreign network client");
self.peer_map.add_new_peer_conn(peer_conn).await
}
async fn collect_next_hop_in_foreign_network_task(
network_identity: NetworkIdentity,
peer_map: Arc<PeerMap>,
peer_rpc: Arc<PeerRpcManager>,
next_hop: Arc<DashMap<PeerId, PeerId>>,
) {
loop {
tokio::time::sleep(std::time::Duration::from_secs(1)).await;
peer_map.clean_peer_without_conn().await;
let new_next_hop = Self::collect_next_hop_in_foreign_network(
network_identity.clone(),
peer_map.clone(),
peer_rpc.clone(),
)
.await;
next_hop.clear();
for (k, v) in new_next_hop.into_iter() {
next_hop.insert(k, v);
}
}
}
async fn collect_next_hop_in_foreign_network(
network_identity: NetworkIdentity,
peer_map: Arc<PeerMap>,
peer_rpc: Arc<PeerRpcManager>,
) -> DashMap<PeerId, PeerId> {
let peers = peer_map.list_peers().await;
let mut tasks = JoinSet::new();
if !peers.is_empty() {
tracing::warn!(?peers, my_peer_id = ?peer_rpc.my_peer_id(), "collect next hop in foreign network");
}
for peer in peers {
let peer_rpc = peer_rpc.clone();
let network_identity = network_identity.clone();
tasks.spawn(async move {
let Ok(Some(peers_in_foreign)) = peer_rpc
.do_client_rpc_scoped(FOREIGN_NETWORK_SERVICE_ID, peer, |c| async {
let c =
ForeignNetworkServiceClient::new(tarpc::client::Config::default(), c)
.spawn();
let mut rpc_ctx = tarpc::context::current();
rpc_ctx.deadline = SystemTime::now() + Duration::from_secs(2);
let ret = c.list_network_peers(rpc_ctx, network_identity).await;
ret
})
.await
else {
return (peer, vec![]);
};
(peer, peers_in_foreign)
});
}
let new_next_hop = DashMap::new();
while let Some(join_ret) = tasks.join_next().await {
let Ok((gateway, peer_ids)) = join_ret else {
tracing::error!(?join_ret, "collect next hop in foreign network failed");
continue;
};
for ret in peer_ids {
new_next_hop.insert(ret, gateway);
}
}
new_next_hop
}
pub fn has_next_hop(&self, peer_id: PeerId) -> bool {
self.get_next_hop(peer_id).is_some()
}
pub fn get_next_hop(&self, peer_id: PeerId) -> Option<PeerId> {
if self.peer_map.has_peer(peer_id) {
return Some(peer_id.clone());
}
self.next_hop.get(&peer_id).map(|v| v.clone())
}
pub async fn send_msg(&self, msg: Bytes, peer_id: PeerId) -> Result<(), Error> {
if let Some(next_hop) = self.get_next_hop(peer_id) {
let ret = self.peer_map.send_msg_directly(msg, next_hop).await;
if ret.is_err() {
tracing::error!(
?ret,
?peer_id,
?next_hop,
"foreign network client send msg failed"
);
}
return ret;
}
Err(Error::RouteError(Some("no next hop".to_string())))
}
pub fn list_foreign_peers(&self) -> Vec<PeerId> {
let mut peers = vec![];
for item in self.next_hop.iter() {
if item.key() != &self.my_peer_id {
peers.push(item.key().clone());
}
}
peers
}
pub async fn run(&self) {
self.tasks
.lock()
.await
.spawn(Self::collect_next_hop_in_foreign_network_task(
self.global_ctx.get_network_identity(),
self.peer_map.clone(),
self.peer_rpc.clone(),
self.next_hop.clone(),
));
}
pub fn get_next_hop_table(&self) -> DashMap<PeerId, PeerId> {
let next_hop = DashMap::new();
for item in self.next_hop.iter() {
next_hop.insert(item.key().clone(), item.value().clone());
}
next_hop
}
pub fn get_peer_map(&self) -> Arc<PeerMap> {
self.peer_map.clone()
}
}
@@ -0,0 +1,459 @@
/*
foreign_network_manager is used to forward packets of other networks. currently
only forward packets of peers that directly connected to this node.
in future, with the help wo peer center we can forward packets of peers that
connected to any node in the local network.
*/
use std::sync::Arc;
use dashmap::DashMap;
use tokio::{
sync::{
mpsc::{self, UnboundedReceiver, UnboundedSender},
Mutex,
},
task::JoinSet,
};
use tokio_util::bytes::Bytes;
use crate::common::{
error::Error,
global_ctx::{ArcGlobalCtx, GlobalCtxEvent, NetworkIdentity},
PeerId,
};
use super::{
packet::{self},
peer_conn::PeerConn,
peer_map::PeerMap,
peer_rpc::{PeerRpcManager, PeerRpcManagerTransport},
};
struct ForeignNetworkEntry {
network: NetworkIdentity,
peer_map: Arc<PeerMap>,
}
impl ForeignNetworkEntry {
fn new(
network: NetworkIdentity,
packet_sender: mpsc::Sender<Bytes>,
global_ctx: ArcGlobalCtx,
my_peer_id: PeerId,
) -> Self {
let peer_map = Arc::new(PeerMap::new(packet_sender, global_ctx, my_peer_id));
Self { network, peer_map }
}
}
struct ForeignNetworkManagerData {
network_peer_maps: DashMap<String, Arc<ForeignNetworkEntry>>,
peer_network_map: DashMap<PeerId, String>,
}
impl ForeignNetworkManagerData {
async fn send_msg(&self, msg: Bytes, dst_peer_id: PeerId) -> Result<(), Error> {
let network_name = self
.peer_network_map
.get(&dst_peer_id)
.ok_or_else(|| Error::RouteError(Some("network not found".to_string())))?
.clone();
let entry = self
.network_peer_maps
.get(&network_name)
.ok_or_else(|| Error::RouteError(Some("no peer in network".to_string())))?
.clone();
entry.peer_map.send_msg(msg, dst_peer_id).await
}
fn get_peer_network(&self, peer_id: PeerId) -> Option<String> {
self.peer_network_map.get(&peer_id).map(|v| v.clone())
}
fn get_network_entry(&self, network_name: &str) -> Option<Arc<ForeignNetworkEntry>> {
self.network_peer_maps.get(network_name).map(|v| v.clone())
}
fn remove_peer(&self, peer_id: PeerId) {
self.peer_network_map.remove(&peer_id);
self.network_peer_maps.retain(|_, v| !v.peer_map.is_empty());
}
fn clear_no_conn_peer(&self) {
for item in self.network_peer_maps.iter() {
let peer_map = item.value().peer_map.clone();
tokio::spawn(async move {
peer_map.clean_peer_without_conn().await;
});
}
}
}
struct RpcTransport {
my_peer_id: PeerId,
data: Arc<ForeignNetworkManagerData>,
packet_recv: Mutex<UnboundedReceiver<Bytes>>,
}
#[async_trait::async_trait]
impl PeerRpcManagerTransport for RpcTransport {
fn my_peer_id(&self) -> PeerId {
self.my_peer_id
}
async fn send(&self, msg: Bytes, dst_peer_id: PeerId) -> Result<(), Error> {
self.data.send_msg(msg, dst_peer_id).await
}
async fn recv(&self) -> Result<Bytes, Error> {
if let Some(o) = self.packet_recv.lock().await.recv().await {
Ok(o)
} else {
Err(Error::Unknown)
}
}
}
pub const FOREIGN_NETWORK_SERVICE_ID: u32 = 1;
#[tarpc::service]
pub trait ForeignNetworkService {
async fn list_network_peers(network_identy: NetworkIdentity) -> Option<Vec<PeerId>>;
}
#[tarpc::server]
impl ForeignNetworkService for Arc<ForeignNetworkManagerData> {
async fn list_network_peers(
self,
_: tarpc::context::Context,
network_identy: NetworkIdentity,
) -> Option<Vec<PeerId>> {
let entry = self.network_peer_maps.get(&network_identy.network_name)?;
Some(entry.peer_map.list_peers().await)
}
}
pub struct ForeignNetworkManager {
my_peer_id: PeerId,
global_ctx: ArcGlobalCtx,
packet_sender_to_mgr: mpsc::Sender<Bytes>,
packet_sender: mpsc::Sender<Bytes>,
packet_recv: Mutex<Option<mpsc::Receiver<Bytes>>>,
data: Arc<ForeignNetworkManagerData>,
rpc_mgr: Arc<PeerRpcManager>,
rpc_transport_sender: UnboundedSender<Bytes>,
tasks: Mutex<JoinSet<()>>,
}
impl ForeignNetworkManager {
pub fn new(
my_peer_id: PeerId,
global_ctx: ArcGlobalCtx,
packet_sender_to_mgr: mpsc::Sender<Bytes>,
) -> Self {
// recv packet from all foreign networks
let (packet_sender, packet_recv) = mpsc::channel(1000);
let data = Arc::new(ForeignNetworkManagerData {
network_peer_maps: DashMap::new(),
peer_network_map: DashMap::new(),
});
// handle rpc from foreign networks
let (rpc_transport_sender, peer_rpc_tspt_recv) = mpsc::unbounded_channel();
let rpc_mgr = Arc::new(PeerRpcManager::new(RpcTransport {
my_peer_id,
data: data.clone(),
packet_recv: Mutex::new(peer_rpc_tspt_recv),
}));
Self {
my_peer_id,
global_ctx,
packet_sender_to_mgr,
packet_sender,
packet_recv: Mutex::new(Some(packet_recv)),
data,
rpc_mgr,
rpc_transport_sender,
tasks: Mutex::new(JoinSet::new()),
}
}
pub async fn add_peer_conn(&self, peer_conn: PeerConn) -> Result<(), Error> {
tracing::info!(peer_conn = ?peer_conn.get_conn_info(), network = ?peer_conn.get_network_identity(), "add new peer conn in foreign network manager");
let entry = self
.data
.network_peer_maps
.entry(peer_conn.get_network_identity().network_name.clone())
.or_insert_with(|| {
Arc::new(ForeignNetworkEntry::new(
peer_conn.get_network_identity(),
self.packet_sender.clone(),
self.global_ctx.clone(),
self.my_peer_id,
))
})
.clone();
self.data.peer_network_map.insert(
peer_conn.get_peer_id(),
peer_conn.get_network_identity().network_name.clone(),
);
if entry.network.network_secret != peer_conn.get_network_identity().network_secret {
return Err(anyhow::anyhow!("network secret not match").into());
}
Ok(entry.peer_map.add_new_peer_conn(peer_conn).await)
}
async fn start_global_event_handler(&self) {
let data = self.data.clone();
let mut s = self.global_ctx.subscribe();
self.tasks.lock().await.spawn(async move {
while let Ok(e) = s.recv().await {
if let GlobalCtxEvent::PeerRemoved(peer_id) = &e {
tracing::info!(?e, "remove peer from foreign network manager");
data.remove_peer(*peer_id);
} else if let GlobalCtxEvent::PeerConnRemoved(..) = &e {
tracing::info!(?e, "clear no conn peer from foreign network manager");
data.clear_no_conn_peer();
}
}
});
}
async fn start_packet_recv(&self) {
let mut recv = self.packet_recv.lock().await.take().unwrap();
let sender_to_mgr = self.packet_sender_to_mgr.clone();
let my_node_id = self.my_peer_id;
let rpc_sender = self.rpc_transport_sender.clone();
let data = self.data.clone();
self.tasks.lock().await.spawn(async move {
while let Some(packet_bytes) = recv.recv().await {
let packet = packet::Packet::decode(&packet_bytes);
let from_peer_id = packet.from_peer.into();
let to_peer_id = packet.to_peer.into();
if to_peer_id == my_node_id {
if packet.packet_type == packet::PacketType::TaRpc {
rpc_sender.send(packet_bytes.clone()).unwrap();
continue;
}
if let Err(e) = sender_to_mgr.send(packet_bytes).await {
tracing::error!("send packet to mgr failed: {:?}", e);
}
} else {
let Some(from_network) = data.get_peer_network(from_peer_id) else {
continue;
};
let Some(to_network) = data.get_peer_network(to_peer_id) else {
continue;
};
if from_network != to_network {
continue;
}
if let Some(entry) = data.get_network_entry(&from_network) {
let ret = entry.peer_map.send_msg(packet_bytes, to_peer_id).await;
if ret.is_err() {
tracing::error!("forward packet to peer failed: {:?}", ret.err());
}
} else {
tracing::error!("foreign network not found: {}", from_network);
}
}
}
});
}
async fn register_peer_rpc_service(&self) {
self.rpc_mgr.run();
self.rpc_mgr
.run_service(FOREIGN_NETWORK_SERVICE_ID, self.data.clone().serve())
}
pub async fn run(&self) {
self.start_global_event_handler().await;
self.start_packet_recv().await;
self.register_peer_rpc_service().await;
}
pub async fn list_foreign_networks(&self) -> DashMap<String, Vec<PeerId>> {
let ret = DashMap::new();
for item in self.data.network_peer_maps.iter() {
let network_name = item.key().clone();
ret.insert(network_name, vec![]);
}
for mut n in ret.iter_mut() {
let network_name = n.key().clone();
let Some(item) = self
.data
.network_peer_maps
.get(&network_name)
.map(|v| v.clone())
else {
continue;
};
n.value_mut().extend(item.peer_map.list_peers().await);
}
ret
}
}
#[cfg(test)]
mod tests {
use crate::{
common::global_ctx::tests::get_mock_global_ctx_with_network,
connector::udp_hole_punch::tests::{
create_mock_peer_manager_with_mock_stun, replace_stun_info_collector,
},
peers::{
peer_manager::{PeerManager, RouteAlgoType},
tests::{connect_peer_manager, wait_route_appear},
},
rpc::NatType,
};
use super::*;
async fn create_mock_peer_manager_for_foreign_network(network: &str) -> Arc<PeerManager> {
let (s, _r) = tokio::sync::mpsc::channel(1000);
let peer_mgr = Arc::new(PeerManager::new(
RouteAlgoType::Ospf,
get_mock_global_ctx_with_network(Some(NetworkIdentity {
network_name: network.to_string(),
network_secret: network.to_string(),
})),
s,
));
replace_stun_info_collector(peer_mgr.clone(), NatType::Unknown);
peer_mgr.run().await.unwrap();
peer_mgr
}
#[tokio::test]
async fn test_foreign_network_manager() {
let pm_center = create_mock_peer_manager_with_mock_stun(crate::rpc::NatType::Unknown).await;
let pm_center2 =
create_mock_peer_manager_with_mock_stun(crate::rpc::NatType::Unknown).await;
connect_peer_manager(pm_center.clone(), pm_center2.clone()).await;
let pma_net1 = create_mock_peer_manager_for_foreign_network("net1").await;
let pmb_net1 = create_mock_peer_manager_for_foreign_network("net1").await;
connect_peer_manager(pma_net1.clone(), pm_center.clone()).await;
connect_peer_manager(pmb_net1.clone(), pm_center.clone()).await;
let now = std::time::Instant::now();
let mut succ = false;
while now.elapsed().as_secs() < 10 {
let table = pma_net1.get_foreign_network_client().get_next_hop_table();
if table.len() >= 1 {
succ = true;
break;
}
tokio::time::sleep(std::time::Duration::from_secs(1)).await;
}
assert!(succ);
assert_eq!(
vec![pm_center.my_peer_id()],
pma_net1
.get_foreign_network_client()
.get_peer_map()
.list_peers()
.await
);
assert_eq!(
vec![pm_center.my_peer_id()],
pmb_net1
.get_foreign_network_client()
.get_peer_map()
.list_peers()
.await
);
wait_route_appear(pma_net1.clone(), pmb_net1.clone())
.await
.unwrap();
assert_eq!(1, pma_net1.list_routes().await.len());
assert_eq!(1, pmb_net1.list_routes().await.len());
let pmc_net1 = create_mock_peer_manager_for_foreign_network("net1").await;
connect_peer_manager(pmc_net1.clone(), pm_center.clone()).await;
wait_route_appear(pma_net1.clone(), pmc_net1.clone())
.await
.unwrap();
wait_route_appear(pmb_net1.clone(), pmc_net1.clone())
.await
.unwrap();
assert_eq!(2, pmc_net1.list_routes().await.len());
let pma_net2 = create_mock_peer_manager_for_foreign_network("net2").await;
let pmb_net2 = create_mock_peer_manager_for_foreign_network("net2").await;
connect_peer_manager(pma_net2.clone(), pm_center.clone()).await;
connect_peer_manager(pmb_net2.clone(), pm_center.clone()).await;
wait_route_appear(pma_net2.clone(), pmb_net2.clone())
.await
.unwrap();
assert_eq!(1, pma_net2.list_routes().await.len());
assert_eq!(1, pmb_net2.list_routes().await.len());
assert_eq!(
5,
pm_center
.get_foreign_network_manager()
.data
.peer_network_map
.len()
);
assert_eq!(
2,
pm_center
.get_foreign_network_manager()
.data
.network_peer_maps
.len()
);
drop(pmb_net2);
tokio::time::sleep(std::time::Duration::from_secs(1)).await;
assert_eq!(
4,
pm_center
.get_foreign_network_manager()
.data
.peer_network_map
.len()
);
drop(pma_net2);
tokio::time::sleep(std::time::Duration::from_secs(1)).await;
assert_eq!(
3,
pm_center
.get_foreign_network_manager()
.data
.peer_network_map
.len()
);
assert_eq!(
1,
pm_center
.get_foreign_network_manager()
.data
.network_peer_maps
.len()
);
}
}
+39
View File
@@ -0,0 +1,39 @@
pub mod packet;
pub mod peer;
pub mod peer_conn;
pub mod peer_manager;
pub mod peer_map;
pub mod peer_ospf_route;
pub mod peer_rip_route;
pub mod peer_rpc;
pub mod route_trait;
pub mod rpc_service;
pub mod foreign_network_client;
pub mod foreign_network_manager;
#[cfg(test)]
pub mod tests;
use tokio_util::bytes::{Bytes, BytesMut};
#[async_trait::async_trait]
#[auto_impl::auto_impl(Arc)]
pub trait PeerPacketFilter {
async fn try_process_packet_from_peer(
&self,
_packet: &packet::ArchivedPacket,
_data: &Bytes,
) -> Option<()> {
None
}
}
#[async_trait::async_trait]
#[auto_impl::auto_impl(Arc)]
pub trait NicPacketFilter {
async fn try_process_packet_from_nic(&self, data: BytesMut) -> BytesMut;
}
type BoxPeerPacketFilter = Box<dyn PeerPacketFilter + Send + Sync>;
type BoxNicPacketFilter = Box<dyn NicPacketFilter + Send + Sync>;
+254
View File
@@ -0,0 +1,254 @@
use std::fmt::Debug;
use rkyv::{Archive, Deserialize, Serialize};
use tokio_util::bytes::Bytes;
use crate::common::{
global_ctx::NetworkIdentity,
rkyv_util::{decode_from_bytes, encode_to_bytes, vec_to_string},
PeerId,
};
const MAGIC: u32 = 0xd1e1a5e1;
const VERSION: u32 = 1;
#[derive(Archive, Deserialize, Serialize, PartialEq, Clone)]
#[archive(compare(PartialEq), check_bytes)]
// Derives can be passed through to the generated type:
#[archive_attr(derive(Debug))]
pub struct UUID(uuid::Bytes);
// impl Debug for UUID
impl std::fmt::Debug for UUID {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let uuid = uuid::Uuid::from_bytes(self.0);
write!(f, "{}", uuid)
}
}
impl From<uuid::Uuid> for UUID {
fn from(uuid: uuid::Uuid) -> Self {
UUID(*uuid.as_bytes())
}
}
impl From<UUID> for uuid::Uuid {
fn from(uuid: UUID) -> Self {
uuid::Uuid::from_bytes(uuid.0)
}
}
impl ArchivedUUID {
pub fn to_uuid(&self) -> uuid::Uuid {
uuid::Uuid::from_bytes(self.0)
}
}
impl From<&ArchivedUUID> for UUID {
fn from(uuid: &ArchivedUUID) -> Self {
UUID(uuid.0)
}
}
#[derive(serde::Serialize, serde::Deserialize, Debug)]
pub struct HandShake {
pub magic: u32,
pub my_peer_id: PeerId,
pub version: u32,
pub features: Vec<String>,
pub network_identity: NetworkIdentity,
}
#[derive(serde::Serialize, serde::Deserialize, Debug)]
pub struct RoutePacket {
pub route_id: u8,
pub body: Vec<u8>,
}
#[derive(Debug, serde::Serialize, serde::Deserialize)]
pub enum CtrlPacketPayload {
HandShake(HandShake),
RoutePacket(RoutePacket),
Ping(u32),
Pong(u32),
TaRpc(u32, u32, bool, Vec<u8>), // u32: service_id, u32: transact_id, bool: is_req, Vec<u8>: rpc body
}
impl CtrlPacketPayload {
pub fn from_packet(p: &ArchivedPacket) -> CtrlPacketPayload {
assert_ne!(p.packet_type, PacketType::Data);
postcard::from_bytes(p.payload.as_bytes()).unwrap()
}
pub fn from_packet2(p: &Packet) -> CtrlPacketPayload {
postcard::from_bytes(p.payload.as_bytes()).unwrap()
}
}
#[repr(u8)]
#[derive(Archive, Deserialize, Serialize, Debug)]
#[archive(compare(PartialEq), check_bytes)]
// Derives can be passed through to the generated type:
#[archive_attr(derive(Debug))]
pub enum PacketType {
Data = 1,
HandShake = 2,
RoutePacket = 3,
Ping = 4,
Pong = 5,
TaRpc = 6,
}
#[derive(Archive, Deserialize, Serialize)]
#[archive(compare(PartialEq), check_bytes)]
// Derives can be passed through to the generated type:
pub struct Packet {
pub from_peer: PeerId,
pub to_peer: PeerId,
pub packet_type: PacketType,
pub payload: String,
}
impl std::fmt::Debug for Packet {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"Packet {{ from_peer: {}, to_peer: {}, packet_type: {:?}, payload: {:?} }}",
self.from_peer,
self.to_peer,
self.packet_type,
&self.payload.as_bytes()
)
}
}
impl std::fmt::Debug for ArchivedPacket {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"Packet {{ from_peer: {}, to_peer: {}, packet_type: {:?}, payload: {:?} }}",
self.from_peer,
self.to_peer,
self.packet_type,
&self.payload.as_bytes()
)
}
}
impl Packet {
pub fn decode(v: &[u8]) -> &ArchivedPacket {
decode_from_bytes::<Packet>(v).unwrap()
}
pub fn new(
from_peer: PeerId,
to_peer: PeerId,
packet_type: PacketType,
payload: Vec<u8>,
) -> Self {
Packet {
from_peer,
to_peer,
packet_type,
payload: vec_to_string(payload),
}
}
}
impl From<Packet> for Bytes {
fn from(val: Packet) -> Self {
encode_to_bytes::<_, 4096>(&val)
}
}
impl Packet {
pub fn new_handshake(from_peer: PeerId, network: &NetworkIdentity) -> Self {
let handshake = CtrlPacketPayload::HandShake(HandShake {
magic: MAGIC,
my_peer_id: from_peer,
version: VERSION,
features: Vec::new(),
network_identity: network.clone().into(),
});
Packet::new(
from_peer.into(),
0,
PacketType::HandShake,
postcard::to_allocvec(&handshake).unwrap(),
)
}
pub fn new_data_packet(from_peer: PeerId, to_peer: PeerId, data: &[u8]) -> Self {
Packet::new(from_peer, to_peer, PacketType::Data, data.to_vec())
}
pub fn new_route_packet(from_peer: PeerId, to_peer: PeerId, route_id: u8, data: &[u8]) -> Self {
let route = CtrlPacketPayload::RoutePacket(RoutePacket {
route_id,
body: data.to_vec(),
});
Packet::new(
from_peer,
to_peer,
PacketType::RoutePacket,
postcard::to_allocvec(&route).unwrap(),
)
}
pub fn new_ping_packet(from_peer: PeerId, to_peer: PeerId, seq: u32) -> Self {
let ping = CtrlPacketPayload::Ping(seq);
Packet::new(
from_peer,
to_peer,
PacketType::Ping,
postcard::to_allocvec(&ping).unwrap(),
)
}
pub fn new_pong_packet(from_peer: PeerId, to_peer: PeerId, seq: u32) -> Self {
let pong = CtrlPacketPayload::Pong(seq);
Packet::new(
from_peer,
to_peer,
PacketType::Pong,
postcard::to_allocvec(&pong).unwrap(),
)
}
pub fn new_tarpc_packet(
from_peer: PeerId,
to_peer: PeerId,
service_id: u32,
transact_id: u32,
is_req: bool,
body: Vec<u8>,
) -> Self {
let ta_rpc = CtrlPacketPayload::TaRpc(service_id, transact_id, is_req, body);
Packet::new(
from_peer,
to_peer,
PacketType::TaRpc,
postcard::to_allocvec(&ta_rpc).unwrap(),
)
}
}
#[cfg(test)]
mod tests {
use crate::common::new_peer_id;
use super::*;
#[tokio::test]
async fn serialize() {
let a = "abcde";
let out = Packet::new_data_packet(new_peer_id(), new_peer_id(), a.as_bytes());
// let out = T::new(a.as_bytes());
let out_bytes: Bytes = out.into();
println!("out str: {:?}", a.as_bytes());
println!("out bytes: {:?}", out_bytes);
let archived = Packet::decode(&out_bytes[..]);
println!("in packet: {:?}", archived);
}
}
+213
View File
@@ -0,0 +1,213 @@
use std::sync::Arc;
use dashmap::DashMap;
use tokio::{
select,
sync::{mpsc, Mutex},
task::JoinHandle,
};
use tokio_util::bytes::Bytes;
use tracing::Instrument;
use super::peer_conn::{PeerConn, PeerConnId};
use crate::common::{
error::Error,
global_ctx::{ArcGlobalCtx, GlobalCtxEvent},
PeerId,
};
use crate::rpc::PeerConnInfo;
type ArcPeerConn = Arc<Mutex<PeerConn>>;
type ConnMap = Arc<DashMap<PeerConnId, ArcPeerConn>>;
pub struct Peer {
pub peer_node_id: PeerId,
conns: ConnMap,
global_ctx: ArcGlobalCtx,
packet_recv_chan: mpsc::Sender<Bytes>,
close_event_sender: mpsc::Sender<PeerConnId>,
close_event_listener: JoinHandle<()>,
shutdown_notifier: Arc<tokio::sync::Notify>,
}
impl Peer {
pub fn new(
peer_node_id: PeerId,
packet_recv_chan: mpsc::Sender<Bytes>,
global_ctx: ArcGlobalCtx,
) -> Self {
let conns: ConnMap = Arc::new(DashMap::new());
let (close_event_sender, mut close_event_receiver) = mpsc::channel(10);
let shutdown_notifier = Arc::new(tokio::sync::Notify::new());
let conns_copy = conns.clone();
let shutdown_notifier_copy = shutdown_notifier.clone();
let global_ctx_copy = global_ctx.clone();
let close_event_listener = tokio::spawn(
async move {
loop {
select! {
ret = close_event_receiver.recv() => {
if ret.is_none() {
break;
}
let ret = ret.unwrap();
tracing::warn!(
?peer_node_id,
?ret,
"notified that peer conn is closed",
);
if let Some((_, conn)) = conns_copy.remove(&ret) {
global_ctx_copy.issue_event(GlobalCtxEvent::PeerConnRemoved(
conn.lock().await.get_conn_info(),
));
}
}
_ = shutdown_notifier_copy.notified() => {
close_event_receiver.close();
tracing::warn!(?peer_node_id, "peer close event listener notified");
}
}
}
tracing::info!("peer {} close event listener exit", peer_node_id);
}
.instrument(tracing::info_span!(
"peer_close_event_listener",
?peer_node_id,
)),
);
Peer {
peer_node_id,
conns: conns.clone(),
packet_recv_chan,
global_ctx,
close_event_sender,
close_event_listener,
shutdown_notifier,
}
}
pub async fn add_peer_conn(&self, mut conn: PeerConn) {
conn.set_close_event_sender(self.close_event_sender.clone());
conn.start_recv_loop(self.packet_recv_chan.clone());
conn.start_pingpong();
self.global_ctx
.issue_event(GlobalCtxEvent::PeerConnAdded(conn.get_conn_info()));
self.conns
.insert(conn.get_conn_id(), Arc::new(Mutex::new(conn)));
}
pub async fn send_msg(&self, msg: Bytes) -> Result<(), Error> {
let Some(conn) = self.conns.iter().next() else {
return Err(Error::PeerNoConnectionError(self.peer_node_id));
};
let conn_clone = conn.clone();
drop(conn);
conn_clone.lock().await.send_msg(msg).await?;
Ok(())
}
pub async fn close_peer_conn(&self, conn_id: &PeerConnId) -> Result<(), Error> {
let has_key = self.conns.contains_key(conn_id);
if !has_key {
return Err(Error::NotFound);
}
self.close_event_sender.send(conn_id.clone()).await.unwrap();
Ok(())
}
pub async fn list_peer_conns(&self) -> Vec<PeerConnInfo> {
let mut conns = vec![];
for conn in self.conns.iter() {
// do not lock here, otherwise it will cause dashmap deadlock
conns.push(conn.clone());
}
let mut ret = Vec::new();
for conn in conns {
ret.push(conn.lock().await.get_conn_info());
}
ret
}
}
// pritn on drop
impl Drop for Peer {
fn drop(&mut self) {
self.shutdown_notifier.notify_one();
tracing::info!("peer {} drop", self.peer_node_id);
}
}
#[cfg(test)]
mod tests {
use tokio::{sync::mpsc, time::timeout};
use crate::{
common::{global_ctx::tests::get_mock_global_ctx, new_peer_id},
peers::peer_conn::PeerConn,
tunnels::ring_tunnel::create_ring_tunnel_pair,
};
use super::Peer;
#[tokio::test]
async fn close_peer() {
let (local_packet_send, _local_packet_recv) = mpsc::channel(10);
let (remote_packet_send, _remote_packet_recv) = mpsc::channel(10);
let global_ctx = get_mock_global_ctx();
let local_peer = Peer::new(new_peer_id(), local_packet_send, global_ctx.clone());
let remote_peer = Peer::new(new_peer_id(), remote_packet_send, global_ctx.clone());
let (local_tunnel, remote_tunnel) = create_ring_tunnel_pair();
let mut local_peer_conn =
PeerConn::new(local_peer.peer_node_id, global_ctx.clone(), local_tunnel);
let mut remote_peer_conn =
PeerConn::new(remote_peer.peer_node_id, global_ctx.clone(), remote_tunnel);
assert!(!local_peer_conn.handshake_done());
assert!(!remote_peer_conn.handshake_done());
let (a, b) = tokio::join!(
local_peer_conn.do_handshake_as_client(),
remote_peer_conn.do_handshake_as_server()
);
a.unwrap();
b.unwrap();
let local_conn_id = local_peer_conn.get_conn_id();
local_peer.add_peer_conn(local_peer_conn).await;
remote_peer.add_peer_conn(remote_peer_conn).await;
assert_eq!(local_peer.list_peer_conns().await.len(), 1);
assert_eq!(remote_peer.list_peer_conns().await.len(), 1);
let close_handler =
tokio::spawn(async move { local_peer.close_peer_conn(&local_conn_id).await });
// wait for remote peer conn close
timeout(std::time::Duration::from_secs(5), async {
while (&remote_peer).list_peer_conns().await.len() != 0 {
tokio::time::sleep(std::time::Duration::from_millis(100)).await;
}
})
.await
.unwrap();
println!("wait for close handler");
close_handler.await.unwrap().unwrap();
}
}
+652
View File
@@ -0,0 +1,652 @@
use std::{
fmt::Debug,
pin::Pin,
sync::{
atomic::{AtomicU32, Ordering},
Arc,
},
};
use futures::{SinkExt, StreamExt};
use pnet::datalink::NetworkInterface;
use tokio::{
sync::{broadcast, mpsc, Mutex},
task::JoinSet,
time::{timeout, Duration},
};
use tokio_util::{bytes::Bytes, sync::PollSender};
use tracing::Instrument;
use crate::{
common::{
global_ctx::{ArcGlobalCtx, NetworkIdentity},
PeerId,
},
define_tunnel_filter_chain,
peers::packet::{ArchivedPacketType, CtrlPacketPayload, PacketType},
rpc::{PeerConnInfo, PeerConnStats},
tunnels::{
stats::{Throughput, WindowLatency},
tunnel_filter::StatsRecorderTunnelFilter,
DatagramSink, Tunnel, TunnelError,
},
};
use super::packet::{self, HandShake, Packet};
pub type PacketRecvChan = mpsc::Sender<Bytes>;
pub type PeerConnId = uuid::Uuid;
macro_rules! wait_response {
($stream: ident, $out_var:ident, $pattern:pat_param => $value:expr) => {
let rsp_vec = timeout(Duration::from_secs(1), $stream.next()).await;
if rsp_vec.is_err() {
return Err(TunnelError::WaitRespError(
"wait handshake response timeout".to_owned(),
));
}
let rsp_vec = rsp_vec.unwrap().unwrap()?;
let $out_var;
let rsp_bytes = Packet::decode(&rsp_vec);
if rsp_bytes.packet_type != PacketType::HandShake {
tracing::error!("unexpected packet type: {:?}", rsp_bytes);
return Err(TunnelError::WaitRespError(
"unexpected packet type".to_owned(),
));
}
let resp_payload = CtrlPacketPayload::from_packet(&rsp_bytes);
match &resp_payload {
$pattern => $out_var = $value,
_ => {
tracing::error!(
"unexpected packet: {:?}, pattern: {:?}",
rsp_bytes,
stringify!($pattern)
);
return Err(TunnelError::WaitRespError("unexpected packet".to_owned()));
}
}
};
}
#[derive(Debug)]
pub struct PeerInfo {
magic: u32,
pub my_peer_id: PeerId,
version: u32,
pub features: Vec<String>,
pub interfaces: Vec<NetworkInterface>,
pub network_identity: NetworkIdentity,
}
impl<'a> From<&HandShake> for PeerInfo {
fn from(hs: &HandShake) -> Self {
PeerInfo {
magic: hs.magic.into(),
my_peer_id: hs.my_peer_id.into(),
version: hs.version.into(),
features: hs.features.iter().map(|x| x.to_string()).collect(),
interfaces: Vec::new(),
network_identity: hs.network_identity.clone(),
}
}
}
struct PeerConnPinger {
my_peer_id: PeerId,
peer_id: PeerId,
sink: Arc<Mutex<Pin<Box<dyn DatagramSink>>>>,
ctrl_sender: broadcast::Sender<Bytes>,
latency_stats: Arc<WindowLatency>,
loss_rate_stats: Arc<AtomicU32>,
tasks: JoinSet<Result<(), TunnelError>>,
}
impl Debug for PeerConnPinger {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("PeerConnPinger")
.field("my_peer_id", &self.my_peer_id)
.field("peer_id", &self.peer_id)
.finish()
}
}
impl PeerConnPinger {
pub fn new(
my_peer_id: PeerId,
peer_id: PeerId,
sink: Pin<Box<dyn DatagramSink>>,
ctrl_sender: broadcast::Sender<Bytes>,
latency_stats: Arc<WindowLatency>,
loss_rate_stats: Arc<AtomicU32>,
) -> Self {
Self {
my_peer_id,
peer_id,
sink: Arc::new(Mutex::new(sink)),
tasks: JoinSet::new(),
latency_stats,
ctrl_sender,
loss_rate_stats,
}
}
async fn do_pingpong_once(
my_node_id: PeerId,
peer_id: PeerId,
sink: Arc<Mutex<Pin<Box<dyn DatagramSink>>>>,
receiver: &mut broadcast::Receiver<Bytes>,
seq: u32,
) -> Result<u128, TunnelError> {
// should add seq here. so latency can be calculated more accurately
let req = packet::Packet::new_ping_packet(my_node_id, peer_id, seq).into();
tracing::trace!("send ping packet: {:?}", req);
sink.lock().await.send(req).await.map_err(|e| {
tracing::warn!("send ping packet error: {:?}", e);
TunnelError::CommonError("send ping packet error".to_owned())
})?;
let now = std::time::Instant::now();
// wait until we get a pong packet in ctrl_resp_receiver
let resp = timeout(Duration::from_secs(1), async {
loop {
match receiver.recv().await {
Ok(p) => {
let ctrl_payload =
packet::CtrlPacketPayload::from_packet(Packet::decode(&p));
if let packet::CtrlPacketPayload::Pong(resp_seq) = ctrl_payload {
if resp_seq == seq {
break;
}
}
}
Err(e) => {
log::warn!("recv pong resp error: {:?}", e);
return Err(TunnelError::WaitRespError(
"recv pong resp error".to_owned(),
));
}
}
}
Ok(())
})
.await;
tracing::trace!(?resp, "wait ping response done");
if resp.is_err() {
return Err(TunnelError::WaitRespError(
"wait ping response timeout".to_owned(),
));
}
if resp.as_ref().unwrap().is_err() {
return Err(resp.unwrap().err().unwrap());
}
Ok(now.elapsed().as_micros())
}
async fn pingpong(&mut self) {
let sink = self.sink.clone();
let my_node_id = self.my_peer_id;
let peer_id = self.peer_id;
let latency_stats = self.latency_stats.clone();
let (ping_res_sender, mut ping_res_receiver) = tokio::sync::mpsc::channel(100);
let stopped = Arc::new(AtomicU32::new(0));
// generate a pingpong task every 200ms
let mut pingpong_tasks = JoinSet::new();
let ctrl_resp_sender = self.ctrl_sender.clone();
let stopped_clone = stopped.clone();
self.tasks.spawn(async move {
let mut req_seq = 0;
loop {
let receiver = ctrl_resp_sender.subscribe();
let ping_res_sender = ping_res_sender.clone();
let sink = sink.clone();
if stopped_clone.load(Ordering::Relaxed) != 0 {
return Ok(());
}
while pingpong_tasks.len() > 5 {
pingpong_tasks.join_next().await;
}
pingpong_tasks.spawn(async move {
let mut receiver = receiver.resubscribe();
let pingpong_once_ret = Self::do_pingpong_once(
my_node_id,
peer_id,
sink.clone(),
&mut receiver,
req_seq,
)
.await;
if let Err(e) = ping_res_sender.send(pingpong_once_ret).await {
tracing::info!(?e, "pingpong task send result error, exit..");
};
});
req_seq = req_seq.wrapping_add(1);
tokio::time::sleep(Duration::from_millis(1000)).await;
}
});
// one with 1% precision
let loss_rate_stats_1 = WindowLatency::new(100);
// one with 20% precision, so we can fast fail this conn.
let loss_rate_stats_20 = WindowLatency::new(5);
let mut counter: u64 = 0;
while let Some(ret) = ping_res_receiver.recv().await {
counter += 1;
if let Ok(lat) = ret {
latency_stats.record_latency(lat as u32);
loss_rate_stats_1.record_latency(0);
loss_rate_stats_20.record_latency(0);
} else {
loss_rate_stats_1.record_latency(1);
loss_rate_stats_20.record_latency(1);
}
let loss_rate_20: f64 = loss_rate_stats_20.get_latency_us();
let loss_rate_1: f64 = loss_rate_stats_1.get_latency_us();
tracing::trace!(
?ret,
?self,
?loss_rate_1,
?loss_rate_20,
"pingpong task recv pingpong_once result"
);
if (counter > 5 && loss_rate_20 > 0.74) || (counter > 150 && loss_rate_1 > 0.20) {
tracing::warn!(
?ret,
?self,
?loss_rate_1,
?loss_rate_20,
"pingpong loss rate too high, closing"
);
break;
}
self.loss_rate_stats
.store((loss_rate_1 * 100.0) as u32, Ordering::Relaxed);
}
stopped.store(1, Ordering::Relaxed);
ping_res_receiver.close();
}
}
define_tunnel_filter_chain!(PeerConnTunnel, stats = StatsRecorderTunnelFilter);
pub struct PeerConn {
conn_id: PeerConnId,
my_peer_id: PeerId,
global_ctx: ArcGlobalCtx,
sink: Pin<Box<dyn DatagramSink>>,
tunnel: Box<dyn Tunnel>,
tasks: JoinSet<Result<(), TunnelError>>,
info: Option<PeerInfo>,
close_event_sender: Option<mpsc::Sender<PeerConnId>>,
ctrl_resp_sender: broadcast::Sender<Bytes>,
latency_stats: Arc<WindowLatency>,
throughput: Arc<Throughput>,
loss_rate_stats: Arc<AtomicU32>,
}
enum PeerConnPacketType {
Data(Bytes),
CtrlReq(Bytes),
CtrlResp(Bytes),
}
impl PeerConn {
pub fn new(my_peer_id: PeerId, global_ctx: ArcGlobalCtx, tunnel: Box<dyn Tunnel>) -> Self {
let (ctrl_sender, _ctrl_receiver) = broadcast::channel(100);
let peer_conn_tunnel = PeerConnTunnel::new();
let tunnel = peer_conn_tunnel.wrap_tunnel(tunnel);
PeerConn {
conn_id: PeerConnId::new_v4(),
my_peer_id,
global_ctx,
sink: tunnel.pin_sink(),
tunnel: Box::new(tunnel),
tasks: JoinSet::new(),
info: None,
close_event_sender: None,
ctrl_resp_sender: ctrl_sender,
latency_stats: Arc::new(WindowLatency::new(15)),
throughput: peer_conn_tunnel.stats.get_throughput().clone(),
loss_rate_stats: Arc::new(AtomicU32::new(0)),
}
}
pub fn get_conn_id(&self) -> PeerConnId {
self.conn_id
}
#[tracing::instrument]
pub async fn do_handshake_as_server(&mut self) -> Result<(), TunnelError> {
let mut stream = self.tunnel.pin_stream();
let mut sink = self.tunnel.pin_sink();
tracing::info!("waiting for handshake request from client");
wait_response!(stream, hs_req, CtrlPacketPayload::HandShake(x) => x);
self.info = Some(PeerInfo::from(hs_req));
tracing::info!("handshake request: {:?}", hs_req);
let hs_req = self
.global_ctx
.net_ns
.run(|| packet::Packet::new_handshake(self.my_peer_id, &self.global_ctx.network));
sink.send(hs_req.into()).await?;
Ok(())
}
#[tracing::instrument]
pub async fn do_handshake_as_client(&mut self) -> Result<(), TunnelError> {
let mut stream = self.tunnel.pin_stream();
let mut sink = self.tunnel.pin_sink();
let hs_req = self
.global_ctx
.net_ns
.run(|| packet::Packet::new_handshake(self.my_peer_id, &self.global_ctx.network));
sink.send(hs_req.into()).await?;
tracing::info!("waiting for handshake request from server");
wait_response!(stream, hs_rsp, CtrlPacketPayload::HandShake(x) => x);
self.info = Some(PeerInfo::from(hs_rsp));
tracing::info!("handshake response: {:?}", hs_rsp);
Ok(())
}
pub fn handshake_done(&self) -> bool {
self.info.is_some()
}
pub fn start_pingpong(&mut self) {
let mut pingpong = PeerConnPinger::new(
self.my_peer_id,
self.get_peer_id(),
self.tunnel.pin_sink(),
self.ctrl_resp_sender.clone(),
self.latency_stats.clone(),
self.loss_rate_stats.clone(),
);
let close_event_sender = self.close_event_sender.clone().unwrap();
let conn_id = self.conn_id;
self.tasks.spawn(async move {
pingpong.pingpong().await;
tracing::warn!(?pingpong, "pingpong task exit");
if let Err(e) = close_event_sender.send(conn_id).await {
log::warn!("close event sender error: {:?}", e);
}
Ok(())
});
}
pub fn start_recv_loop(&mut self, packet_recv_chan: PacketRecvChan) {
let mut stream = self.tunnel.pin_stream();
let mut sink = self.tunnel.pin_sink();
let mut sender = PollSender::new(packet_recv_chan.clone());
let close_event_sender = self.close_event_sender.clone().unwrap();
let conn_id = self.conn_id;
let ctrl_sender = self.ctrl_resp_sender.clone();
let conn_info = self.get_conn_info();
let conn_info_for_instrument = self.get_conn_info();
self.tasks.spawn(
async move {
tracing::info!("start recving peer conn packet");
let mut task_ret = Ok(());
while let Some(ret) = stream.next().await {
if ret.is_err() {
tracing::error!(error = ?ret, "peer conn recv error");
task_ret = Err(ret.err().unwrap());
break;
}
let buf = ret.unwrap();
let p = Packet::decode(&buf);
match p.packet_type {
ArchivedPacketType::Ping => {
let CtrlPacketPayload::Ping(seq) = CtrlPacketPayload::from_packet(p)
else {
log::error!("unexpected packet: {:?}", p);
continue;
};
let pong = packet::Packet::new_pong_packet(
conn_info.my_peer_id,
conn_info.peer_id,
seq.into(),
);
if let Err(e) = sink.send(pong.into()).await {
tracing::error!(?e, "peer conn send req error");
}
}
ArchivedPacketType::Pong => {
if let Err(e) = ctrl_sender.send(buf.into()) {
tracing::error!(?e, "peer conn send ctrl resp error");
}
}
_ => {
if sender.send(buf.into()).await.is_err() {
break;
}
}
}
}
tracing::info!("end recving peer conn packet");
if let Err(close_ret) = sink.close().await {
tracing::error!(error = ?close_ret, "peer conn sink close error, ignore it");
}
if let Err(e) = close_event_sender.send(conn_id).await {
tracing::error!(error = ?e, "peer conn close event send error");
}
task_ret
}
.instrument(
tracing::info_span!("peer conn recv loop", conn_info = ?conn_info_for_instrument),
),
);
}
pub async fn send_msg(&mut self, msg: Bytes) -> Result<(), TunnelError> {
self.sink.send(msg).await
}
pub fn get_peer_id(&self) -> PeerId {
self.info.as_ref().unwrap().my_peer_id
}
pub fn get_network_identity(&self) -> NetworkIdentity {
self.info.as_ref().unwrap().network_identity.clone()
}
pub fn set_close_event_sender(&mut self, sender: mpsc::Sender<PeerConnId>) {
self.close_event_sender = Some(sender);
}
pub fn get_stats(&self) -> PeerConnStats {
PeerConnStats {
latency_us: self.latency_stats.get_latency_us(),
tx_bytes: self.throughput.tx_bytes(),
rx_bytes: self.throughput.rx_bytes(),
tx_packets: self.throughput.tx_packets(),
rx_packets: self.throughput.rx_packets(),
}
}
pub fn get_conn_info(&self) -> PeerConnInfo {
PeerConnInfo {
conn_id: self.conn_id.to_string(),
my_peer_id: self.my_peer_id,
peer_id: self.get_peer_id(),
features: self.info.as_ref().unwrap().features.clone(),
tunnel: self.tunnel.info(),
stats: Some(self.get_stats()),
loss_rate: (f64::from(self.loss_rate_stats.load(Ordering::Relaxed)) / 100.0) as f32,
}
}
}
impl Drop for PeerConn {
fn drop(&mut self) {
let mut sink = self.tunnel.pin_sink();
tokio::spawn(async move {
let ret = sink.close().await;
tracing::info!(error = ?ret, "peer conn tunnel closed.");
});
log::info!("peer conn {:?} drop", self.conn_id);
}
}
impl Debug for PeerConn {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("PeerConn")
.field("conn_id", &self.conn_id)
.field("my_peer_id", &self.my_peer_id)
.field("info", &self.info)
.finish()
}
}
#[cfg(test)]
mod tests {
use std::sync::Arc;
use super::*;
use crate::common::global_ctx::tests::get_mock_global_ctx;
use crate::common::new_peer_id;
use crate::tunnels::tunnel_filter::tests::DropSendTunnelFilter;
use crate::tunnels::tunnel_filter::{PacketRecorderTunnelFilter, TunnelWithFilter};
#[tokio::test]
async fn peer_conn_handshake() {
use crate::tunnels::ring_tunnel::create_ring_tunnel_pair;
let (c, s) = create_ring_tunnel_pair();
let c_recorder = Arc::new(PacketRecorderTunnelFilter::new());
let s_recorder = Arc::new(PacketRecorderTunnelFilter::new());
let c = TunnelWithFilter::new(c, c_recorder.clone());
let s = TunnelWithFilter::new(s, s_recorder.clone());
let c_peer_id = new_peer_id();
let s_peer_id = new_peer_id();
let mut c_peer = PeerConn::new(c_peer_id, get_mock_global_ctx(), Box::new(c));
let mut s_peer = PeerConn::new(s_peer_id, get_mock_global_ctx(), Box::new(s));
let (c_ret, s_ret) = tokio::join!(
c_peer.do_handshake_as_client(),
s_peer.do_handshake_as_server()
);
c_ret.unwrap();
s_ret.unwrap();
assert_eq!(c_recorder.sent.lock().unwrap().len(), 1);
assert_eq!(c_recorder.received.lock().unwrap().len(), 1);
assert_eq!(s_recorder.sent.lock().unwrap().len(), 1);
assert_eq!(s_recorder.received.lock().unwrap().len(), 1);
assert_eq!(c_peer.get_peer_id(), s_peer_id);
assert_eq!(s_peer.get_peer_id(), c_peer_id);
assert_eq!(c_peer.get_network_identity(), s_peer.get_network_identity());
assert_eq!(c_peer.get_network_identity(), NetworkIdentity::default());
}
async fn peer_conn_pingpong_test_common(drop_start: u32, drop_end: u32, conn_closed: bool) {
use crate::tunnels::ring_tunnel::create_ring_tunnel_pair;
let (c, s) = create_ring_tunnel_pair();
// drop 1-3 packets should not affect pingpong
let c_recorder = Arc::new(DropSendTunnelFilter::new(drop_start, drop_end));
let c = TunnelWithFilter::new(c, c_recorder.clone());
let c_peer_id = new_peer_id();
let s_peer_id = new_peer_id();
let mut c_peer = PeerConn::new(c_peer_id, get_mock_global_ctx(), Box::new(c));
let mut s_peer = PeerConn::new(s_peer_id, get_mock_global_ctx(), Box::new(s));
let (c_ret, s_ret) = tokio::join!(
c_peer.do_handshake_as_client(),
s_peer.do_handshake_as_server()
);
s_peer.set_close_event_sender(tokio::sync::mpsc::channel(1).0);
s_peer.start_recv_loop(tokio::sync::mpsc::channel(200).0);
assert!(c_ret.is_ok());
assert!(s_ret.is_ok());
let (close_send, mut close_recv) = tokio::sync::mpsc::channel(1);
c_peer.set_close_event_sender(close_send);
c_peer.start_pingpong();
c_peer.start_recv_loop(tokio::sync::mpsc::channel(200).0);
// wait 5s, conn should not be disconnected
tokio::time::sleep(Duration::from_secs(15)).await;
if conn_closed {
assert!(close_recv.try_recv().is_ok());
} else {
assert!(close_recv.try_recv().is_err());
}
}
#[tokio::test]
async fn peer_conn_pingpong_timeout() {
peer_conn_pingpong_test_common(3, 5, false).await;
peer_conn_pingpong_test_common(5, 12, true).await;
}
}
+641
View File
@@ -0,0 +1,641 @@
use std::{
fmt::Debug,
net::Ipv4Addr,
sync::{Arc, Weak},
};
use async_trait::async_trait;
use futures::StreamExt;
use tokio::{
sync::{
mpsc::{self, UnboundedReceiver, UnboundedSender},
Mutex, RwLock,
},
task::JoinSet,
};
use tokio_stream::wrappers::ReceiverStream;
use tokio_util::bytes::{Bytes, BytesMut};
use crate::{
common::{
error::Error, global_ctx::ArcGlobalCtx, rkyv_util::extract_bytes_from_archived_string,
PeerId,
},
peers::{
packet, peer_conn::PeerConn, peer_rpc::PeerRpcManagerTransport,
route_trait::RouteInterface, PeerPacketFilter,
},
tunnels::{SinkItem, Tunnel, TunnelConnector},
};
use super::{
foreign_network_client::ForeignNetworkClient,
foreign_network_manager::ForeignNetworkManager,
peer_conn::PeerConnId,
peer_map::PeerMap,
peer_ospf_route::PeerRoute,
peer_rip_route::BasicRoute,
peer_rpc::PeerRpcManager,
route_trait::{ArcRoute, Route},
BoxNicPacketFilter, BoxPeerPacketFilter,
};
struct RpcTransport {
my_peer_id: PeerId,
peers: Weak<PeerMap>,
foreign_peers: Mutex<Option<Weak<ForeignNetworkClient>>>,
packet_recv: Mutex<UnboundedReceiver<Bytes>>,
peer_rpc_tspt_sender: UnboundedSender<Bytes>,
}
#[async_trait::async_trait]
impl PeerRpcManagerTransport for RpcTransport {
fn my_peer_id(&self) -> PeerId {
self.my_peer_id
}
async fn send(&self, msg: Bytes, dst_peer_id: PeerId) -> Result<(), Error> {
let foreign_peers = self
.foreign_peers
.lock()
.await
.as_ref()
.ok_or(Error::Unknown)?
.upgrade()
.ok_or(Error::Unknown)?;
let peers = self.peers.upgrade().ok_or(Error::Unknown)?;
let ret = peers.send_msg(msg.clone(), dst_peer_id).await;
if matches!(ret, Err(Error::RouteError(..))) && foreign_peers.has_next_hop(dst_peer_id) {
tracing::info!(
?dst_peer_id,
?self.my_peer_id,
"failed to send msg to peer, try foreign network",
);
return foreign_peers.send_msg(msg, dst_peer_id).await;
}
ret
}
async fn recv(&self) -> Result<Bytes, Error> {
if let Some(o) = self.packet_recv.lock().await.recv().await {
Ok(o)
} else {
Err(Error::Unknown)
}
}
}
pub enum RouteAlgoType {
Rip,
Ospf,
None,
}
enum RouteAlgoInst {
Rip(Arc<BasicRoute>),
Ospf(Arc<PeerRoute>),
None,
}
pub struct PeerManager {
my_peer_id: PeerId,
global_ctx: ArcGlobalCtx,
nic_channel: mpsc::Sender<SinkItem>,
tasks: Arc<Mutex<JoinSet<()>>>,
packet_recv: Arc<Mutex<Option<mpsc::Receiver<Bytes>>>>,
peers: Arc<PeerMap>,
peer_rpc_mgr: Arc<PeerRpcManager>,
peer_rpc_tspt: Arc<RpcTransport>,
peer_packet_process_pipeline: Arc<RwLock<Vec<BoxPeerPacketFilter>>>,
nic_packet_process_pipeline: Arc<RwLock<Vec<BoxNicPacketFilter>>>,
route_algo_inst: RouteAlgoInst,
foreign_network_manager: Arc<ForeignNetworkManager>,
foreign_network_client: Arc<ForeignNetworkClient>,
}
impl Debug for PeerManager {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("PeerManager")
.field("my_peer_id", &self.my_peer_id())
.field("instance_name", &self.global_ctx.inst_name)
.field("net_ns", &self.global_ctx.net_ns.name())
.finish()
}
}
impl PeerManager {
pub fn new(
route_algo: RouteAlgoType,
global_ctx: ArcGlobalCtx,
nic_channel: mpsc::Sender<SinkItem>,
) -> Self {
let my_peer_id = rand::random();
let (packet_send, packet_recv) = mpsc::channel(100);
let peers = Arc::new(PeerMap::new(
packet_send.clone(),
global_ctx.clone(),
my_peer_id,
));
// TODO: remove these because we have impl pipeline processor.
let (peer_rpc_tspt_sender, peer_rpc_tspt_recv) = mpsc::unbounded_channel();
let rpc_tspt = Arc::new(RpcTransport {
my_peer_id,
peers: Arc::downgrade(&peers),
foreign_peers: Mutex::new(None),
packet_recv: Mutex::new(peer_rpc_tspt_recv),
peer_rpc_tspt_sender,
});
let peer_rpc_mgr = Arc::new(PeerRpcManager::new(rpc_tspt.clone()));
let route_algo_inst = match route_algo {
RouteAlgoType::Rip => {
RouteAlgoInst::Rip(Arc::new(BasicRoute::new(my_peer_id, global_ctx.clone())))
}
RouteAlgoType::Ospf => RouteAlgoInst::Ospf(PeerRoute::new(
my_peer_id,
global_ctx.clone(),
peer_rpc_mgr.clone(),
)),
RouteAlgoType::None => RouteAlgoInst::None,
};
let foreign_network_manager = Arc::new(ForeignNetworkManager::new(
my_peer_id,
global_ctx.clone(),
packet_send.clone(),
));
let foreign_network_client = Arc::new(ForeignNetworkClient::new(
global_ctx.clone(),
packet_send.clone(),
peer_rpc_mgr.clone(),
my_peer_id,
));
PeerManager {
my_peer_id,
global_ctx,
nic_channel,
tasks: Arc::new(Mutex::new(JoinSet::new())),
packet_recv: Arc::new(Mutex::new(Some(packet_recv))),
peers: peers.clone(),
peer_rpc_mgr,
peer_rpc_tspt: rpc_tspt,
peer_packet_process_pipeline: Arc::new(RwLock::new(Vec::new())),
nic_packet_process_pipeline: Arc::new(RwLock::new(Vec::new())),
route_algo_inst,
foreign_network_manager,
foreign_network_client,
}
}
pub async fn add_client_tunnel(
&self,
tunnel: Box<dyn Tunnel>,
) -> Result<(PeerId, PeerConnId), Error> {
let mut peer = PeerConn::new(self.my_peer_id, self.global_ctx.clone(), tunnel);
peer.do_handshake_as_client().await?;
let conn_id = peer.get_conn_id();
let peer_id = peer.get_peer_id();
if peer.get_network_identity() == self.global_ctx.get_network_identity() {
self.peers.add_new_peer_conn(peer).await;
} else {
self.foreign_network_client.add_new_peer_conn(peer).await;
}
Ok((peer_id, conn_id))
}
#[tracing::instrument]
pub async fn try_connect<C>(&self, mut connector: C) -> Result<(PeerId, PeerConnId), Error>
where
C: TunnelConnector + Debug,
{
let ns = self.global_ctx.net_ns.clone();
let t = ns
.run_async(|| async move { connector.connect().await })
.await?;
self.add_client_tunnel(t).await
}
#[tracing::instrument]
pub async fn add_tunnel_as_server(&self, tunnel: Box<dyn Tunnel>) -> Result<(), Error> {
tracing::info!("add tunnel as server start");
let mut peer = PeerConn::new(self.my_peer_id, self.global_ctx.clone(), tunnel);
peer.do_handshake_as_server().await?;
if peer.get_network_identity() == self.global_ctx.get_network_identity() {
self.peers.add_new_peer_conn(peer).await;
} else {
self.foreign_network_manager.add_peer_conn(peer).await?;
}
tracing::info!("add tunnel as server done");
Ok(())
}
async fn start_peer_recv(&self) {
let mut recv = ReceiverStream::new(self.packet_recv.lock().await.take().unwrap());
let my_peer_id = self.my_peer_id;
let peers = self.peers.clone();
let pipe_line = self.peer_packet_process_pipeline.clone();
self.tasks.lock().await.spawn(async move {
log::trace!("start_peer_recv");
while let Some(ret) = recv.next().await {
log::trace!("peer recv a packet...: {:?}", ret);
let packet = packet::Packet::decode(&ret);
let from_peer_id: PeerId = packet.from_peer.into();
let to_peer_id: PeerId = packet.to_peer.into();
if to_peer_id != my_peer_id {
log::trace!(
"need forward: to_peer_id: {:?}, my_peer_id: {:?}",
to_peer_id,
my_peer_id
);
let ret = peers.send_msg(ret.clone(), to_peer_id).await;
if ret.is_err() {
log::error!(
"forward packet error: {:?}, dst: {:?}, from: {:?}",
ret,
to_peer_id,
from_peer_id
);
}
} else {
let mut processed = false;
for pipeline in pipe_line.read().await.iter().rev() {
if let Some(_) = pipeline.try_process_packet_from_peer(&packet, &ret).await
{
processed = true;
break;
}
}
if !processed {
tracing::error!("unexpected packet: {:?}", ret);
}
}
}
panic!("done_peer_recv");
});
}
pub async fn add_packet_process_pipeline(&self, pipeline: BoxPeerPacketFilter) {
// newest pipeline will be executed first
self.peer_packet_process_pipeline
.write()
.await
.push(pipeline);
}
pub async fn add_nic_packet_process_pipeline(&self, pipeline: BoxNicPacketFilter) {
// newest pipeline will be executed first
self.nic_packet_process_pipeline
.write()
.await
.push(pipeline);
}
async fn init_packet_process_pipeline(&self) {
// for tun/tap ip/eth packet.
struct NicPacketProcessor {
nic_channel: mpsc::Sender<SinkItem>,
}
#[async_trait::async_trait]
impl PeerPacketFilter for NicPacketProcessor {
async fn try_process_packet_from_peer(
&self,
packet: &packet::ArchivedPacket,
data: &Bytes,
) -> Option<()> {
if packet.packet_type == packet::PacketType::Data {
// TODO: use a function to get the body ref directly for zero copy
self.nic_channel
.send(extract_bytes_from_archived_string(data, &packet.payload))
.await
.unwrap();
Some(())
} else {
None
}
}
}
self.add_packet_process_pipeline(Box::new(NicPacketProcessor {
nic_channel: self.nic_channel.clone(),
}))
.await;
// for peer rpc packet
struct PeerRpcPacketProcessor {
peer_rpc_tspt_sender: UnboundedSender<Bytes>,
}
#[async_trait::async_trait]
impl PeerPacketFilter for PeerRpcPacketProcessor {
async fn try_process_packet_from_peer(
&self,
packet: &packet::ArchivedPacket,
data: &Bytes,
) -> Option<()> {
if packet.packet_type == packet::PacketType::TaRpc {
self.peer_rpc_tspt_sender.send(data.clone()).unwrap();
Some(())
} else {
None
}
}
}
self.add_packet_process_pipeline(Box::new(PeerRpcPacketProcessor {
peer_rpc_tspt_sender: self.peer_rpc_tspt.peer_rpc_tspt_sender.clone(),
}))
.await;
}
pub async fn add_route<T>(&self, route: T)
where
T: Route + PeerPacketFilter + Send + Sync + Clone + 'static,
{
// for route
self.add_packet_process_pipeline(Box::new(route.clone()))
.await;
struct Interface {
my_peer_id: PeerId,
peers: Weak<PeerMap>,
foreign_network_client: Weak<ForeignNetworkClient>,
}
#[async_trait]
impl RouteInterface for Interface {
async fn list_peers(&self) -> Vec<PeerId> {
let Some(foreign_client) = self.foreign_network_client.upgrade() else {
return vec![];
};
let Some(peer_map) = self.peers.upgrade() else {
return vec![];
};
let mut peers = foreign_client.list_foreign_peers();
peers.extend(peer_map.list_peers_with_conn().await);
peers
}
async fn send_route_packet(
&self,
msg: Bytes,
route_id: u8,
dst_peer_id: PeerId,
) -> Result<(), Error> {
let foreign_client = self
.foreign_network_client
.upgrade()
.ok_or(Error::Unknown)?;
let peer_map = self.peers.upgrade().ok_or(Error::Unknown)?;
let packet_bytes: Bytes =
packet::Packet::new_route_packet(self.my_peer_id, dst_peer_id, route_id, &msg)
.into();
if foreign_client.has_next_hop(dst_peer_id) {
return foreign_client.send_msg(packet_bytes, dst_peer_id).await;
}
peer_map.send_msg_directly(packet_bytes, dst_peer_id).await
}
fn my_peer_id(&self) -> PeerId {
self.my_peer_id
}
}
let my_peer_id = self.my_peer_id;
let _route_id = route
.open(Box::new(Interface {
my_peer_id,
peers: Arc::downgrade(&self.peers),
foreign_network_client: Arc::downgrade(&self.foreign_network_client),
}))
.await
.unwrap();
let arc_route: ArcRoute = Arc::new(Box::new(route));
self.peers.add_route(arc_route).await;
}
pub fn get_route(&self) -> Box<dyn Route + Send + Sync + 'static> {
match &self.route_algo_inst {
RouteAlgoInst::Rip(route) => Box::new(route.clone()),
RouteAlgoInst::Ospf(route) => Box::new(route.clone()),
RouteAlgoInst::None => panic!("no route"),
}
}
pub async fn list_routes(&self) -> Vec<crate::rpc::Route> {
self.get_route().list_routes().await
}
async fn run_nic_packet_process_pipeline(&self, mut data: BytesMut) -> BytesMut {
for pipeline in self.nic_packet_process_pipeline.read().await.iter().rev() {
data = pipeline.try_process_packet_from_nic(data).await;
}
data
}
pub async fn send_msg(&self, msg: Bytes, dst_peer_id: PeerId) -> Result<(), Error> {
self.peers.send_msg(msg, dst_peer_id).await
}
pub async fn send_msg_ipv4(&self, msg: BytesMut, ipv4_addr: Ipv4Addr) -> Result<(), Error> {
log::trace!(
"do send_msg in peer manager, msg: {:?}, ipv4_addr: {}",
msg,
ipv4_addr
);
let mut dst_peers = vec![];
// NOTE: currently we only support ipv4 and cidr is 24
if ipv4_addr.is_broadcast() || ipv4_addr.is_multicast() || ipv4_addr.octets()[3] == 255 {
dst_peers.extend(
self.peers
.list_routes()
.await
.iter()
.map(|x| x.key().clone()),
);
} else if let Some(peer_id) = self.peers.get_peer_id_by_ipv4(&ipv4_addr).await {
dst_peers.push(peer_id);
}
if dst_peers.is_empty() {
tracing::info!("no peer id for ipv4: {}", ipv4_addr);
return Ok(());
}
let msg = self.run_nic_packet_process_pipeline(msg).await;
let mut errs: Vec<Error> = vec![];
for peer_id in dst_peers.iter() {
let msg: Bytes =
packet::Packet::new_data_packet(self.my_peer_id, peer_id.clone(), &msg).into();
let send_ret = self.peers.send_msg(msg.clone(), *peer_id).await;
if matches!(send_ret, Err(Error::RouteError(..)))
&& self.foreign_network_client.has_next_hop(*peer_id)
{
let foreign_send_ret = self.foreign_network_client.send_msg(msg, *peer_id).await;
if foreign_send_ret.is_ok() {
continue;
}
}
if let Err(send_ret) = send_ret {
errs.push(send_ret);
}
}
tracing::trace!(?dst_peers, "do send_msg in peer manager done");
if errs.is_empty() {
Ok(())
} else {
tracing::error!(?errs, "send_msg has error");
Err(anyhow::anyhow!("send_msg has error: {:?}", errs).into())
}
}
async fn run_clean_peer_without_conn_routine(&self) {
let peer_map = self.peers.clone();
self.tasks.lock().await.spawn(async move {
loop {
peer_map.clean_peer_without_conn().await;
tokio::time::sleep(std::time::Duration::from_secs(3)).await;
}
});
}
async fn run_foriegn_network(&self) {
self.peer_rpc_tspt
.foreign_peers
.lock()
.await
.replace(Arc::downgrade(&self.foreign_network_client));
self.foreign_network_manager.run().await;
self.foreign_network_client.run().await;
}
pub async fn run(&self) -> Result<(), Error> {
match &self.route_algo_inst {
RouteAlgoInst::Ospf(route) => self.add_route(route.clone()).await,
RouteAlgoInst::Rip(route) => self.add_route(route.clone()).await,
RouteAlgoInst::None => {}
};
self.init_packet_process_pipeline().await;
self.peer_rpc_mgr.run();
self.start_peer_recv().await;
self.run_clean_peer_without_conn_routine().await;
self.run_foriegn_network().await;
Ok(())
}
pub fn get_peer_map(&self) -> Arc<PeerMap> {
self.peers.clone()
}
pub fn get_peer_rpc_mgr(&self) -> Arc<PeerRpcManager> {
self.peer_rpc_mgr.clone()
}
pub fn my_node_id(&self) -> uuid::Uuid {
self.global_ctx.get_id()
}
pub fn my_peer_id(&self) -> PeerId {
self.my_peer_id
}
pub fn get_global_ctx(&self) -> ArcGlobalCtx {
self.global_ctx.clone()
}
pub fn get_nic_channel(&self) -> mpsc::Sender<SinkItem> {
self.nic_channel.clone()
}
pub fn get_basic_route(&self) -> Arc<BasicRoute> {
match &self.route_algo_inst {
RouteAlgoInst::Rip(route) => route.clone(),
_ => panic!("not rip route"),
}
}
pub fn get_foreign_network_manager(&self) -> Arc<ForeignNetworkManager> {
self.foreign_network_manager.clone()
}
pub fn get_foreign_network_client(&self) -> Arc<ForeignNetworkClient> {
self.foreign_network_client.clone()
}
}
#[cfg(test)]
mod tests {
use crate::{
connector::udp_hole_punch::tests::create_mock_peer_manager_with_mock_stun,
peers::tests::{connect_peer_manager, wait_for_condition, wait_route_appear},
rpc::NatType,
};
#[tokio::test]
async fn drop_peer_manager() {
let peer_mgr_a = create_mock_peer_manager_with_mock_stun(NatType::Unknown).await;
let peer_mgr_b = create_mock_peer_manager_with_mock_stun(NatType::Unknown).await;
let peer_mgr_c = create_mock_peer_manager_with_mock_stun(NatType::Unknown).await;
connect_peer_manager(peer_mgr_a.clone(), peer_mgr_b.clone()).await;
connect_peer_manager(peer_mgr_b.clone(), peer_mgr_c.clone()).await;
connect_peer_manager(peer_mgr_a.clone(), peer_mgr_c.clone()).await;
wait_route_appear(peer_mgr_a.clone(), peer_mgr_b.clone())
.await
.unwrap();
wait_route_appear(peer_mgr_a.clone(), peer_mgr_c.clone())
.await
.unwrap();
// wait mgr_a have 2 peers
wait_for_condition(
|| async { peer_mgr_a.get_peer_map().list_peers_with_conn().await.len() == 2 },
std::time::Duration::from_secs(5),
)
.await;
drop(peer_mgr_b);
wait_for_condition(
|| async { peer_mgr_a.get_peer_map().list_peers_with_conn().await.len() == 1 },
std::time::Duration::from_secs(5),
)
.await;
}
}
+234
View File
@@ -0,0 +1,234 @@
use std::{net::Ipv4Addr, sync::Arc};
use anyhow::Context;
use dashmap::DashMap;
use tokio::sync::{mpsc, RwLock};
use tokio_util::bytes::Bytes;
use crate::{
common::{
error::Error,
global_ctx::{ArcGlobalCtx, GlobalCtxEvent},
PeerId,
},
rpc::PeerConnInfo,
tunnels::TunnelError,
};
use super::{
peer::Peer,
peer_conn::{PeerConn, PeerConnId},
route_trait::ArcRoute,
};
pub struct PeerMap {
global_ctx: ArcGlobalCtx,
my_peer_id: PeerId,
peer_map: DashMap<PeerId, Arc<Peer>>,
packet_send: mpsc::Sender<Bytes>,
routes: RwLock<Vec<ArcRoute>>,
}
impl PeerMap {
pub fn new(
packet_send: mpsc::Sender<Bytes>,
global_ctx: ArcGlobalCtx,
my_peer_id: PeerId,
) -> Self {
PeerMap {
global_ctx,
my_peer_id,
peer_map: DashMap::new(),
packet_send,
routes: RwLock::new(Vec::new()),
}
}
async fn add_new_peer(&self, peer: Peer) {
let peer_id = peer.peer_node_id.clone();
self.peer_map.insert(peer_id.clone(), Arc::new(peer));
self.global_ctx
.issue_event(GlobalCtxEvent::PeerAdded(peer_id));
}
pub async fn add_new_peer_conn(&self, peer_conn: PeerConn) {
let peer_id = peer_conn.get_peer_id();
let no_entry = self.peer_map.get(&peer_id).is_none();
if no_entry {
let new_peer = Peer::new(peer_id, self.packet_send.clone(), self.global_ctx.clone());
new_peer.add_peer_conn(peer_conn).await;
self.add_new_peer(new_peer).await;
} else {
let peer = self.peer_map.get(&peer_id).unwrap().clone();
peer.add_peer_conn(peer_conn).await;
}
}
fn get_peer_by_id(&self, peer_id: PeerId) -> Option<Arc<Peer>> {
self.peer_map.get(&peer_id).map(|v| v.clone())
}
pub fn has_peer(&self, peer_id: PeerId) -> bool {
self.peer_map.contains_key(&peer_id)
}
pub async fn send_msg_directly(&self, msg: Bytes, dst_peer_id: PeerId) -> Result<(), Error> {
if dst_peer_id == self.my_peer_id {
return Ok(self
.packet_send
.send(msg)
.await
.with_context(|| "send msg to self failed")?);
}
match self.get_peer_by_id(dst_peer_id) {
Some(peer) => {
peer.send_msg(msg).await?;
}
None => {
log::error!("no peer for dst_peer_id: {}", dst_peer_id);
return Err(Error::RouteError(None));
}
}
Ok(())
}
pub async fn send_msg(&self, msg: Bytes, dst_peer_id: PeerId) -> Result<(), Error> {
if dst_peer_id == self.my_peer_id {
return Ok(self
.packet_send
.send(msg)
.await
.with_context(|| "send msg to self failed")?);
}
// get route info
let mut gateway_peer_id = None;
for route in self.routes.read().await.iter() {
gateway_peer_id = route.get_next_hop(dst_peer_id).await;
if gateway_peer_id.is_none() {
continue;
} else {
break;
}
}
if gateway_peer_id.is_none() && self.has_peer(dst_peer_id) {
gateway_peer_id = Some(dst_peer_id);
}
let Some(gateway_peer_id) = gateway_peer_id else {
tracing::trace!(
"no gateway for dst_peer_id: {}, peers: {:?}, my_peer_id: {}",
dst_peer_id,
self.peer_map.iter().map(|v| *v.key()).collect::<Vec<_>>(),
self.my_peer_id
);
return Err(Error::RouteError(None));
};
self.send_msg_directly(msg.clone(), gateway_peer_id).await?;
return Ok(());
}
pub async fn get_peer_id_by_ipv4(&self, ipv4: &Ipv4Addr) -> Option<PeerId> {
for route in self.routes.read().await.iter() {
let peer_id = route.get_peer_id_by_ipv4(ipv4).await;
if peer_id.is_some() {
return peer_id;
}
}
None
}
pub fn is_empty(&self) -> bool {
self.peer_map.is_empty()
}
pub async fn list_peers(&self) -> Vec<PeerId> {
let mut ret = Vec::new();
for item in self.peer_map.iter() {
let peer_id = item.key();
ret.push(*peer_id);
}
ret
}
pub async fn list_peers_with_conn(&self) -> Vec<PeerId> {
let mut ret = Vec::new();
let peers = self.list_peers().await;
for peer_id in peers.iter() {
let Some(peer) = self.get_peer_by_id(*peer_id) else {
continue;
};
if peer.list_peer_conns().await.len() > 0 {
ret.push(*peer_id);
}
}
ret
}
pub async fn list_peer_conns(&self, peer_id: PeerId) -> Option<Vec<PeerConnInfo>> {
if let Some(p) = self.get_peer_by_id(peer_id) {
Some(p.list_peer_conns().await)
} else {
return None;
}
}
pub async fn close_peer_conn(
&self,
peer_id: PeerId,
conn_id: &PeerConnId,
) -> Result<(), Error> {
if let Some(p) = self.get_peer_by_id(peer_id) {
p.close_peer_conn(conn_id).await
} else {
return Err(Error::NotFound);
}
}
pub async fn close_peer(&self, peer_id: PeerId) -> Result<(), TunnelError> {
let remove_ret = self.peer_map.remove(&peer_id);
self.global_ctx
.issue_event(GlobalCtxEvent::PeerRemoved(peer_id));
tracing::info!(
?peer_id,
has_old_value = ?remove_ret.is_some(),
peer_ref_counter = ?remove_ret.map(|v| Arc::strong_count(&v.1)),
"peer is closed"
);
Ok(())
}
pub async fn add_route(&self, route: ArcRoute) {
let mut routes = self.routes.write().await;
routes.insert(0, route);
}
pub async fn clean_peer_without_conn(&self) {
let mut to_remove = vec![];
for peer_id in self.list_peers().await {
let conns = self.list_peer_conns(peer_id).await;
if conns.is_none() || conns.as_ref().unwrap().is_empty() {
to_remove.push(peer_id);
}
}
for peer_id in to_remove {
self.close_peer(peer_id).await.unwrap();
}
}
pub async fn list_routes(&self) -> DashMap<PeerId, PeerId> {
let route_map = DashMap::new();
for route in self.routes.read().await.iter() {
for item in route.list_routes().await.iter() {
route_map.insert(item.peer_id, item.next_hop_peer_id);
}
}
route_map
}
}
File diff suppressed because it is too large Load Diff
+770
View File
@@ -0,0 +1,770 @@
use std::{
net::Ipv4Addr,
sync::{atomic::AtomicU32, Arc},
time::{Duration, Instant},
};
use async_trait::async_trait;
use dashmap::DashMap;
use tokio::{
sync::{Mutex, RwLock},
task::JoinSet,
};
use tokio_util::bytes::Bytes;
use tracing::Instrument;
use crate::{
common::{error::Error, global_ctx::ArcGlobalCtx, stun::StunInfoCollectorTrait, PeerId},
peers::{
packet,
route_trait::{Route, RouteInterfaceBox},
},
rpc::{NatType, StunInfo},
};
use super::{packet::CtrlPacketPayload, PeerPacketFilter};
const SEND_ROUTE_PERIOD_SEC: u64 = 60;
const SEND_ROUTE_FAST_REPLY_SEC: u64 = 5;
const ROUTE_EXPIRED_SEC: u64 = 70;
type Version = u32;
#[derive(serde::Deserialize, serde::Serialize, Clone, Debug, PartialEq)]
// Derives can be passed through to the generated type:
pub struct SyncPeerInfo {
// means next hop in route table.
pub peer_id: PeerId,
pub cost: u32,
pub ipv4_addr: Option<Ipv4Addr>,
pub proxy_cidrs: Vec<String>,
pub hostname: Option<String>,
pub udp_stun_info: i8,
}
impl SyncPeerInfo {
pub fn new_self(from_peer: PeerId, global_ctx: &ArcGlobalCtx) -> Self {
SyncPeerInfo {
peer_id: from_peer,
cost: 0,
ipv4_addr: global_ctx.get_ipv4(),
proxy_cidrs: global_ctx
.get_proxy_cidrs()
.iter()
.map(|x| x.to_string())
.chain(global_ctx.get_vpn_portal_cidr().map(|x| x.to_string()))
.collect(),
hostname: global_ctx.get_hostname(),
udp_stun_info: global_ctx
.get_stun_info_collector()
.get_stun_info()
.udp_nat_type as i8,
}
}
pub fn clone_for_route_table(&self, next_hop: PeerId, cost: u32, from: &Self) -> Self {
SyncPeerInfo {
peer_id: next_hop,
cost,
ipv4_addr: from.ipv4_addr.clone(),
proxy_cidrs: from.proxy_cidrs.clone(),
hostname: from.hostname.clone(),
udp_stun_info: from.udp_stun_info,
}
}
}
#[derive(serde::Deserialize, serde::Serialize, Clone, Debug)]
pub struct SyncPeer {
pub myself: SyncPeerInfo,
pub neighbors: Vec<SyncPeerInfo>,
// the route table version of myself
pub version: Version,
// the route table version of peer that we have received last time
pub peer_version: Option<Version>,
// if we do not have latest peer version, need_reply is true
pub need_reply: bool,
}
impl SyncPeer {
pub fn new(
from_peer: PeerId,
_to_peer: PeerId,
neighbors: Vec<SyncPeerInfo>,
global_ctx: ArcGlobalCtx,
version: Version,
peer_version: Option<Version>,
need_reply: bool,
) -> Self {
SyncPeer {
myself: SyncPeerInfo::new_self(from_peer, &global_ctx),
neighbors,
version,
peer_version,
need_reply,
}
}
}
#[derive(Debug)]
struct SyncPeerFromRemote {
packet: SyncPeer,
last_update: std::time::Instant,
}
type SyncPeerFromRemoteMap = Arc<DashMap<PeerId, SyncPeerFromRemote>>;
#[derive(Debug)]
struct RouteTable {
route_info: DashMap<PeerId, SyncPeerInfo>,
ipv4_peer_id_map: DashMap<Ipv4Addr, PeerId>,
cidr_peer_id_map: DashMap<cidr::IpCidr, PeerId>,
}
impl RouteTable {
fn new() -> Self {
RouteTable {
route_info: DashMap::new(),
ipv4_peer_id_map: DashMap::new(),
cidr_peer_id_map: DashMap::new(),
}
}
fn copy_from(&self, other: &Self) {
self.route_info.clear();
for item in other.route_info.iter() {
let (k, v) = item.pair();
self.route_info.insert(*k, v.clone());
}
self.ipv4_peer_id_map.clear();
for item in other.ipv4_peer_id_map.iter() {
let (k, v) = item.pair();
self.ipv4_peer_id_map.insert(*k, *v);
}
self.cidr_peer_id_map.clear();
for item in other.cidr_peer_id_map.iter() {
let (k, v) = item.pair();
self.cidr_peer_id_map.insert(*k, *v);
}
}
}
#[derive(Debug, Clone)]
struct RouteVersion(Arc<AtomicU32>);
impl RouteVersion {
fn new() -> Self {
// RouteVersion(Arc::new(AtomicU32::new(rand::random())))
RouteVersion(Arc::new(AtomicU32::new(0)))
}
fn get(&self) -> Version {
self.0.load(std::sync::atomic::Ordering::Relaxed)
}
fn inc(&self) {
self.0.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
}
}
pub struct BasicRoute {
my_peer_id: PeerId,
global_ctx: ArcGlobalCtx,
interface: Arc<Mutex<Option<RouteInterfaceBox>>>,
route_table: Arc<RouteTable>,
sync_peer_from_remote: SyncPeerFromRemoteMap,
tasks: Mutex<JoinSet<()>>,
need_sync_notifier: Arc<tokio::sync::Notify>,
version: RouteVersion,
myself: Arc<RwLock<SyncPeerInfo>>,
last_send_time_map: Arc<DashMap<PeerId, (Version, Option<Version>, Instant)>>,
}
impl BasicRoute {
pub fn new(my_peer_id: PeerId, global_ctx: ArcGlobalCtx) -> Self {
BasicRoute {
my_peer_id,
global_ctx: global_ctx.clone(),
interface: Arc::new(Mutex::new(None)),
route_table: Arc::new(RouteTable::new()),
sync_peer_from_remote: Arc::new(DashMap::new()),
tasks: Mutex::new(JoinSet::new()),
need_sync_notifier: Arc::new(tokio::sync::Notify::new()),
version: RouteVersion::new(),
myself: Arc::new(RwLock::new(SyncPeerInfo::new_self(
my_peer_id.into(),
&global_ctx,
))),
last_send_time_map: Arc::new(DashMap::new()),
}
}
fn update_route_table(
my_id: PeerId,
sync_peer_reqs: SyncPeerFromRemoteMap,
route_table: Arc<RouteTable>,
) {
tracing::trace!(my_id = ?my_id, route_table = ?route_table, "update route table");
let new_route_table = Arc::new(RouteTable::new());
for item in sync_peer_reqs.iter() {
Self::update_route_table_with_req(my_id, &item.value().packet, new_route_table.clone());
}
route_table.copy_from(&new_route_table);
}
async fn update_myself(
my_peer_id: PeerId,
myself: &Arc<RwLock<SyncPeerInfo>>,
global_ctx: &ArcGlobalCtx,
) -> bool {
let new_myself = SyncPeerInfo::new_self(my_peer_id, &global_ctx);
if *myself.read().await != new_myself {
*myself.write().await = new_myself;
true
} else {
false
}
}
fn update_route_table_with_req(my_id: PeerId, packet: &SyncPeer, route_table: Arc<RouteTable>) {
let peer_id = packet.myself.peer_id.clone();
let update = |cost: u32, peer_info: &SyncPeerInfo| {
let node_id: PeerId = peer_info.peer_id.into();
let ret = route_table
.route_info
.entry(node_id.clone().into())
.and_modify(|info| {
if info.cost > cost {
*info = info.clone_for_route_table(peer_id, cost, &peer_info);
}
})
.or_insert(
peer_info
.clone()
.clone_for_route_table(peer_id, cost, &peer_info),
)
.value()
.clone();
if ret.cost > 6 {
log::error!(
"cost too large: {}, may lost connection, remove it",
ret.cost
);
route_table.route_info.remove(&node_id);
}
log::trace!(
"update route info, to: {:?}, gateway: {:?}, cost: {}, peer: {:?}",
node_id,
peer_id,
cost,
&peer_info
);
if let Some(ipv4) = peer_info.ipv4_addr {
route_table
.ipv4_peer_id_map
.insert(ipv4.clone(), node_id.clone().into());
}
for cidr in peer_info.proxy_cidrs.iter() {
let cidr: cidr::IpCidr = cidr.parse().unwrap();
route_table
.cidr_peer_id_map
.insert(cidr, node_id.clone().into());
}
};
for neighbor in packet.neighbors.iter() {
if neighbor.peer_id == my_id {
continue;
}
update(neighbor.cost + 1, &neighbor);
log::trace!("route info: {:?}", neighbor);
}
// add the sender peer to route info
update(1, &packet.myself);
log::trace!("my_id: {:?}, current route table: {:?}", my_id, route_table);
}
async fn send_sync_peer_request(
interface: &RouteInterfaceBox,
my_peer_id: PeerId,
global_ctx: ArcGlobalCtx,
peer_id: PeerId,
route_table: Arc<RouteTable>,
my_version: Version,
peer_version: Option<Version>,
need_reply: bool,
) -> Result<(), Error> {
let mut route_info_copy: Vec<SyncPeerInfo> = Vec::new();
// copy the route info
for item in route_table.route_info.iter() {
let (k, v) = item.pair();
route_info_copy.push(v.clone().clone_for_route_table(*k, v.cost, &v));
}
let msg = SyncPeer::new(
my_peer_id,
peer_id,
route_info_copy,
global_ctx,
my_version,
peer_version,
need_reply,
);
// TODO: this may exceed the MTU of the tunnel
interface
.send_route_packet(postcard::to_allocvec(&msg).unwrap().into(), 1, peer_id)
.await
}
async fn sync_peer_periodically(&self) {
let route_table = self.route_table.clone();
let global_ctx = self.global_ctx.clone();
let my_peer_id = self.my_peer_id.clone();
let interface = self.interface.clone();
let notifier = self.need_sync_notifier.clone();
let sync_peer_from_remote = self.sync_peer_from_remote.clone();
let myself = self.myself.clone();
let version = self.version.clone();
let last_send_time_map = self.last_send_time_map.clone();
self.tasks.lock().await.spawn(
async move {
loop {
if Self::update_myself(my_peer_id,&myself, &global_ctx).await {
version.inc();
tracing::info!(
my_id = ?my_peer_id,
version = version.get(),
"update route table version when myself changed"
);
}
let lockd_interface = interface.lock().await;
let interface = lockd_interface.as_ref().unwrap();
let last_send_time_map_new = DashMap::new();
let peers = interface.list_peers().await;
for peer in peers.iter() {
let last_send_time = last_send_time_map.get(peer).map(|v| *v).unwrap_or((0, None, Instant::now() - Duration::from_secs(3600)));
let my_version_peer_saved = sync_peer_from_remote.get(peer).and_then(|v| v.packet.peer_version);
let peer_have_latest_version = my_version_peer_saved == Some(version.get());
if peer_have_latest_version && last_send_time.2.elapsed().as_secs() < SEND_ROUTE_PERIOD_SEC {
last_send_time_map_new.insert(*peer, last_send_time);
continue;
}
tracing::trace!(
my_id = ?my_peer_id,
dst_peer_id = ?peer,
version = version.get(),
?my_version_peer_saved,
last_send_version = ?last_send_time.0,
last_send_peer_version = ?last_send_time.1,
last_send_elapse = ?last_send_time.2.elapsed().as_secs(),
"need send route info"
);
let peer_version_we_saved = sync_peer_from_remote.get(&peer).and_then(|v| Some(v.packet.version));
last_send_time_map_new.insert(*peer, (version.get(), peer_version_we_saved, Instant::now()));
let ret = Self::send_sync_peer_request(
interface,
my_peer_id.clone(),
global_ctx.clone(),
*peer,
route_table.clone(),
version.get(),
peer_version_we_saved,
!peer_have_latest_version,
)
.await;
match &ret {
Ok(_) => {
log::trace!("send sync peer request to peer: {}", peer);
}
Err(Error::PeerNoConnectionError(_)) => {
log::trace!("peer {} no connection", peer);
}
Err(e) => {
log::error!(
"send sync peer request to peer: {} error: {:?}",
peer,
e
);
}
};
}
last_send_time_map.clear();
for item in last_send_time_map_new.iter() {
let (k, v) = item.pair();
last_send_time_map.insert(*k, *v);
}
tokio::select! {
_ = notifier.notified() => {
log::trace!("sync peer request triggered by notifier");
}
_ = tokio::time::sleep(Duration::from_secs(1)) => {
log::trace!("sync peer request triggered by timeout");
}
}
}
}
.instrument(
tracing::info_span!("sync_peer_periodically", my_id = ?self.my_peer_id, global_ctx = ?self.global_ctx),
),
);
}
async fn check_expired_sync_peer_from_remote(&self) {
let route_table = self.route_table.clone();
let my_peer_id = self.my_peer_id.clone();
let sync_peer_from_remote = self.sync_peer_from_remote.clone();
let notifier = self.need_sync_notifier.clone();
let interface = self.interface.clone();
let version = self.version.clone();
self.tasks.lock().await.spawn(async move {
loop {
let mut need_update_route = false;
let now = std::time::Instant::now();
let mut need_remove = Vec::new();
let connected_peers = interface.lock().await.as_ref().unwrap().list_peers().await;
for item in sync_peer_from_remote.iter() {
let (k, v) = item.pair();
if now.duration_since(v.last_update).as_secs() > ROUTE_EXPIRED_SEC
|| !connected_peers.contains(k)
{
need_update_route = true;
need_remove.insert(0, k.clone());
}
}
for k in need_remove.iter() {
log::warn!("remove expired sync peer: {:?}", k);
sync_peer_from_remote.remove(k);
}
if need_update_route {
Self::update_route_table(
my_peer_id,
sync_peer_from_remote.clone(),
route_table.clone(),
);
version.inc();
tracing::info!(
my_id = ?my_peer_id,
version = version.get(),
"update route table when check expired peer"
);
notifier.notify_one();
}
tokio::time::sleep(Duration::from_secs(1)).await;
}
});
}
fn get_peer_id_for_proxy(&self, ipv4: &Ipv4Addr) -> Option<PeerId> {
let ipv4 = std::net::IpAddr::V4(*ipv4);
for item in self.route_table.cidr_peer_id_map.iter() {
let (k, v) = item.pair();
if k.contains(&ipv4) {
return Some(*v);
}
}
None
}
#[tracing::instrument(skip(self, packet), fields(my_id = ?self.my_peer_id, ctx = ?self.global_ctx))]
async fn handle_route_packet(&self, src_peer_id: PeerId, packet: Bytes) {
let packet = postcard::from_bytes::<SyncPeer>(&packet).unwrap();
let p = &packet;
let mut updated = true;
assert_eq!(packet.myself.peer_id, src_peer_id);
self.sync_peer_from_remote
.entry(packet.myself.peer_id.into())
.and_modify(|v| {
if v.packet.myself == p.myself && v.packet.neighbors == p.neighbors {
updated = false;
} else {
v.packet = p.clone();
}
v.packet.version = p.version;
v.packet.peer_version = p.peer_version;
v.last_update = std::time::Instant::now();
})
.or_insert(SyncPeerFromRemote {
packet: p.clone(),
last_update: std::time::Instant::now(),
});
if updated {
Self::update_route_table(
self.my_peer_id.clone(),
self.sync_peer_from_remote.clone(),
self.route_table.clone(),
);
self.version.inc();
tracing::info!(
my_id = ?self.my_peer_id,
?p,
version = self.version.get(),
"update route table when receive route packet"
);
}
if packet.need_reply {
self.last_send_time_map
.entry(packet.myself.peer_id.into())
.and_modify(|v| {
const FAST_REPLY_DURATION: u64 =
SEND_ROUTE_PERIOD_SEC - SEND_ROUTE_FAST_REPLY_SEC;
if v.0 != self.version.get() || v.1 != Some(p.version) {
v.2 = Instant::now() - Duration::from_secs(3600);
} else if v.2.elapsed().as_secs() < FAST_REPLY_DURATION {
// do not send same version route info too frequently
v.2 = Instant::now() - Duration::from_secs(FAST_REPLY_DURATION);
}
});
}
if updated || packet.need_reply {
self.need_sync_notifier.notify_one();
}
}
}
#[async_trait]
impl Route for BasicRoute {
async fn open(&self, interface: RouteInterfaceBox) -> Result<u8, ()> {
*self.interface.lock().await = Some(interface);
self.sync_peer_periodically().await;
self.check_expired_sync_peer_from_remote().await;
Ok(1)
}
async fn close(&self) {}
async fn get_next_hop(&self, dst_peer_id: PeerId) -> Option<PeerId> {
match self.route_table.route_info.get(&dst_peer_id) {
Some(info) => {
return Some(info.peer_id.clone().into());
}
None => {
log::error!("no route info for dst_peer_id: {}", dst_peer_id);
return None;
}
}
}
async fn list_routes(&self) -> Vec<crate::rpc::Route> {
let mut routes = Vec::new();
let parse_route_info = |real_peer_id: PeerId, route_info: &SyncPeerInfo| {
let mut route = crate::rpc::Route::default();
route.ipv4_addr = if let Some(ipv4_addr) = route_info.ipv4_addr {
ipv4_addr.to_string()
} else {
"".to_string()
};
route.peer_id = real_peer_id;
route.next_hop_peer_id = route_info.peer_id;
route.cost = route_info.cost as i32;
route.proxy_cidrs = route_info.proxy_cidrs.clone();
route.hostname = if let Some(hostname) = &route_info.hostname {
hostname.clone()
} else {
"".to_string()
};
let mut stun_info = StunInfo::default();
if let Ok(udp_nat_type) = NatType::try_from(route_info.udp_stun_info as i32) {
stun_info.set_udp_nat_type(udp_nat_type);
}
route.stun_info = Some(stun_info);
route
};
self.route_table.route_info.iter().for_each(|item| {
routes.push(parse_route_info(*item.key(), item.value()));
});
routes
}
async fn get_peer_id_by_ipv4(&self, ipv4_addr: &Ipv4Addr) -> Option<PeerId> {
if let Some(peer_id) = self.route_table.ipv4_peer_id_map.get(ipv4_addr) {
return Some(*peer_id);
}
if let Some(peer_id) = self.get_peer_id_for_proxy(ipv4_addr) {
return Some(peer_id);
}
log::info!("no peer id for ipv4: {}", ipv4_addr);
return None;
}
}
#[async_trait::async_trait]
impl PeerPacketFilter for BasicRoute {
async fn try_process_packet_from_peer(
&self,
packet: &packet::ArchivedPacket,
_data: &Bytes,
) -> Option<()> {
if packet.packet_type == packet::PacketType::RoutePacket {
let CtrlPacketPayload::RoutePacket(route_packet) =
CtrlPacketPayload::from_packet(packet)
else {
return None;
};
self.handle_route_packet(
packet.from_peer.into(),
route_packet.body.into_boxed_slice().into(),
)
.await;
Some(())
} else {
None
}
}
}
#[cfg(test)]
mod tests {
use std::sync::Arc;
use crate::{
common::{global_ctx::tests::get_mock_global_ctx, PeerId},
connector::udp_hole_punch::tests::replace_stun_info_collector,
peers::{
peer_manager::{PeerManager, RouteAlgoType},
peer_rip_route::Version,
tests::{connect_peer_manager, wait_route_appear},
},
rpc::NatType,
};
async fn create_mock_pmgr() -> Arc<PeerManager> {
let (s, _r) = tokio::sync::mpsc::channel(1000);
let peer_mgr = Arc::new(PeerManager::new(
RouteAlgoType::Rip,
get_mock_global_ctx(),
s,
));
replace_stun_info_collector(peer_mgr.clone(), NatType::Unknown);
peer_mgr.run().await.unwrap();
peer_mgr
}
#[tokio::test]
async fn test_rip_route() {
let peer_mgr_a = create_mock_pmgr().await;
let peer_mgr_b = create_mock_pmgr().await;
let peer_mgr_c = create_mock_pmgr().await;
connect_peer_manager(peer_mgr_a.clone(), peer_mgr_b.clone()).await;
connect_peer_manager(peer_mgr_b.clone(), peer_mgr_c.clone()).await;
wait_route_appear(peer_mgr_a.clone(), peer_mgr_b.clone())
.await
.unwrap();
wait_route_appear(peer_mgr_a.clone(), peer_mgr_c.clone())
.await
.unwrap();
let mgrs = vec![peer_mgr_a.clone(), peer_mgr_b.clone(), peer_mgr_c.clone()];
tokio::time::sleep(tokio::time::Duration::from_secs(4)).await;
let check_version = |version: Version, peer_id: PeerId, mgrs: &Vec<Arc<PeerManager>>| {
for mgr in mgrs.iter() {
tracing::warn!(
"check version: {:?}, {:?}, {:?}, {:?}",
version,
peer_id,
mgr,
mgr.get_basic_route().sync_peer_from_remote
);
assert_eq!(
version,
mgr.get_basic_route()
.sync_peer_from_remote
.get(&peer_id)
.unwrap()
.packet
.version,
);
assert_eq!(
mgr.get_basic_route()
.sync_peer_from_remote
.get(&peer_id)
.unwrap()
.packet
.peer_version
.unwrap(),
mgr.get_basic_route().version.get()
);
}
};
let check_sanity = || {
// check peer version in other peer mgr are correct.
check_version(
peer_mgr_b.get_basic_route().version.get(),
peer_mgr_b.my_peer_id(),
&vec![peer_mgr_a.clone(), peer_mgr_c.clone()],
);
check_version(
peer_mgr_a.get_basic_route().version.get(),
peer_mgr_a.my_peer_id(),
&vec![peer_mgr_b.clone()],
);
check_version(
peer_mgr_c.get_basic_route().version.get(),
peer_mgr_c.my_peer_id(),
&vec![peer_mgr_b.clone()],
);
};
check_sanity();
let versions = mgrs
.iter()
.map(|x| x.get_basic_route().version.get())
.collect::<Vec<_>>();
tokio::time::sleep(tokio::time::Duration::from_secs(5)).await;
let versions2 = mgrs
.iter()
.map(|x| x.get_basic_route().version.get())
.collect::<Vec<_>>();
assert_eq!(versions, versions2);
check_sanity();
assert!(peer_mgr_a.get_basic_route().version.get() <= 3);
assert!(peer_mgr_b.get_basic_route().version.get() <= 6);
assert!(peer_mgr_c.get_basic_route().version.get() <= 3);
}
}
+581
View File
@@ -0,0 +1,581 @@
use std::sync::{atomic::AtomicU32, Arc};
use dashmap::DashMap;
use futures::{SinkExt, StreamExt};
use rkyv::Deserialize;
use tarpc::{server::Channel, transport::channel::UnboundedChannel};
use tokio::{
sync::mpsc::{self, UnboundedSender},
task::JoinSet,
};
use tokio_util::bytes::Bytes;
use tracing::Instrument;
use crate::{
common::{error::Error, PeerId},
peers::packet::Packet,
};
use super::packet::CtrlPacketPayload;
type PeerRpcServiceId = u32;
type PeerRpcTransactId = u32;
#[async_trait::async_trait]
#[auto_impl::auto_impl(Arc)]
pub trait PeerRpcManagerTransport: Send + Sync + 'static {
fn my_peer_id(&self) -> PeerId;
async fn send(&self, msg: Bytes, dst_peer_id: PeerId) -> Result<(), Error>;
async fn recv(&self) -> Result<Bytes, Error>;
}
type PacketSender = UnboundedSender<Packet>;
struct PeerRpcEndPoint {
peer_id: PeerId,
packet_sender: PacketSender,
tasks: JoinSet<()>,
}
type PeerRpcEndPointCreator = Box<dyn Fn(PeerId) -> PeerRpcEndPoint + Send + Sync + 'static>;
#[derive(Hash, Eq, PartialEq, Clone)]
struct PeerRpcClientCtxKey(PeerId, PeerRpcServiceId, PeerRpcTransactId);
// handle rpc request from one peer
pub struct PeerRpcManager {
service_map: Arc<DashMap<PeerRpcServiceId, PacketSender>>,
tasks: JoinSet<()>,
tspt: Arc<Box<dyn PeerRpcManagerTransport>>,
service_registry: Arc<DashMap<PeerRpcServiceId, PeerRpcEndPointCreator>>,
peer_rpc_endpoints: Arc<DashMap<(PeerId, PeerRpcServiceId), PeerRpcEndPoint>>,
client_resp_receivers: Arc<DashMap<PeerRpcClientCtxKey, PacketSender>>,
transact_id: AtomicU32,
}
impl std::fmt::Debug for PeerRpcManager {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("PeerRpcManager")
.field("node_id", &self.tspt.my_peer_id())
.finish()
}
}
#[derive(Debug)]
struct TaRpcPacketInfo {
from_peer: PeerId,
to_peer: PeerId,
service_id: PeerRpcServiceId,
transact_id: PeerRpcTransactId,
is_req: bool,
content: Vec<u8>,
}
impl PeerRpcManager {
pub fn new(tspt: impl PeerRpcManagerTransport) -> Self {
Self {
service_map: Arc::new(DashMap::new()),
tasks: JoinSet::new(),
tspt: Arc::new(Box::new(tspt)),
service_registry: Arc::new(DashMap::new()),
peer_rpc_endpoints: Arc::new(DashMap::new()),
client_resp_receivers: Arc::new(DashMap::new()),
transact_id: AtomicU32::new(0),
}
}
pub fn run_service<S, Req>(self: &Self, service_id: PeerRpcServiceId, s: S) -> ()
where
S: tarpc::server::Serve<Req> + Clone + Send + Sync + 'static,
Req: Send + 'static + serde::Serialize + for<'a> serde::Deserialize<'a>,
S::Resp:
Send + std::fmt::Debug + 'static + serde::Serialize + for<'a> serde::Deserialize<'a>,
S::Fut: Send + 'static,
{
let tspt = self.tspt.clone();
let creator = Box::new(move |peer_id: PeerId| {
let mut tasks = JoinSet::new();
let (packet_sender, mut packet_receiver) = mpsc::unbounded_channel::<Packet>();
let (mut client_transport, server_transport) = tarpc::transport::channel::unbounded();
let server = tarpc::server::BaseChannel::with_defaults(server_transport);
let my_peer_id_clone = tspt.my_peer_id();
let peer_id_clone = peer_id.clone();
let o = server.execute(s.clone());
tasks.spawn(o);
let tspt = tspt.clone();
tasks.spawn(async move {
let mut cur_req_peer_id = None;
let mut cur_transact_id = 0;
loop {
tokio::select! {
Some(resp) = client_transport.next() => {
let Some(cur_req_peer_id) = cur_req_peer_id.take() else {
tracing::error!("[PEER RPC MGR] cur_req_peer_id is none, ignore this resp");
continue;
};
tracing::trace!(resp = ?resp, "recv packet from client");
if resp.is_err() {
tracing::warn!(err = ?resp.err(),
"[PEER RPC MGR] client_transport in server side got channel error, ignore it.");
continue;
}
let resp = resp.unwrap();
let serialized_resp = postcard::to_allocvec(&resp);
if serialized_resp.is_err() {
tracing::error!(error = ?serialized_resp.err(), "serialize resp failed");
continue;
}
let msg = Packet::new_tarpc_packet(
tspt.my_peer_id(),
cur_req_peer_id,
service_id,
cur_transact_id,
false,
serialized_resp.unwrap(),
);
if let Err(e) = tspt.send(msg.into(), peer_id).await {
tracing::error!(error = ?e, peer_id = ?peer_id, service_id = ?service_id, "send resp to peer failed");
}
}
Some(packet) = packet_receiver.recv() => {
let info = Self::parse_rpc_packet(&packet);
if let Err(e) = info {
tracing::error!(error = ?e, packet = ?packet, "parse rpc packet failed");
continue;
}
let info = info.unwrap();
if info.from_peer != peer_id {
tracing::warn!("recv packet from peer, but peer_id not match, ignore it");
continue;
}
if cur_req_peer_id.is_some() {
tracing::warn!("cur_req_peer_id is not none, ignore this packet");
continue;
}
assert_eq!(info.service_id, service_id);
cur_req_peer_id = Some(packet.from_peer.clone().into());
cur_transact_id = info.transact_id;
tracing::trace!("recv packet from peer, packet: {:?}", packet);
let decoded_ret = postcard::from_bytes(&info.content.as_slice());
if let Err(e) = decoded_ret {
tracing::error!(error = ?e, "decode rpc packet failed");
continue;
}
let decoded: tarpc::ClientMessage<Req> = decoded_ret.unwrap();
if let Err(e) = client_transport.send(decoded).await {
tracing::error!(error = ?e, "send to req to client transport failed");
}
}
else => {
tracing::warn!("[PEER RPC MGR] service runner destroy, peer_id: {}, service_id: {}", peer_id, service_id);
}
}
}
}.instrument(tracing::info_span!("service_runner", my_id = ?my_peer_id_clone, peer_id = ?peer_id_clone, service_id = ?service_id)));
tracing::info!(
"[PEER RPC MGR] create new service endpoint for peer {}, service {}",
peer_id,
service_id
);
return PeerRpcEndPoint {
peer_id,
packet_sender,
tasks,
};
// let resp = client_transport.next().await;
});
if let Some(_) = self.service_registry.insert(service_id, creator) {
panic!(
"[PEER RPC MGR] service {} is already registered",
service_id
);
}
log::info!(
"[PEER RPC MGR] register service {} succeed, my_node_id {}",
service_id,
self.tspt.my_peer_id()
)
}
fn parse_rpc_packet(packet: &Packet) -> Result<TaRpcPacketInfo, Error> {
let ctrl_packet_payload = CtrlPacketPayload::from_packet2(&packet);
match &ctrl_packet_payload {
CtrlPacketPayload::TaRpc(id, tid, is_req, body) => Ok(TaRpcPacketInfo {
from_peer: packet.from_peer.into(),
to_peer: packet.to_peer.into(),
service_id: *id,
transact_id: *tid,
is_req: *is_req,
content: body.clone(),
}),
_ => Err(Error::ShellCommandError("invalid packet".to_owned())),
}
}
pub fn run(&self) {
let tspt = self.tspt.clone();
let service_registry = self.service_registry.clone();
let peer_rpc_endpoints = self.peer_rpc_endpoints.clone();
let client_resp_receivers = self.client_resp_receivers.clone();
tokio::spawn(async move {
loop {
let Ok(o) = tspt.recv().await else {
tracing::warn!("peer rpc transport read aborted, exiting");
break;
};
let packet = Packet::decode(&o);
let packet: Packet = packet.deserialize(&mut rkyv::Infallible).unwrap();
let info = Self::parse_rpc_packet(&packet).unwrap();
if info.is_req {
if !service_registry.contains_key(&info.service_id) {
log::warn!(
"service {} not found, my_node_id: {}",
info.service_id,
tspt.my_peer_id()
);
continue;
}
let endpoint = peer_rpc_endpoints
.entry((info.from_peer, info.service_id))
.or_insert_with(|| {
service_registry.get(&info.service_id).unwrap()(info.from_peer)
});
endpoint.packet_sender.send(packet).unwrap();
} else {
if let Some(a) = client_resp_receivers.get(&PeerRpcClientCtxKey(
info.from_peer,
info.service_id,
info.transact_id,
)) {
log::trace!("recv resp: {:?}", packet);
if let Err(e) = a.send(packet) {
tracing::error!(error = ?e, "send resp to client failed");
}
} else {
log::warn!("client resp receiver not found, info: {:?}", info);
}
}
}
});
}
#[tracing::instrument(skip(f))]
pub async fn do_client_rpc_scoped<CM, Req, RpcRet, Fut>(
&self,
service_id: PeerRpcServiceId,
dst_peer_id: PeerId,
f: impl FnOnce(UnboundedChannel<CM, Req>) -> Fut,
) -> RpcRet
where
CM: serde::Serialize + for<'a> serde::Deserialize<'a> + Send + Sync + 'static,
Req: serde::Serialize + for<'a> serde::Deserialize<'a> + Send + Sync + 'static,
Fut: std::future::Future<Output = RpcRet>,
{
let mut tasks = JoinSet::new();
let (packet_sender, mut packet_receiver) = mpsc::unbounded_channel::<Packet>();
let (client_transport, server_transport) =
tarpc::transport::channel::unbounded::<CM, Req>();
let (mut server_s, mut server_r) = server_transport.split();
let transact_id = self
.transact_id
.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
let tspt = self.tspt.clone();
tasks.spawn(async move {
while let Some(a) = server_r.next().await {
if a.is_err() {
tracing::error!(error = ?a.err(), "channel error");
continue;
}
let a = postcard::to_allocvec(&a.unwrap());
if a.is_err() {
tracing::error!(error = ?a.err(), "bincode serialize failed");
continue;
}
let a = Packet::new_tarpc_packet(
tspt.my_peer_id(),
dst_peer_id,
service_id,
transact_id,
true,
a.unwrap(),
);
if let Err(e) = tspt.send(a.into(), dst_peer_id).await {
tracing::error!(error = ?e, dst_peer_id = ?dst_peer_id, "send to peer failed");
}
}
tracing::warn!("[PEER RPC MGR] server trasport read aborted");
});
tasks.spawn(async move {
while let Some(packet) = packet_receiver.recv().await {
tracing::trace!("tunnel recv: {:?}", packet);
let info = PeerRpcManager::parse_rpc_packet(&packet);
if let Err(e) = info {
tracing::error!(error = ?e, "parse rpc packet failed");
continue;
}
let decoded = postcard::from_bytes(&info.unwrap().content.as_slice());
if let Err(e) = decoded {
tracing::error!(error = ?e, "decode rpc packet failed");
continue;
}
if let Err(e) = server_s.send(decoded.unwrap()).await {
tracing::error!(error = ?e, "send to rpc server channel failed");
}
}
tracing::warn!("[PEER RPC MGR] server packet read aborted");
});
let key = PeerRpcClientCtxKey(dst_peer_id, service_id, transact_id);
let _insert_ret = self
.client_resp_receivers
.insert(key.clone(), packet_sender);
let ret = f(client_transport).await;
self.client_resp_receivers.remove(&key);
ret
}
pub fn my_peer_id(&self) -> PeerId {
self.tspt.my_peer_id()
}
}
#[cfg(test)]
mod tests {
use futures::{SinkExt, StreamExt};
use tokio_util::bytes::Bytes;
use crate::{
common::{error::Error, new_peer_id, PeerId},
peers::{
peer_rpc::PeerRpcManager,
tests::{connect_peer_manager, create_mock_peer_manager, wait_route_appear},
},
tunnels::{self, ring_tunnel::create_ring_tunnel_pair},
};
use super::PeerRpcManagerTransport;
#[tarpc::service]
pub trait TestRpcService {
async fn hello(s: String) -> String;
}
#[derive(Clone)]
struct MockService {
prefix: String,
}
#[tarpc::server]
impl TestRpcService for MockService {
async fn hello(self, _: tarpc::context::Context, s: String) -> String {
format!("{} {}", self.prefix, s)
}
}
#[tokio::test]
async fn peer_rpc_basic_test() {
struct MockTransport {
tunnel: Box<dyn tunnels::Tunnel>,
my_peer_id: PeerId,
}
#[async_trait::async_trait]
impl PeerRpcManagerTransport for MockTransport {
fn my_peer_id(&self) -> PeerId {
self.my_peer_id
}
async fn send(&self, msg: Bytes, _dst_peer_id: PeerId) -> Result<(), Error> {
println!("rpc mgr send: {:?}", msg);
self.tunnel.pin_sink().send(msg).await.unwrap();
Ok(())
}
async fn recv(&self) -> Result<Bytes, Error> {
let ret = self.tunnel.pin_stream().next().await.unwrap();
println!("rpc mgr recv: {:?}", ret);
return ret.map(|v| v.freeze()).map_err(|_| Error::Unknown);
}
}
let (ct, st) = create_ring_tunnel_pair();
let server_rpc_mgr = PeerRpcManager::new(MockTransport {
tunnel: st,
my_peer_id: new_peer_id(),
});
server_rpc_mgr.run();
let s = MockService {
prefix: "hello".to_owned(),
};
server_rpc_mgr.run_service(1, s.serve());
let client_rpc_mgr = PeerRpcManager::new(MockTransport {
tunnel: ct,
my_peer_id: new_peer_id(),
});
client_rpc_mgr.run();
let ret = client_rpc_mgr
.do_client_rpc_scoped(1, server_rpc_mgr.my_peer_id(), |c| async {
let c = TestRpcServiceClient::new(tarpc::client::Config::default(), c).spawn();
let ret = c.hello(tarpc::context::current(), "abc".to_owned()).await;
ret
})
.await;
println!("ret: {:?}", ret);
assert_eq!(ret.unwrap(), "hello abc");
}
#[tokio::test]
async fn test_rpc_with_peer_manager() {
let peer_mgr_a = create_mock_peer_manager().await;
let peer_mgr_b = create_mock_peer_manager().await;
let peer_mgr_c = create_mock_peer_manager().await;
connect_peer_manager(peer_mgr_a.clone(), peer_mgr_b.clone()).await;
connect_peer_manager(peer_mgr_b.clone(), peer_mgr_c.clone()).await;
wait_route_appear(peer_mgr_a.clone(), peer_mgr_b.clone())
.await
.unwrap();
wait_route_appear(peer_mgr_a.clone(), peer_mgr_c.clone())
.await
.unwrap();
assert_eq!(peer_mgr_a.get_peer_map().list_peers().await.len(), 1);
assert_eq!(
peer_mgr_a.get_peer_map().list_peers().await[0],
peer_mgr_b.my_peer_id()
);
assert_eq!(peer_mgr_c.get_peer_map().list_peers().await.len(), 1);
assert_eq!(
peer_mgr_c.get_peer_map().list_peers().await[0],
peer_mgr_b.my_peer_id()
);
let s = MockService {
prefix: "hello".to_owned(),
};
peer_mgr_b.get_peer_rpc_mgr().run_service(1, s.serve());
let ip_list = peer_mgr_a
.get_peer_rpc_mgr()
.do_client_rpc_scoped(1, peer_mgr_b.my_peer_id(), |c| async {
let c = TestRpcServiceClient::new(tarpc::client::Config::default(), c).spawn();
let ret = c.hello(tarpc::context::current(), "abc".to_owned()).await;
ret
})
.await;
println!("ip_list: {:?}", ip_list);
assert_eq!(ip_list.as_ref().unwrap(), "hello abc");
// call again
let ip_list = peer_mgr_a
.get_peer_rpc_mgr()
.do_client_rpc_scoped(1, peer_mgr_b.my_peer_id(), |c| async {
let c = TestRpcServiceClient::new(tarpc::client::Config::default(), c).spawn();
let ret = c.hello(tarpc::context::current(), "abcd".to_owned()).await;
ret
})
.await;
println!("ip_list: {:?}", ip_list);
assert_eq!(ip_list.as_ref().unwrap(), "hello abcd");
let ip_list = peer_mgr_c
.get_peer_rpc_mgr()
.do_client_rpc_scoped(1, peer_mgr_b.my_peer_id(), |c| async {
let c = TestRpcServiceClient::new(tarpc::client::Config::default(), c).spawn();
let ret = c.hello(tarpc::context::current(), "bcd".to_owned()).await;
ret
})
.await;
println!("ip_list: {:?}", ip_list);
assert_eq!(ip_list.as_ref().unwrap(), "hello bcd");
}
#[tokio::test]
async fn test_multi_service_with_peer_manager() {
let peer_mgr_a = create_mock_peer_manager().await;
let peer_mgr_b = create_mock_peer_manager().await;
connect_peer_manager(peer_mgr_a.clone(), peer_mgr_b.clone()).await;
wait_route_appear(peer_mgr_a.clone(), peer_mgr_b.clone())
.await
.unwrap();
assert_eq!(peer_mgr_a.get_peer_map().list_peers().await.len(), 1);
assert_eq!(
peer_mgr_a.get_peer_map().list_peers().await[0],
peer_mgr_b.my_peer_id()
);
let s = MockService {
prefix: "hello_a".to_owned(),
};
peer_mgr_b.get_peer_rpc_mgr().run_service(1, s.serve());
let b = MockService {
prefix: "hello_b".to_owned(),
};
peer_mgr_b.get_peer_rpc_mgr().run_service(2, b.serve());
let ip_list = peer_mgr_a
.get_peer_rpc_mgr()
.do_client_rpc_scoped(1, peer_mgr_b.my_peer_id(), |c| async {
let c = TestRpcServiceClient::new(tarpc::client::Config::default(), c).spawn();
let ret = c.hello(tarpc::context::current(), "abc".to_owned()).await;
ret
})
.await;
assert_eq!(ip_list.as_ref().unwrap(), "hello_a abc");
let ip_list = peer_mgr_a
.get_peer_rpc_mgr()
.do_client_rpc_scoped(2, peer_mgr_b.my_peer_id(), |c| async {
let c = TestRpcServiceClient::new(tarpc::client::Config::default(), c).spawn();
let ret = c.hello(tarpc::context::current(), "abc".to_owned()).await;
ret
})
.await;
assert_eq!(ip_list.as_ref().unwrap(), "hello_b abc");
}
}
+36
View File
@@ -0,0 +1,36 @@
use std::{net::Ipv4Addr, sync::Arc};
use async_trait::async_trait;
use tokio_util::bytes::Bytes;
use crate::common::{error::Error, PeerId};
#[async_trait]
pub trait RouteInterface {
async fn list_peers(&self) -> Vec<PeerId>;
async fn send_route_packet(
&self,
msg: Bytes,
route_id: u8,
dst_peer_id: PeerId,
) -> Result<(), Error>;
fn my_peer_id(&self) -> PeerId;
}
pub type RouteInterfaceBox = Box<dyn RouteInterface + Send + Sync>;
#[async_trait]
#[auto_impl::auto_impl(Box, Arc)]
pub trait Route {
async fn open(&self, interface: RouteInterfaceBox) -> Result<u8, ()>;
async fn close(&self);
async fn get_next_hop(&self, peer_id: PeerId) -> Option<PeerId>;
async fn list_routes(&self) -> Vec<crate::rpc::Route>;
async fn get_peer_id_by_ipv4(&self, _ipv4: &Ipv4Addr) -> Option<PeerId> {
None
}
}
pub type ArcRoute = Arc<Box<dyn Route + Send + Sync>>;
+63
View File
@@ -0,0 +1,63 @@
use std::sync::Arc;
use crate::rpc::{
cli::PeerInfo,
peer_manage_rpc_server::PeerManageRpc,
{ListPeerRequest, ListPeerResponse, ListRouteRequest, ListRouteResponse},
};
use tonic::{Request, Response, Status};
use super::peer_manager::PeerManager;
pub struct PeerManagerRpcService {
peer_manager: Arc<PeerManager>,
}
impl PeerManagerRpcService {
pub fn new(peer_manager: Arc<PeerManager>) -> Self {
PeerManagerRpcService { peer_manager }
}
pub async fn list_peers(&self) -> Vec<PeerInfo> {
let peers = self.peer_manager.get_peer_map().list_peers().await;
let mut peer_infos = Vec::new();
for peer in peers {
let mut peer_info = PeerInfo::default();
peer_info.peer_id = peer;
if let Some(conns) = self.peer_manager.get_peer_map().list_peer_conns(peer).await {
peer_info.conns = conns;
}
peer_infos.push(peer_info);
}
peer_infos
}
}
#[tonic::async_trait]
impl PeerManageRpc for PeerManagerRpcService {
async fn list_peer(
&self,
_request: Request<ListPeerRequest>, // Accept request of type HelloRequest
) -> Result<Response<ListPeerResponse>, Status> {
let mut reply = ListPeerResponse::default();
let peers = self.list_peers().await;
for peer in peers {
reply.peer_infos.push(peer);
}
Ok(Response::new(reply))
}
async fn list_route(
&self,
_request: Request<ListRouteRequest>, // Accept request of type HelloRequest
) -> Result<Response<ListRouteResponse>, Status> {
let mut reply = ListRouteResponse::default();
reply.routes = self.peer_manager.list_routes().await;
Ok(Response::new(reply))
}
}
+75
View File
@@ -0,0 +1,75 @@
use std::sync::Arc;
use futures::Future;
use crate::{
common::{error::Error, global_ctx::tests::get_mock_global_ctx, PeerId},
tunnels::ring_tunnel::create_ring_tunnel_pair,
};
use super::peer_manager::{PeerManager, RouteAlgoType};
pub async fn create_mock_peer_manager() -> Arc<PeerManager> {
let (s, _r) = tokio::sync::mpsc::channel(1000);
let peer_mgr = Arc::new(PeerManager::new(
RouteAlgoType::Ospf,
get_mock_global_ctx(),
s,
));
peer_mgr.run().await.unwrap();
peer_mgr
}
pub async fn connect_peer_manager(client: Arc<PeerManager>, server: Arc<PeerManager>) {
let (a_ring, b_ring) = create_ring_tunnel_pair();
let a_mgr_copy = client.clone();
tokio::spawn(async move {
a_mgr_copy.add_client_tunnel(a_ring).await.unwrap();
});
let b_mgr_copy = server.clone();
tokio::spawn(async move {
b_mgr_copy.add_tunnel_as_server(b_ring).await.unwrap();
});
}
pub async fn wait_route_appear_with_cost(
peer_mgr: Arc<PeerManager>,
node_id: PeerId,
cost: Option<i32>,
) -> Result<(), Error> {
let now = std::time::Instant::now();
while now.elapsed().as_secs() < 5 {
let route = peer_mgr.list_routes().await;
if route
.iter()
.any(|r| r.peer_id == node_id && (cost.is_none() || r.cost == cost.unwrap()))
{
return Ok(());
}
tokio::time::sleep(std::time::Duration::from_millis(50)).await;
}
return Err(Error::NotFound);
}
pub async fn wait_route_appear(
peer_mgr: Arc<PeerManager>,
target_peer: Arc<PeerManager>,
) -> Result<(), Error> {
wait_route_appear_with_cost(peer_mgr.clone(), target_peer.my_peer_id(), None).await?;
wait_route_appear_with_cost(target_peer, peer_mgr.my_peer_id(), None).await
}
pub async fn wait_for_condition<F, FRet>(mut condition: F, timeout: std::time::Duration) -> ()
where
F: FnMut() -> FRet + Send,
FRet: Future<Output = bool>,
{
let now = std::time::Instant::now();
while now.elapsed() < timeout {
if condition().await {
return;
}
tokio::time::sleep(std::time::Duration::from_millis(50)).await;
}
assert!(condition().await, "Timeout")
}
+1
View File
@@ -0,0 +1 @@
tonic::include_proto!("cli"); // The string specified here must match the proto package name
+4
View File
@@ -0,0 +1,4 @@
pub mod cli;
pub use cli::*;
pub mod peer;
+22
View File
@@ -0,0 +1,22 @@
use serde::{Deserialize, Serialize};
#[derive(Clone, PartialEq, Debug, Serialize, Deserialize)]
pub struct GetIpListResponse {
pub public_ipv4: String,
pub interface_ipv4s: Vec<String>,
pub public_ipv6: String,
pub interface_ipv6s: Vec<String>,
pub listeners: Vec<url::Url>,
}
impl GetIpListResponse {
pub fn new() -> Self {
GetIpListResponse {
public_ipv4: "".to_string(),
interface_ipv4s: vec![],
public_ipv6: "".to_string(),
interface_ipv6s: vec![],
listeners: vec![],
}
}
}
+181
View File
@@ -0,0 +1,181 @@
use crate::common::PeerId;
mod three_node;
pub fn get_guest_veth_name(net_ns: &str) -> &str {
Box::leak(format!("veth_{}_g", net_ns).into_boxed_str())
}
pub fn get_host_veth_name(net_ns: &str) -> &str {
Box::leak(format!("veth_{}_h", net_ns).into_boxed_str())
}
pub fn del_netns(name: &str) {
// del veth host
let _ = std::process::Command::new("ip")
.args(&["link", "del", get_host_veth_name(name)])
.output();
let _ = std::process::Command::new("ip")
.args(&["netns", "del", name])
.output();
}
pub fn create_netns(name: &str, ipv4: &str) {
// create netns
let _ = std::process::Command::new("ip")
.args(&["netns", "add", name])
.output()
.unwrap();
// set lo up
let _ = std::process::Command::new("ip")
.args(&["netns", "exec", name, "ip", "link", "set", "lo", "up"])
.output()
.unwrap();
let _ = std::process::Command::new("ip")
.args(&[
"link",
"add",
get_host_veth_name(name),
"type",
"veth",
"peer",
"name",
get_guest_veth_name(name),
])
.output()
.unwrap();
let _ = std::process::Command::new("ip")
.args(&["link", "set", get_guest_veth_name(name), "netns", name])
.output()
.unwrap();
let _ = std::process::Command::new("ip")
.args(&[
"netns",
"exec",
name,
"ip",
"link",
"set",
get_guest_veth_name(name),
"up",
])
.output()
.unwrap();
let _ = std::process::Command::new("ip")
.args(&["link", "set", get_host_veth_name(name), "up"])
.output()
.unwrap();
let _ = std::process::Command::new("ip")
.args(&[
"netns",
"exec",
name,
"ip",
"addr",
"add",
ipv4,
"dev",
get_guest_veth_name(name),
])
.output()
.unwrap();
}
pub fn prepare_bridge(name: &str) {
// del bridge with brctl
let _ = std::process::Command::new("brctl")
.args(&["delbr", name])
.output();
// create new br
let _ = std::process::Command::new("brctl")
.args(&["addbr", name])
.output();
}
pub fn add_ns_to_bridge(br_name: &str, ns_name: &str) {
// use brctl to add ns to bridge
let _ = std::process::Command::new("brctl")
.args(&["addif", br_name, get_host_veth_name(ns_name)])
.output()
.unwrap();
// set bridge up
let _ = std::process::Command::new("ip")
.args(&["link", "set", br_name, "up"])
.output()
.unwrap();
}
pub fn enable_log() {
let filter = tracing_subscriber::EnvFilter::builder()
.with_default_directive(tracing::level_filters::LevelFilter::INFO.into())
.from_env()
.unwrap()
.add_directive("tarpc=error".parse().unwrap());
tracing_subscriber::fmt::fmt()
.pretty()
.with_env_filter(filter)
.init();
}
fn check_route(ipv4: &str, dst_peer_id: PeerId, routes: Vec<crate::rpc::Route>) {
let mut found = false;
for r in routes.iter() {
if r.ipv4_addr == ipv4.to_string() {
found = true;
assert_eq!(r.peer_id, dst_peer_id, "{:?}", routes);
}
}
assert!(
found,
"routes: {:?}, dst_peer_id: {}, ipv4: {}",
routes, dst_peer_id, ipv4
);
}
async fn wait_proxy_route_appear(
mgr: &std::sync::Arc<crate::peers::peer_manager::PeerManager>,
ipv4: &str,
dst_peer_id: PeerId,
proxy_cidr: &str,
) {
let now = std::time::Instant::now();
loop {
for r in mgr.list_routes().await.iter() {
let r = r;
if r.proxy_cidrs.contains(&proxy_cidr.to_owned()) {
assert_eq!(r.peer_id, dst_peer_id);
assert_eq!(r.ipv4_addr, ipv4);
return;
}
}
if now.elapsed().as_secs() > 5 {
panic!("wait proxy route appear timeout");
}
tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
}
}
fn set_link_status(net_ns: &str, up: bool) {
let _ = std::process::Command::new("ip")
.args(&[
"netns",
"exec",
net_ns,
"ip",
"link",
"set",
get_guest_veth_name(net_ns),
if up { "up" } else { "down" },
])
.output()
.unwrap();
}
+409
View File
@@ -0,0 +1,409 @@
use std::{
sync::{atomic::AtomicU32, Arc},
time::Duration,
};
use tokio::{net::UdpSocket, task::JoinSet};
use super::*;
use crate::{
common::{
config::{ConfigLoader, NetworkIdentity, TomlConfigLoader},
netns::{NetNS, ROOT_NETNS_NAME},
},
instance::instance::Instance,
peers::tests::wait_for_condition,
tunnels::{
common::tests::_tunnel_pingpong_netns,
ring_tunnel::RingTunnelConnector,
tcp_tunnel::{TcpTunnelConnector, TcpTunnelListener},
udp_tunnel::{UdpTunnelConnector, UdpTunnelListener},
wireguard::{WgConfig, WgTunnelConnector},
},
};
pub fn prepare_linux_namespaces() {
del_netns("net_a");
del_netns("net_b");
del_netns("net_c");
del_netns("net_d");
create_netns("net_a", "10.1.1.1/24");
create_netns("net_b", "10.1.1.2/24");
create_netns("net_c", "10.1.2.3/24");
create_netns("net_d", "10.1.2.4/24");
prepare_bridge("br_a");
prepare_bridge("br_b");
add_ns_to_bridge("br_a", "net_a");
add_ns_to_bridge("br_a", "net_b");
add_ns_to_bridge("br_b", "net_c");
add_ns_to_bridge("br_b", "net_d");
}
pub fn get_inst_config(inst_name: &str, ns: Option<&str>, ipv4: &str) -> TomlConfigLoader {
let config = TomlConfigLoader::default();
config.set_inst_name(inst_name.to_owned());
config.set_netns(ns.map(|s| s.to_owned()));
config.set_ipv4(ipv4.parse().unwrap());
config.set_listeners(vec![
"tcp://0.0.0.0:11010".parse().unwrap(),
"udp://0.0.0.0:11010".parse().unwrap(),
"wg://0.0.0.0:11011".parse().unwrap(),
]);
config
}
pub async fn init_three_node(proto: &str) -> Vec<Instance> {
log::set_max_level(log::LevelFilter::Info);
prepare_linux_namespaces();
let mut inst1 = Instance::new(get_inst_config("inst1", Some("net_a"), "10.144.144.1"));
let mut inst2 = Instance::new(get_inst_config("inst2", Some("net_b"), "10.144.144.2"));
let mut inst3 = Instance::new(get_inst_config("inst3", Some("net_c"), "10.144.144.3"));
inst1.run().await.unwrap();
inst2.run().await.unwrap();
inst3.run().await.unwrap();
if proto == "tcp" {
inst2
.get_conn_manager()
.add_connector(TcpTunnelConnector::new(
"tcp://10.1.1.1:11010".parse().unwrap(),
));
} else if proto == "udp" {
inst2
.get_conn_manager()
.add_connector(UdpTunnelConnector::new(
"udp://10.1.1.1:11010".parse().unwrap(),
));
} else if proto == "wg" {
inst2
.get_conn_manager()
.add_connector(WgTunnelConnector::new(
"wg://10.1.1.1:11011".parse().unwrap(),
WgConfig::new_from_network_identity(
&inst1.get_global_ctx().get_network_identity().network_name,
&inst1.get_global_ctx().get_network_identity().network_secret,
),
));
}
inst2
.get_conn_manager()
.add_connector(RingTunnelConnector::new(
format!("ring://{}", inst3.id()).parse().unwrap(),
));
// wait inst2 have two route.
let now = std::time::Instant::now();
loop {
if inst2.get_peer_manager().list_routes().await.len() == 2 {
break;
}
if now.elapsed().as_secs() > 5 {
panic!("wait inst2 have two route timeout");
}
tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
}
vec![inst1, inst2, inst3]
}
#[rstest::rstest]
#[tokio::test]
#[serial_test::serial]
pub async fn basic_three_node_test(#[values("tcp", "udp", "wg")] proto: &str) {
let insts = init_three_node(proto).await;
check_route(
"10.144.144.2",
insts[1].peer_id(),
insts[0].get_peer_manager().list_routes().await,
);
check_route(
"10.144.144.3",
insts[2].peer_id(),
insts[0].get_peer_manager().list_routes().await,
);
}
#[rstest::rstest]
#[tokio::test]
#[serial_test::serial]
pub async fn tcp_proxy_three_node_test(#[values("tcp", "udp", "wg")] proto: &str) {
let insts = init_three_node(proto).await;
insts[2]
.get_global_ctx()
.add_proxy_cidr("10.1.2.0/24".parse().unwrap())
.unwrap();
assert_eq!(insts[2].get_global_ctx().get_proxy_cidrs().len(), 1);
wait_proxy_route_appear(
&insts[0].get_peer_manager(),
"10.144.144.3",
insts[2].peer_id(),
"10.1.2.0/24",
)
.await;
// wait updater
tokio::time::sleep(tokio::time::Duration::from_secs(6)).await;
let tcp_listener = TcpTunnelListener::new("tcp://10.1.2.4:22223".parse().unwrap());
let tcp_connector = TcpTunnelConnector::new("tcp://10.1.2.4:22223".parse().unwrap());
_tunnel_pingpong_netns(
tcp_listener,
tcp_connector,
NetNS::new(Some("net_d".into())),
NetNS::new(Some("net_a".into())),
)
.await;
}
#[rstest::rstest]
#[tokio::test]
#[serial_test::serial]
pub async fn icmp_proxy_three_node_test(#[values("tcp", "udp", "wg")] proto: &str) {
let insts = init_three_node(proto).await;
insts[2]
.get_global_ctx()
.add_proxy_cidr("10.1.2.0/24".parse().unwrap())
.unwrap();
assert_eq!(insts[2].get_global_ctx().get_proxy_cidrs().len(), 1);
wait_proxy_route_appear(
&insts[0].get_peer_manager(),
"10.144.144.3",
insts[2].peer_id(),
"10.1.2.0/24",
)
.await;
// wait updater
tokio::time::sleep(tokio::time::Duration::from_secs(6)).await;
// send ping with shell in net_a to net_d
let _g = NetNS::new(Some(ROOT_NETNS_NAME.to_owned())).guard();
let code = tokio::process::Command::new("ip")
.args(&[
"netns", "exec", "net_a", "ping", "-c", "1", "-W", "1", "10.1.2.4",
])
.status()
.await
.unwrap();
assert_eq!(code.code().unwrap(), 0);
}
#[rstest::rstest]
#[tokio::test]
#[serial_test::serial]
pub async fn proxy_three_node_disconnect_test(#[values("tcp", "wg")] proto: &str) {
let insts = init_three_node(proto).await;
let mut inst4 = Instance::new(get_inst_config("inst4", Some("net_d"), "10.144.144.4"));
if proto == "tcp" {
inst4
.get_conn_manager()
.add_connector(TcpTunnelConnector::new(
"tcp://10.1.2.3:11010".parse().unwrap(),
));
} else if proto == "wg" {
inst4
.get_conn_manager()
.add_connector(WgTunnelConnector::new(
"wg://10.1.2.3:11011".parse().unwrap(),
WgConfig::new_from_network_identity(
&inst4.get_global_ctx().get_network_identity().network_name,
&inst4.get_global_ctx().get_network_identity().network_secret,
),
));
} else {
unreachable!("not support");
}
inst4.run().await.unwrap();
let task = tokio::spawn(async move {
for _ in 1..=2 {
tokio::time::sleep(tokio::time::Duration::from_secs(8)).await;
// inst4 should be in inst1's route list
let routes = insts[0].get_peer_manager().list_routes().await;
assert!(
routes
.iter()
.find(|r| r.peer_id == inst4.peer_id())
.is_some(),
"inst4 should be in inst1's route list, {:?}",
routes
);
set_link_status("net_d", false);
tokio::time::sleep(tokio::time::Duration::from_secs(8)).await;
let routes = insts[0].get_peer_manager().list_routes().await;
assert!(
routes
.iter()
.find(|r| r.peer_id == inst4.peer_id())
.is_none(),
"inst4 should not be in inst1's route list, {:?}",
routes
);
set_link_status("net_d", true);
}
});
let (ret,) = tokio::join!(task);
assert!(ret.is_ok());
}
#[rstest::rstest]
#[tokio::test]
#[serial_test::serial]
pub async fn udp_proxy_three_node_test(#[values("tcp", "udp", "wg")] proto: &str) {
let insts = init_three_node(proto).await;
insts[2]
.get_global_ctx()
.add_proxy_cidr("10.1.2.0/24".parse().unwrap())
.unwrap();
assert_eq!(insts[2].get_global_ctx().get_proxy_cidrs().len(), 1);
wait_proxy_route_appear(
&insts[0].get_peer_manager(),
"10.144.144.3",
insts[2].peer_id(),
"10.1.2.0/24",
)
.await;
// wait updater
tokio::time::sleep(tokio::time::Duration::from_secs(5)).await;
let tcp_listener = UdpTunnelListener::new("udp://10.1.2.4:22233".parse().unwrap());
let tcp_connector = UdpTunnelConnector::new("udp://10.1.2.4:22233".parse().unwrap());
_tunnel_pingpong_netns(
tcp_listener,
tcp_connector,
NetNS::new(Some("net_d".into())),
NetNS::new(Some("net_a".into())),
)
.await;
}
#[tokio::test]
#[serial_test::serial]
pub async fn udp_broadcast_test() {
let _insts = init_three_node("tcp").await;
let udp_broadcast_responder = |net_ns: NetNS, counter: Arc<AtomicU32>| async move {
let _g = net_ns.guard();
let socket: UdpSocket = UdpSocket::bind("0.0.0.0:22111").await.unwrap();
socket.set_broadcast(true).unwrap();
println!("Awaiting responses..."); // self.recv_buff is a [u8; 8092]
let mut recv_buff = [0; 8092];
while let Ok((n, addr)) = socket.recv_from(&mut recv_buff).await {
println!("{} bytes response from {:?}", n, addr);
counter.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
// Remaining code not directly relevant to the question
}
};
let mut tasks = JoinSet::new();
let counter = Arc::new(AtomicU32::new(0));
tasks.spawn(udp_broadcast_responder(
NetNS::new(Some("net_b".into())),
counter.clone(),
));
tasks.spawn(udp_broadcast_responder(
NetNS::new(Some("net_c".into())),
counter.clone(),
));
tokio::time::sleep(tokio::time::Duration::from_secs(1)).await;
// send broadcast
let net_ns = NetNS::new(Some("net_a".into()));
let _g = net_ns.guard();
let socket: UdpSocket = UdpSocket::bind("0.0.0.0:0").await.unwrap();
socket.set_broadcast(true).unwrap();
// socket.connect(("10.144.144.255", 22111)).await.unwrap();
let call: Vec<u8> = vec![1; 1024];
println!("Sending call, {} bytes", call.len());
match socket.send_to(&call, "10.144.144.255:22111").await {
Err(e) => panic!("Error sending call: {:?}", e),
_ => {}
}
tokio::time::sleep(tokio::time::Duration::from_secs(2)).await;
assert_eq!(counter.load(std::sync::atomic::Ordering::Relaxed), 2);
}
#[tokio::test]
#[serial_test::serial]
pub async fn foreign_network_forward_nic_data() {
prepare_linux_namespaces();
let center_node_config = get_inst_config("inst1", Some("net_a"), "10.144.144.1");
center_node_config.set_network_identity(NetworkIdentity {
network_name: "center".to_string(),
network_secret: "".to_string(),
});
let mut center_inst = Instance::new(center_node_config);
let mut inst1 = Instance::new(get_inst_config("inst1", Some("net_b"), "10.144.145.1"));
let mut inst2 = Instance::new(get_inst_config("inst2", Some("net_c"), "10.144.145.2"));
center_inst.run().await.unwrap();
inst1.run().await.unwrap();
inst2.run().await.unwrap();
assert_ne!(inst1.id(), center_inst.id());
assert_ne!(inst2.id(), center_inst.id());
inst1
.get_conn_manager()
.add_connector(RingTunnelConnector::new(
format!("ring://{}", center_inst.id()).parse().unwrap(),
));
inst2
.get_conn_manager()
.add_connector(RingTunnelConnector::new(
format!("ring://{}", center_inst.id()).parse().unwrap(),
));
wait_for_condition(
|| async {
inst1.get_peer_manager().list_routes().await.len() == 1
&& inst2.get_peer_manager().list_routes().await.len() == 1
},
Duration::from_secs(5),
)
.await;
let _g = NetNS::new(Some(ROOT_NETNS_NAME.to_owned())).guard();
let code = tokio::process::Command::new("ip")
.args(&[
"netns",
"exec",
"net_b",
"ping",
"-c",
"1",
"-W",
"1",
"10.144.145.2",
])
.status()
.await
.unwrap();
assert_eq!(code.code().unwrap(), 0);
}
+54
View File
@@ -0,0 +1,54 @@
use std::result::Result;
use tokio::io;
use tokio_util::{
bytes::{BufMut, Bytes, BytesMut},
codec::{Decoder, Encoder},
};
#[derive(Copy, Clone, Debug, Eq, PartialEq, Ord, PartialOrd, Hash, Default)]
pub struct BytesCodec {
capacity: usize,
}
impl BytesCodec {
/// Creates a new `BytesCodec` for shipping around raw bytes.
pub fn new(capacity: usize) -> BytesCodec {
BytesCodec { capacity }
}
}
impl Decoder for BytesCodec {
type Item = BytesMut;
type Error = io::Error;
fn decode(&mut self, buf: &mut BytesMut) -> Result<Option<BytesMut>, io::Error> {
if !buf.is_empty() {
let len = buf.len();
let ret = Some(buf.split_to(len));
buf.reserve(self.capacity);
Ok(ret)
} else {
Ok(None)
}
}
}
impl Encoder<Bytes> for BytesCodec {
type Error = io::Error;
fn encode(&mut self, data: Bytes, buf: &mut BytesMut) -> Result<(), io::Error> {
buf.reserve(data.len());
buf.put(data);
Ok(())
}
}
impl Encoder<BytesMut> for BytesCodec {
type Error = io::Error;
fn encode(&mut self, data: BytesMut, buf: &mut BytesMut) -> Result<(), io::Error> {
buf.reserve(data.len());
buf.put(data);
Ok(())
}
}
+460
View File
@@ -0,0 +1,460 @@
use std::{
collections::VecDeque,
net::{IpAddr, SocketAddr},
sync::Arc,
task::{ready, Context, Poll},
};
use async_stream::stream;
use futures::{Future, FutureExt, Sink, SinkExt, Stream, StreamExt};
use network_interface::NetworkInterfaceConfig;
use tokio::{sync::Mutex, time::error::Elapsed};
use std::pin::Pin;
use crate::tunnels::{SinkError, TunnelError};
use super::{DatagramSink, DatagramStream, SinkItem, StreamT, Tunnel, TunnelInfo};
pub struct FramedTunnel<R, W> {
read: Arc<Mutex<R>>,
write: Arc<Mutex<W>>,
info: Option<TunnelInfo>,
}
impl<R, RE, W, WE> FramedTunnel<R, W>
where
R: Stream<Item = Result<StreamT, RE>> + Send + Sync + Unpin + 'static,
W: Sink<SinkItem, Error = WE> + Send + Sync + Unpin + 'static,
RE: std::error::Error + std::fmt::Debug + Send + Sync + 'static,
WE: std::error::Error + std::fmt::Debug + Send + Sync + 'static + From<Elapsed>,
{
pub fn new(read: R, write: W, info: Option<TunnelInfo>) -> Self {
FramedTunnel {
read: Arc::new(Mutex::new(read)),
write: Arc::new(Mutex::new(write)),
info,
}
}
pub fn new_tunnel_with_info(read: R, write: W, info: TunnelInfo) -> Box<dyn Tunnel> {
Box::new(FramedTunnel::new(read, write, Some(info)))
}
pub fn recv_stream(&self) -> impl DatagramStream {
let read = self.read.clone();
let info = self.info.clone();
stream! {
loop {
let read_ret = read.lock().await.next().await;
if read_ret.is_none() {
tracing::info!(?info, "read_ret is none");
yield Err(TunnelError::CommonError("recv stream closed".to_string()));
} else {
let read_ret = read_ret.unwrap();
if read_ret.is_err() {
let err = read_ret.err().unwrap();
tracing::info!(?info, "recv stream read error");
yield Err(TunnelError::CommonError(err.to_string()));
} else {
yield Ok(read_ret.unwrap());
}
}
}
}
}
pub fn send_sink(&self) -> impl DatagramSink {
struct SendSink<W, WE> {
write: Arc<Mutex<W>>,
max_buffer_size: usize,
sending_buffers: Option<VecDeque<SinkItem>>,
send_task:
Option<Pin<Box<dyn Future<Output = Result<(), WE>> + Send + Sync + 'static>>>,
close_task:
Option<Pin<Box<dyn Future<Output = Result<(), WE>> + Send + Sync + 'static>>>,
}
impl<W, WE> SendSink<W, WE>
where
W: Sink<SinkItem, Error = WE> + Send + Sync + Unpin + 'static,
WE: std::error::Error + std::fmt::Debug + Send + Sync + From<Elapsed>,
{
fn try_send_buffser(
&mut self,
cx: &mut Context<'_>,
) -> Poll<std::result::Result<(), WE>> {
if self.send_task.is_none() {
let mut buffers = self.sending_buffers.take().unwrap();
let tun = self.write.clone();
let send_task = async move {
if buffers.is_empty() {
return Ok(());
}
let mut locked_tun = tun.lock_owned().await;
while let Some(buf) = buffers.front() {
log::trace!(
"try_send buffer, len: {:?}, buf: {:?}",
buffers.len(),
&buf
);
let timeout_task = tokio::time::timeout(
std::time::Duration::from_secs(1),
locked_tun.send(buf.clone()),
);
let send_res = timeout_task.await;
let Ok(send_res) = send_res else {
// panic!("send timeout");
let err = send_res.err().unwrap();
return Err(err.into());
};
let Ok(_) = send_res else {
let err = send_res.err().unwrap();
println!("send error: {:?}", err);
return Err(err);
};
buffers.pop_front();
}
return Ok(());
};
self.send_task = Some(Box::pin(send_task));
}
let ret = ready!(self.send_task.as_mut().unwrap().poll_unpin(cx));
self.send_task = None;
self.sending_buffers = Some(VecDeque::new());
return Poll::Ready(ret);
}
}
impl<W, WE> Sink<SinkItem> for SendSink<W, WE>
where
W: Sink<SinkItem, Error = WE> + Send + Sync + Unpin + 'static,
WE: std::error::Error + std::fmt::Debug + Send + Sync + From<Elapsed>,
{
type Error = SinkError;
fn poll_ready(
self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> Poll<Result<(), Self::Error>> {
let self_mut = self.get_mut();
let sending_buf = self_mut.sending_buffers.as_ref();
// if sending_buffers is None, must already be doing flush
if sending_buf.is_none() || sending_buf.unwrap().len() > self_mut.max_buffer_size {
return self_mut.poll_flush_unpin(cx);
} else {
return Poll::Ready(Ok(()));
}
}
fn start_send(self: Pin<&mut Self>, item: SinkItem) -> Result<(), Self::Error> {
assert!(self.send_task.is_none());
let self_mut = self.get_mut();
self_mut.sending_buffers.as_mut().unwrap().push_back(item);
Ok(())
}
fn poll_flush(
self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> Poll<Result<(), Self::Error>> {
let self_mut = self.get_mut();
let ret = self_mut.try_send_buffser(cx);
match ret {
Poll::Ready(Ok(())) => Poll::Ready(Ok(())),
Poll::Ready(Err(e)) => Poll::Ready(Err(SinkError::CommonError(e.to_string()))),
Poll::Pending => {
return Poll::Pending;
}
}
}
fn poll_close(
self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> Poll<Result<(), Self::Error>> {
let self_mut = self.get_mut();
if self_mut.close_task.is_none() {
let tun = self_mut.write.clone();
let close_task = async move {
let mut locked_tun = tun.lock_owned().await;
return locked_tun.close().await;
};
self_mut.close_task = Some(Box::pin(close_task));
}
let ret = ready!(self_mut.close_task.as_mut().unwrap().poll_unpin(cx));
self_mut.close_task = None;
if ret.is_err() {
return Poll::Ready(Err(SinkError::CommonError(
ret.err().unwrap().to_string(),
)));
} else {
return Poll::Ready(Ok(()));
}
}
}
SendSink {
write: self.write.clone(),
max_buffer_size: 1000,
sending_buffers: Some(VecDeque::new()),
send_task: None,
close_task: None,
}
}
}
impl<R, RE, W, WE> Tunnel for FramedTunnel<R, W>
where
R: Stream<Item = Result<StreamT, RE>> + Send + Sync + Unpin + 'static,
W: Sink<SinkItem, Error = WE> + Send + Sync + Unpin + 'static,
RE: std::error::Error + std::fmt::Debug + Send + Sync + 'static,
WE: std::error::Error + std::fmt::Debug + Send + Sync + 'static + From<Elapsed>,
{
fn stream(&self) -> Box<dyn DatagramStream> {
Box::new(self.recv_stream())
}
fn sink(&self) -> Box<dyn DatagramSink> {
Box::new(self.send_sink())
}
fn info(&self) -> Option<TunnelInfo> {
if self.info.is_none() {
None
} else {
Some(self.info.clone().unwrap())
}
}
}
pub struct TunnelWithCustomInfo {
tunnel: Box<dyn Tunnel>,
info: TunnelInfo,
}
impl TunnelWithCustomInfo {
pub fn new(tunnel: Box<dyn Tunnel>, info: TunnelInfo) -> Self {
TunnelWithCustomInfo { tunnel, info }
}
}
impl Tunnel for TunnelWithCustomInfo {
fn stream(&self) -> Box<dyn DatagramStream> {
self.tunnel.stream()
}
fn sink(&self) -> Box<dyn DatagramSink> {
self.tunnel.sink()
}
fn info(&self) -> Option<TunnelInfo> {
Some(self.info.clone())
}
}
pub(crate) fn get_interface_name_by_ip(local_ip: &IpAddr) -> Option<String> {
if local_ip.is_unspecified() || local_ip.is_multicast() {
return None;
}
let ifaces = network_interface::NetworkInterface::show().ok()?;
for iface in ifaces {
for addr in iface.addr {
if addr.ip() == *local_ip {
return Some(iface.name);
}
}
}
tracing::error!(?local_ip, "can not find interface name by ip");
None
}
pub(crate) fn setup_sokcet2_ext(
socket2_socket: &socket2::Socket,
bind_addr: &SocketAddr,
bind_dev: Option<String>,
) -> Result<(), TunnelError> {
#[cfg(target_os = "windows")]
{
let is_udp = matches!(socket2_socket.r#type()?, socket2::Type::DGRAM);
crate::arch::windows::setup_socket_for_win(socket2_socket, bind_addr, bind_dev, is_udp)?;
}
socket2_socket.set_nonblocking(true)?;
socket2_socket.set_reuse_address(true)?;
socket2_socket.bind(&socket2::SockAddr::from(*bind_addr))?;
// #[cfg(all(unix, not(target_os = "solaris"), not(target_os = "illumos")))]
// socket2_socket.set_reuse_port(true)?;
if bind_addr.ip().is_unspecified() {
return Ok(());
}
// linux/mac does not use interface of bind_addr to send packet, so we need to bind device
// win can handle this with bind correctly
#[cfg(any(target_os = "ios", target_os = "macos"))]
if let Some(dev_name) = bind_dev {
// use IP_BOUND_IF to bind device
unsafe {
let dev_idx = nix::libc::if_nametoindex(dev_name.as_str().as_ptr() as *const i8);
tracing::warn!(?dev_idx, ?dev_name, "bind device");
socket2_socket.bind_device_by_index_v4(std::num::NonZeroU32::new(dev_idx))?;
tracing::warn!(?dev_idx, ?dev_name, "bind device doen");
}
}
#[cfg(any(target_os = "android", target_os = "fuchsia", target_os = "linux"))]
if let Some(dev_name) = bind_dev {
tracing::trace!(dev_name = ?dev_name, "bind device");
socket2_socket.bind_device(Some(dev_name.as_bytes()))?;
}
Ok(())
}
pub(crate) fn setup_sokcet2(
socket2_socket: &socket2::Socket,
bind_addr: &SocketAddr,
) -> Result<(), TunnelError> {
setup_sokcet2_ext(
socket2_socket,
bind_addr,
super::common::get_interface_name_by_ip(&bind_addr.ip()),
)
}
pub mod tests {
use std::time::Instant;
use futures::SinkExt;
use tokio_stream::StreamExt;
use tokio_util::bytes::{BufMut, Bytes, BytesMut};
use crate::{
common::netns::NetNS,
tunnels::{close_tunnel, TunnelConnector, TunnelListener},
};
pub async fn _tunnel_echo_server(tunnel: Box<dyn super::Tunnel>, once: bool) {
let mut recv = Box::into_pin(tunnel.stream());
let mut send = Box::into_pin(tunnel.sink());
while let Some(ret) = recv.next().await {
if ret.is_err() {
log::trace!("recv error: {:?}", ret.err().unwrap());
break;
}
let res = ret.unwrap();
log::trace!("recv a msg, try echo back: {:?}", res);
send.send(Bytes::from(res)).await.unwrap();
if once {
break;
}
}
log::warn!("echo server exit...");
}
pub(crate) async fn _tunnel_pingpong<L, C>(listener: L, connector: C)
where
L: TunnelListener + Send + Sync + 'static,
C: TunnelConnector + Send + Sync + 'static,
{
_tunnel_pingpong_netns(listener, connector, NetNS::new(None), NetNS::new(None)).await
}
pub(crate) async fn _tunnel_pingpong_netns<L, C>(
mut listener: L,
mut connector: C,
l_netns: NetNS,
c_netns: NetNS,
) where
L: TunnelListener + Send + Sync + 'static,
C: TunnelConnector + Send + Sync + 'static,
{
l_netns
.run_async(|| async {
listener.listen().await.unwrap();
})
.await;
let lis = tokio::spawn(async move {
let ret = listener.accept().await.unwrap();
assert_eq!(
ret.info().unwrap().local_addr,
listener.local_url().to_string()
);
_tunnel_echo_server(ret, false).await
});
let tunnel = c_netns.run_async(|| connector.connect()).await.unwrap();
assert_eq!(
tunnel.info().unwrap().remote_addr,
connector.remote_url().to_string()
);
let mut send = tunnel.pin_sink();
let mut recv = tunnel.pin_stream();
let send_data = Bytes::from("12345678abcdefg");
send.send(send_data).await.unwrap();
let ret = tokio::time::timeout(tokio::time::Duration::from_secs(1), recv.next())
.await
.unwrap()
.unwrap()
.unwrap();
println!("echo back: {:?}", ret);
assert_eq!(ret, Bytes::from("12345678abcdefg"));
close_tunnel(&tunnel).await.unwrap();
if ["udp", "wg"].contains(&connector.remote_url().scheme()) {
lis.abort();
} else {
// lis should finish in 1 second
let ret = tokio::time::timeout(tokio::time::Duration::from_secs(1), lis).await;
assert!(ret.is_ok());
}
}
pub(crate) async fn _tunnel_bench<L, C>(mut listener: L, mut connector: C)
where
L: TunnelListener + Send + Sync + 'static,
C: TunnelConnector + Send + Sync + 'static,
{
listener.listen().await.unwrap();
let lis = tokio::spawn(async move {
let ret = listener.accept().await.unwrap();
_tunnel_echo_server(ret, false).await
});
let tunnel = connector.connect().await.unwrap();
let mut send = tunnel.pin_sink();
let mut recv = tunnel.pin_stream();
// prepare a 4k buffer with random data
let mut send_buf = BytesMut::new();
for _ in 0..64 {
send_buf.put_i128(rand::random::<i128>());
}
let now = Instant::now();
let mut count = 0;
while now.elapsed().as_secs() < 3 {
send.send(send_buf.clone().freeze()).await.unwrap();
let _ = recv.next().await.unwrap().unwrap();
count += 1;
}
println!("bps: {}", (count / 1024) * 4 / now.elapsed().as_secs());
lis.abort();
}
}
+192
View File
@@ -0,0 +1,192 @@
pub mod codec;
pub mod common;
pub mod ring_tunnel;
pub mod stats;
pub mod tcp_tunnel;
pub mod tunnel_filter;
pub mod udp_tunnel;
pub mod wireguard;
use std::{fmt::Debug, net::SocketAddr, pin::Pin, sync::Arc};
use crate::rpc::TunnelInfo;
use async_trait::async_trait;
use futures::{Sink, SinkExt, Stream};
use thiserror::Error;
use tokio_util::bytes::{Bytes, BytesMut};
#[derive(Error, Debug)]
pub enum TunnelError {
#[error("Error: {0}")]
CommonError(String),
#[error("io error")]
IOError(#[from] std::io::Error),
#[error("wait resp error")]
WaitRespError(String),
#[error("Connect Error: {0}")]
ConnectError(String),
#[error("Invalid Protocol: {0}")]
InvalidProtocol(String),
#[error("Invalid Addr: {0}")]
InvalidAddr(String),
#[error("Tun Error: {0}")]
TunError(String),
#[error("timeout")]
Timeout(#[from] tokio::time::error::Elapsed),
}
pub type StreamT = BytesMut;
pub type StreamItem = Result<StreamT, TunnelError>;
pub type SinkItem = Bytes;
pub type SinkError = TunnelError;
pub trait DatagramStream: Stream<Item = StreamItem> + Send + Sync {}
impl<T> DatagramStream for T where T: Stream<Item = StreamItem> + Send + Sync {}
pub trait DatagramSink: Sink<SinkItem, Error = SinkError> + Send + Sync {}
impl<T> DatagramSink for T where T: Sink<SinkItem, Error = SinkError> + Send + Sync {}
#[auto_impl::auto_impl(Box, Arc)]
pub trait Tunnel: Send + Sync {
fn stream(&self) -> Box<dyn DatagramStream>;
fn sink(&self) -> Box<dyn DatagramSink>;
fn pin_stream(&self) -> Pin<Box<dyn DatagramStream>> {
Box::into_pin(self.stream())
}
fn pin_sink(&self) -> Pin<Box<dyn DatagramSink>> {
Box::into_pin(self.sink())
}
fn info(&self) -> Option<TunnelInfo>;
}
pub async fn close_tunnel(t: &Box<dyn Tunnel>) -> Result<(), TunnelError> {
t.pin_sink().close().await
}
#[auto_impl::auto_impl(Arc)]
pub trait TunnelConnCounter: 'static + Send + Sync + Debug {
fn get(&self) -> u32;
}
#[async_trait]
#[auto_impl::auto_impl(Box)]
pub trait TunnelListener: Send + Sync {
async fn listen(&mut self) -> Result<(), TunnelError>;
async fn accept(&mut self) -> Result<Box<dyn Tunnel>, TunnelError>;
fn local_url(&self) -> url::Url;
fn get_conn_counter(&self) -> Arc<Box<dyn TunnelConnCounter>> {
#[derive(Debug)]
struct FakeTunnelConnCounter {}
impl TunnelConnCounter for FakeTunnelConnCounter {
fn get(&self) -> u32 {
0
}
}
Arc::new(Box::new(FakeTunnelConnCounter {}))
}
}
#[async_trait]
#[auto_impl::auto_impl(Box)]
pub trait TunnelConnector {
async fn connect(&mut self) -> Result<Box<dyn Tunnel>, TunnelError>;
fn remote_url(&self) -> url::Url;
fn set_bind_addrs(&mut self, _addrs: Vec<SocketAddr>) {}
}
pub fn build_url_from_socket_addr(addr: &String, scheme: &str) -> url::Url {
url::Url::parse(format!("{}://{}", scheme, addr).as_str()).unwrap()
}
impl std::fmt::Debug for dyn Tunnel {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Tunnel")
.field("info", &self.info())
.finish()
}
}
impl std::fmt::Debug for dyn TunnelConnector + Sync + Send {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("TunnelConnector")
.field("remote_url", &self.remote_url())
.finish()
}
}
impl std::fmt::Debug for dyn TunnelListener {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("TunnelListener")
.field("local_url", &self.local_url())
.finish()
}
}
pub(crate) trait FromUrl {
fn from_url(url: url::Url) -> Result<Self, TunnelError>
where
Self: Sized;
}
pub(crate) fn check_scheme_and_get_socket_addr<T>(
url: &url::Url,
scheme: &str,
) -> Result<T, TunnelError>
where
T: FromUrl,
{
if url.scheme() != scheme {
return Err(TunnelError::InvalidProtocol(url.scheme().to_string()));
}
Ok(T::from_url(url.clone())?)
}
impl FromUrl for SocketAddr {
fn from_url(url: url::Url) -> Result<Self, TunnelError> {
Ok(url.socket_addrs(|| None)?.pop().unwrap())
}
}
impl FromUrl for uuid::Uuid {
fn from_url(url: url::Url) -> Result<Self, TunnelError> {
let o = url.host_str().unwrap();
let o = uuid::Uuid::parse_str(o).map_err(|e| TunnelError::InvalidAddr(e.to_string()))?;
Ok(o)
}
}
pub struct TunnelUrl {
inner: url::Url,
}
impl From<url::Url> for TunnelUrl {
fn from(url: url::Url) -> Self {
TunnelUrl { inner: url }
}
}
impl From<TunnelUrl> for url::Url {
fn from(url: TunnelUrl) -> Self {
url.into_inner()
}
}
impl TunnelUrl {
pub fn into_inner(self) -> url::Url {
self.inner
}
pub fn bind_dev(&self) -> Option<String> {
self.inner.path().strip_prefix("/").and_then(|s| {
if s.is_empty() {
None
} else {
Some(String::from_utf8(percent_encoding::percent_decode_str(&s).collect()).unwrap())
}
})
}
}
+396
View File
@@ -0,0 +1,396 @@
use std::{
collections::HashMap,
sync::{atomic::AtomicBool, Arc},
task::Poll,
};
use async_stream::stream;
use crossbeam_queue::ArrayQueue;
use async_trait::async_trait;
use futures::Sink;
use once_cell::sync::Lazy;
use tokio::sync::{
mpsc::{UnboundedReceiver, UnboundedSender},
Mutex, Notify,
};
use futures::FutureExt;
use tokio_util::bytes::BytesMut;
use uuid::Uuid;
use crate::tunnels::{SinkError, SinkItem};
use super::{
build_url_from_socket_addr, check_scheme_and_get_socket_addr, DatagramSink, DatagramStream,
Tunnel, TunnelConnector, TunnelError, TunnelInfo, TunnelListener,
};
static RING_TUNNEL_CAP: usize = 1000;
pub struct RingTunnel {
id: Uuid,
ring: Arc<ArrayQueue<SinkItem>>,
consume_notify: Arc<Notify>,
produce_notify: Arc<Notify>,
closed: Arc<AtomicBool>,
}
impl RingTunnel {
pub fn new(cap: usize) -> Self {
RingTunnel {
id: Uuid::new_v4(),
ring: Arc::new(ArrayQueue::new(cap)),
consume_notify: Arc::new(Notify::new()),
produce_notify: Arc::new(Notify::new()),
closed: Arc::new(AtomicBool::new(false)),
}
}
pub fn new_with_id(id: Uuid, cap: usize) -> Self {
let mut ret = Self::new(cap);
ret.id = id;
ret
}
fn recv_stream(&self) -> impl DatagramStream {
let ring = self.ring.clone();
let produce_notify = self.produce_notify.clone();
let consume_notify = self.consume_notify.clone();
let closed = self.closed.clone();
let id = self.id;
stream! {
loop {
if closed.load(std::sync::atomic::Ordering::Relaxed) {
log::warn!("ring recv tunnel {:?} closed", id);
yield Err(TunnelError::CommonError("Closed".to_owned()));
}
match ring.pop() {
Some(v) => {
let mut out = BytesMut::new();
out.extend_from_slice(&v);
consume_notify.notify_one();
log::trace!("id: {}, recv buffer, len: {:?}, buf: {:?}", id, v.len(), &v);
yield Ok(out);
},
None => {
log::trace!("waiting recv buffer, id: {}", id);
produce_notify.notified().await;
}
}
}
}
}
fn send_sink(&self) -> impl DatagramSink {
let ring = self.ring.clone();
let produce_notify = self.produce_notify.clone();
let consume_notify = self.consume_notify.clone();
let closed = self.closed.clone();
let id = self.id;
// type T = RingTunnel;
use tokio::task::JoinHandle;
struct T {
ring: RingTunnel,
wait_consume_task: Option<JoinHandle<()>>,
}
impl T {
fn wait_ring_consume(
self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
expected_size: usize,
) -> std::task::Poll<()> {
let self_mut = self.get_mut();
if self_mut.ring.ring.len() <= expected_size {
return Poll::Ready(());
}
if self_mut.wait_consume_task.is_none() {
let id = self_mut.ring.id;
let consume_notify = self_mut.ring.consume_notify.clone();
let ring = self_mut.ring.ring.clone();
let task = async move {
log::trace!(
"waiting ring consume done, expected_size: {}, id: {}",
expected_size,
id
);
while ring.len() > expected_size {
consume_notify.notified().await;
}
log::trace!(
"ring consume done, expected_size: {}, id: {}",
expected_size,
id
);
};
self_mut.wait_consume_task = Some(tokio::spawn(task));
}
let task = self_mut.wait_consume_task.as_mut().unwrap();
match task.poll_unpin(cx) {
Poll::Ready(_) => {
self_mut.wait_consume_task = None;
Poll::Ready(())
}
Poll::Pending => Poll::Pending,
}
}
}
impl Sink<SinkItem> for T {
type Error = SinkError;
fn poll_ready(
self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Result<(), Self::Error>> {
let expected_size = self.ring.ring.capacity() - 1;
match self.wait_ring_consume(cx, expected_size) {
Poll::Ready(_) => Poll::Ready(Ok(())),
Poll::Pending => Poll::Pending,
}
}
fn start_send(
self: std::pin::Pin<&mut Self>,
item: SinkItem,
) -> Result<(), Self::Error> {
log::trace!("id: {}, send buffer, buf: {:?}", self.ring.id, &item);
self.ring.ring.push(item).unwrap();
self.ring.produce_notify.notify_one();
Ok(())
}
fn poll_flush(
self: std::pin::Pin<&mut Self>,
_cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Result<(), Self::Error>> {
Poll::Ready(Ok(()))
}
fn poll_close(
self: std::pin::Pin<&mut Self>,
_cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Result<(), Self::Error>> {
self.ring
.closed
.store(true, std::sync::atomic::Ordering::Relaxed);
log::warn!("ring tunnel send {:?} closed", self.ring.id);
self.ring.produce_notify.notify_one();
Poll::Ready(Ok(()))
}
}
T {
ring: RingTunnel {
id,
ring,
consume_notify,
produce_notify,
closed,
},
wait_consume_task: None,
}
}
}
struct Connection {
client: RingTunnel,
server: RingTunnel,
}
impl Tunnel for RingTunnel {
fn stream(&self) -> Box<dyn DatagramStream> {
Box::new(self.recv_stream())
}
fn sink(&self) -> Box<dyn DatagramSink> {
Box::new(self.send_sink())
}
fn info(&self) -> Option<TunnelInfo> {
None
// Some(TunnelInfo {
// tunnel_type: "ring".to_owned(),
// local_addr: format!("ring://{}", self.id),
// remote_addr: format!("ring://{}", self.id),
// })
}
}
static CONNECTION_MAP: Lazy<Arc<Mutex<HashMap<uuid::Uuid, UnboundedSender<Arc<Connection>>>>>> =
Lazy::new(|| Arc::new(Mutex::new(HashMap::new())));
#[derive(Debug)]
pub struct RingTunnelListener {
listerner_addr: url::Url,
conn_sender: UnboundedSender<Arc<Connection>>,
conn_receiver: UnboundedReceiver<Arc<Connection>>,
}
impl RingTunnelListener {
pub fn new(key: url::Url) -> Self {
let (conn_sender, conn_receiver) = tokio::sync::mpsc::unbounded_channel();
RingTunnelListener {
listerner_addr: key,
conn_sender,
conn_receiver,
}
}
}
struct ConnectionForServer {
conn: Arc<Connection>,
}
impl Tunnel for ConnectionForServer {
fn stream(&self) -> Box<dyn DatagramStream> {
Box::new(self.conn.server.recv_stream())
}
fn sink(&self) -> Box<dyn DatagramSink> {
Box::new(self.conn.client.send_sink())
}
fn info(&self) -> Option<TunnelInfo> {
Some(TunnelInfo {
tunnel_type: "ring".to_owned(),
local_addr: build_url_from_socket_addr(&self.conn.server.id.into(), "ring").into(),
remote_addr: build_url_from_socket_addr(&self.conn.client.id.into(), "ring").into(),
})
}
}
struct ConnectionForClient {
conn: Arc<Connection>,
}
impl Tunnel for ConnectionForClient {
fn stream(&self) -> Box<dyn DatagramStream> {
Box::new(self.conn.client.recv_stream())
}
fn sink(&self) -> Box<dyn DatagramSink> {
Box::new(self.conn.server.send_sink())
}
fn info(&self) -> Option<TunnelInfo> {
Some(TunnelInfo {
tunnel_type: "ring".to_owned(),
local_addr: build_url_from_socket_addr(&self.conn.client.id.into(), "ring").into(),
remote_addr: build_url_from_socket_addr(&self.conn.server.id.into(), "ring").into(),
})
}
}
impl RingTunnelListener {
fn get_addr(&self) -> Result<uuid::Uuid, TunnelError> {
check_scheme_and_get_socket_addr::<Uuid>(&self.listerner_addr, "ring")
}
}
#[async_trait]
impl TunnelListener for RingTunnelListener {
async fn listen(&mut self) -> Result<(), TunnelError> {
log::info!("listen new conn of key: {}", self.listerner_addr);
CONNECTION_MAP
.lock()
.await
.insert(self.get_addr()?, self.conn_sender.clone());
Ok(())
}
async fn accept(&mut self) -> Result<Box<dyn Tunnel>, TunnelError> {
log::info!("waiting accept new conn of key: {}", self.listerner_addr);
let my_addr = self.get_addr()?;
if let Some(conn) = self.conn_receiver.recv().await {
if conn.server.id == my_addr {
log::info!("accept new conn of key: {}", self.listerner_addr);
return Ok(Box::new(ConnectionForServer { conn }));
} else {
tracing::error!(?conn.server.id, ?my_addr, "got new conn with wrong id");
return Err(TunnelError::CommonError(
"accept got wrong ring server id".to_owned(),
));
}
}
return Err(TunnelError::CommonError("conn receiver stopped".to_owned()));
}
fn local_url(&self) -> url::Url {
self.listerner_addr.clone()
}
}
pub struct RingTunnelConnector {
remote_addr: url::Url,
}
impl RingTunnelConnector {
pub fn new(remote_addr: url::Url) -> Self {
RingTunnelConnector { remote_addr }
}
}
#[async_trait]
impl TunnelConnector for RingTunnelConnector {
async fn connect(&mut self) -> Result<Box<dyn Tunnel>, super::TunnelError> {
let remote_addr = check_scheme_and_get_socket_addr::<Uuid>(&self.remote_addr, "ring")?;
let entry = CONNECTION_MAP
.lock()
.await
.get(&remote_addr)
.unwrap()
.clone();
log::info!("connecting");
let conn = Arc::new(Connection {
client: RingTunnel::new(RING_TUNNEL_CAP),
server: RingTunnel::new_with_id(remote_addr.clone(), RING_TUNNEL_CAP),
});
entry
.send(conn.clone())
.map_err(|_| TunnelError::CommonError("send conn to listner failed".to_owned()))?;
Ok(Box::new(ConnectionForClient { conn }))
}
fn remote_url(&self) -> url::Url {
self.remote_addr.clone()
}
}
pub fn create_ring_tunnel_pair() -> (Box<dyn Tunnel>, Box<dyn Tunnel>) {
let conn = Arc::new(Connection {
client: RingTunnel::new(RING_TUNNEL_CAP),
server: RingTunnel::new(RING_TUNNEL_CAP),
});
(
Box::new(ConnectionForServer { conn: conn.clone() }),
Box::new(ConnectionForClient { conn }),
)
}
#[cfg(test)]
mod tests {
use crate::tunnels::common::tests::{_tunnel_bench, _tunnel_pingpong};
use super::*;
#[tokio::test]
async fn ring_pingpong() {
let id: url::Url = format!("ring://{}", Uuid::new_v4()).parse().unwrap();
let listener = RingTunnelListener::new(id.clone());
let connector = RingTunnelConnector::new(id.clone());
_tunnel_pingpong(listener, connector).await
}
#[tokio::test]
async fn ring_bench() {
let id: url::Url = format!("ring://{}", Uuid::new_v4()).parse().unwrap();
let listener = RingTunnelListener::new(id.clone());
let connector = RingTunnelConnector::new(id);
_tunnel_bench(listener, connector).await
}
}
+95
View File
@@ -0,0 +1,95 @@
use std::sync::atomic::{AtomicU32, AtomicU64, Ordering::Relaxed};
pub struct WindowLatency {
latency_us_window: Vec<AtomicU32>,
latency_us_window_index: AtomicU32,
latency_us_window_size: u32,
sum: AtomicU32,
count: AtomicU32,
}
impl WindowLatency {
pub fn new(window_size: u32) -> Self {
Self {
latency_us_window: (0..window_size).map(|_| AtomicU32::new(0)).collect(),
latency_us_window_index: AtomicU32::new(0),
latency_us_window_size: window_size,
sum: AtomicU32::new(0),
count: AtomicU32::new(0),
}
}
pub fn record_latency(&self, latency_us: u32) {
let index = self.latency_us_window_index.fetch_add(1, Relaxed);
if self.count.load(Relaxed) < self.latency_us_window_size {
self.count.fetch_add(1, Relaxed);
}
let index = index % self.latency_us_window_size;
let old_lat = self.latency_us_window[index as usize].swap(latency_us, Relaxed);
if old_lat < latency_us {
self.sum.fetch_add(latency_us - old_lat, Relaxed);
} else {
self.sum.fetch_sub(old_lat - latency_us, Relaxed);
}
}
pub fn get_latency_us<T: From<u32> + std::ops::Div<Output = T>>(&self) -> T {
let count = self.count.load(Relaxed);
let sum = self.sum.load(Relaxed);
if count == 0 {
0.into()
} else {
(T::from(sum)) / T::from(count)
}
}
}
pub struct Throughput {
tx_bytes: AtomicU64,
rx_bytes: AtomicU64,
tx_packets: AtomicU64,
rx_packets: AtomicU64,
}
impl Throughput {
pub fn new() -> Self {
Self {
tx_bytes: AtomicU64::new(0),
rx_bytes: AtomicU64::new(0),
tx_packets: AtomicU64::new(0),
rx_packets: AtomicU64::new(0),
}
}
pub fn tx_bytes(&self) -> u64 {
self.tx_bytes.load(Relaxed)
}
pub fn rx_bytes(&self) -> u64 {
self.rx_bytes.load(Relaxed)
}
pub fn tx_packets(&self) -> u64 {
self.tx_packets.load(Relaxed)
}
pub fn rx_packets(&self) -> u64 {
self.rx_packets.load(Relaxed)
}
pub fn record_tx_bytes(&self, bytes: u64) {
self.tx_bytes.fetch_add(bytes, Relaxed);
self.tx_packets.fetch_add(1, Relaxed);
}
pub fn record_rx_bytes(&self, bytes: u64) {
self.rx_bytes.fetch_add(bytes, Relaxed);
self.rx_packets.fetch_add(1, Relaxed);
}
}
+273
View File
@@ -0,0 +1,273 @@
use std::net::SocketAddr;
use async_trait::async_trait;
use futures::{stream::FuturesUnordered, StreamExt};
use tokio::net::{TcpListener, TcpSocket, TcpStream};
use tokio_util::codec::{FramedRead, FramedWrite, LengthDelimitedCodec};
use crate::tunnels::common::setup_sokcet2;
use super::{
check_scheme_and_get_socket_addr, common::FramedTunnel, Tunnel, TunnelInfo, TunnelListener,
};
#[derive(Debug)]
pub struct TcpTunnelListener {
addr: url::Url,
listener: Option<TcpListener>,
}
impl TcpTunnelListener {
pub fn new(addr: url::Url) -> Self {
TcpTunnelListener {
addr,
listener: None,
}
}
}
#[async_trait]
impl TunnelListener for TcpTunnelListener {
async fn listen(&mut self) -> Result<(), super::TunnelError> {
let addr = check_scheme_and_get_socket_addr::<SocketAddr>(&self.addr, "tcp")?;
let socket = if addr.is_ipv4() {
TcpSocket::new_v4()?
} else {
TcpSocket::new_v6()?
};
socket.set_reuseaddr(true)?;
// #[cfg(all(unix, not(target_os = "solaris"), not(target_os = "illumos")))]
// socket.set_reuseport(true)?;
socket.bind(addr)?;
self.listener = Some(socket.listen(1024)?);
Ok(())
}
async fn accept(&mut self) -> Result<Box<dyn Tunnel>, super::TunnelError> {
let listener = self.listener.as_ref().unwrap();
let (stream, _) = listener.accept().await?;
stream.set_nodelay(true).unwrap();
let info = TunnelInfo {
tunnel_type: "tcp".to_owned(),
local_addr: self.local_url().into(),
remote_addr: super::build_url_from_socket_addr(&stream.peer_addr()?.to_string(), "tcp")
.into(),
};
let (r, w) = tokio::io::split(stream);
Ok(FramedTunnel::new_tunnel_with_info(
FramedRead::new(r, LengthDelimitedCodec::new()),
FramedWrite::new(w, LengthDelimitedCodec::new()),
info,
))
}
fn local_url(&self) -> url::Url {
self.addr.clone()
}
}
fn get_tunnel_with_tcp_stream(
stream: TcpStream,
remote_url: url::Url,
) -> Result<Box<dyn Tunnel>, super::TunnelError> {
stream.set_nodelay(true).unwrap();
let info = TunnelInfo {
tunnel_type: "tcp".to_owned(),
local_addr: super::build_url_from_socket_addr(&stream.local_addr()?.to_string(), "tcp")
.into(),
remote_addr: remote_url.into(),
};
let (r, w) = tokio::io::split(stream);
Ok(Box::new(FramedTunnel::new_tunnel_with_info(
FramedRead::new(r, LengthDelimitedCodec::new()),
FramedWrite::new(w, LengthDelimitedCodec::new()),
info,
)))
}
#[derive(Debug)]
pub struct TcpTunnelConnector {
addr: url::Url,
bind_addrs: Vec<SocketAddr>,
}
impl TcpTunnelConnector {
pub fn new(addr: url::Url) -> Self {
TcpTunnelConnector {
addr,
bind_addrs: vec![],
}
}
async fn connect_with_default_bind(&mut self) -> Result<Box<dyn Tunnel>, super::TunnelError> {
tracing::info!(addr = ?self.addr, "connect tcp start");
let addr = check_scheme_and_get_socket_addr::<SocketAddr>(&self.addr, "tcp")?;
let stream = TcpStream::connect(addr).await?;
tracing::info!(addr = ?self.addr, "connect tcp succ");
return get_tunnel_with_tcp_stream(stream, self.addr.clone().into());
}
async fn connect_with_custom_bind(&mut self) -> Result<Box<dyn Tunnel>, super::TunnelError> {
let mut futures = FuturesUnordered::new();
let dst_addr = check_scheme_and_get_socket_addr::<SocketAddr>(&self.addr, "tcp")?;
for bind_addr in self.bind_addrs.iter() {
tracing::info!(bind_addr = ?bind_addr, ?dst_addr, "bind addr");
let socket2_socket = socket2::Socket::new(
socket2::Domain::for_address(dst_addr),
socket2::Type::STREAM,
Some(socket2::Protocol::TCP),
)?;
setup_sokcet2(&socket2_socket, bind_addr)?;
let socket = TcpSocket::from_std_stream(socket2_socket.into());
futures.push(socket.connect(dst_addr.clone()));
}
let Some(ret) = futures.next().await else {
return Err(super::TunnelError::CommonError(
"join connect futures failed".to_owned(),
));
};
return get_tunnel_with_tcp_stream(ret?, self.addr.clone().into());
}
}
#[async_trait]
impl super::TunnelConnector for TcpTunnelConnector {
async fn connect(&mut self) -> Result<Box<dyn Tunnel>, super::TunnelError> {
if self.bind_addrs.is_empty() {
self.connect_with_default_bind().await
} else {
self.connect_with_custom_bind().await
}
}
fn remote_url(&self) -> url::Url {
self.addr.clone()
}
fn set_bind_addrs(&mut self, addrs: Vec<SocketAddr>) {
self.bind_addrs = addrs;
}
}
#[cfg(test)]
mod tests {
use futures::SinkExt;
use crate::tunnels::{
common::tests::{_tunnel_bench, _tunnel_pingpong},
TunnelConnector,
};
use super::*;
#[tokio::test]
async fn tcp_pingpong() {
let listener = TcpTunnelListener::new("tcp://0.0.0.0:11011".parse().unwrap());
let connector = TcpTunnelConnector::new("tcp://127.0.0.1:11011".parse().unwrap());
_tunnel_pingpong(listener, connector).await
}
#[tokio::test]
async fn tcp_bench() {
let listener = TcpTunnelListener::new("tcp://0.0.0.0:11012".parse().unwrap());
let connector = TcpTunnelConnector::new("tcp://127.0.0.1:11012".parse().unwrap());
_tunnel_bench(listener, connector).await
}
#[tokio::test]
async fn tcp_bench_with_bind() {
let listener = TcpTunnelListener::new("tcp://127.0.0.1:11013".parse().unwrap());
let mut connector = TcpTunnelConnector::new("tcp://127.0.0.1:11013".parse().unwrap());
connector.set_bind_addrs(vec!["127.0.0.1:0".parse().unwrap()]);
_tunnel_pingpong(listener, connector).await
}
#[tokio::test]
#[should_panic]
async fn tcp_bench_with_bind_fail() {
let listener = TcpTunnelListener::new("tcp://127.0.0.1:11014".parse().unwrap());
let mut connector = TcpTunnelConnector::new("tcp://127.0.0.1:11014".parse().unwrap());
connector.set_bind_addrs(vec!["10.0.0.1:0".parse().unwrap()]);
_tunnel_pingpong(listener, connector).await
}
// test slow send lock in framed tunnel
#[tokio::test]
async fn tcp_multiple_sender_and_slow_receiver() {
// console_subscriber::init();
let mut listener = TcpTunnelListener::new("tcp://127.0.0.1:11014".parse().unwrap());
let mut connector = TcpTunnelConnector::new("tcp://127.0.0.1:11014".parse().unwrap());
listener.listen().await.unwrap();
let t1 = tokio::spawn(async move {
let t = listener.accept().await.unwrap();
let mut stream = t.pin_stream();
let now = tokio::time::Instant::now();
while let Some(Ok(_)) = stream.next().await {
tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
if now.elapsed().as_secs() > 5 {
break;
}
}
tracing::info!("t1 exit");
});
let tunnel = connector.connect().await.unwrap();
let mut sink1 = tunnel.pin_sink();
let t2 = tokio::spawn(async move {
for i in 0..1000000 {
let a = sink1.send(b"hello".to_vec().into()).await;
if a.is_err() {
tracing::info!(?a, "t2 exit with err");
break;
}
if i % 5000 == 0 {
tracing::info!(i, "send2 1000");
}
}
tracing::info!("t2 exit");
});
let mut sink2 = tunnel.pin_sink();
let t3 = tokio::spawn(async move {
for i in 0..1000000 {
let a = sink2.send(b"hello".to_vec().into()).await;
if a.is_err() {
tracing::info!(?a, "t3 exit with err");
break;
}
if i % 5000 == 0 {
tracing::info!(i, "send2 1000");
}
}
tracing::info!("t3 exit");
});
let t4 = tokio::spawn(async move {
tokio::time::sleep(tokio::time::Duration::from_secs(5)).await;
tracing::info!("closing");
let close_ret = tunnel.pin_sink().close().await;
tracing::info!("closed {:?}", close_ret);
});
let _ = tokio::join!(t1, t2, t3, t4);
}
}
+279
View File
@@ -0,0 +1,279 @@
use std::{
sync::Arc,
task::{Context, Poll},
};
use crate::rpc::TunnelInfo;
use futures::{Sink, SinkExt, Stream, StreamExt};
use self::stats::Throughput;
use super::*;
use crate::tunnels::{DatagramSink, DatagramStream, SinkError, SinkItem, StreamItem, Tunnel};
pub trait TunnelFilter {
fn before_send(&self, data: SinkItem) -> Option<Result<SinkItem, SinkError>> {
Some(Ok(data))
}
fn after_received(&self, data: StreamItem) -> Option<Result<BytesMut, TunnelError>> {
match data {
Ok(v) => Some(Ok(v)),
Err(e) => Some(Err(e)),
}
}
}
pub struct TunnelWithFilter<T, F> {
inner: T,
filter: Arc<F>,
}
impl<T, F> Tunnel for TunnelWithFilter<T, F>
where
T: Tunnel + Send + Sync + 'static,
F: TunnelFilter + Send + Sync + 'static,
{
fn sink(&self) -> Box<dyn DatagramSink> {
struct SinkWrapper<F> {
sink: Pin<Box<dyn DatagramSink>>,
filter: Arc<F>,
}
impl<F> Sink<SinkItem> for SinkWrapper<F>
where
F: TunnelFilter + Send + Sync + 'static,
{
type Error = SinkError;
fn poll_ready(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Result<(), Self::Error>> {
self.get_mut().sink.poll_ready_unpin(cx)
}
fn start_send(self: Pin<&mut Self>, item: SinkItem) -> Result<(), Self::Error> {
let Some(item) = self.filter.before_send(item) else {
return Ok(());
};
self.get_mut().sink.start_send_unpin(item?)
}
fn poll_flush(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Result<(), Self::Error>> {
self.get_mut().sink.poll_flush_unpin(cx)
}
fn poll_close(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Result<(), Self::Error>> {
self.get_mut().sink.poll_close_unpin(cx)
}
}
Box::new(SinkWrapper {
sink: self.inner.pin_sink(),
filter: self.filter.clone(),
})
}
fn stream(&self) -> Box<dyn DatagramStream> {
struct StreamWrapper<F> {
stream: Pin<Box<dyn DatagramStream>>,
filter: Arc<F>,
}
impl<F> Stream for StreamWrapper<F>
where
F: TunnelFilter + Send + Sync + 'static,
{
type Item = StreamItem;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let self_mut = self.get_mut();
loop {
match self_mut.stream.poll_next_unpin(cx) {
Poll::Ready(Some(ret)) => {
let Some(ret) = self_mut.filter.after_received(ret) else {
continue;
};
return Poll::Ready(Some(ret));
}
Poll::Ready(None) => {
return Poll::Ready(None);
}
Poll::Pending => {
return Poll::Pending;
}
}
}
}
}
Box::new(StreamWrapper {
stream: self.inner.pin_stream(),
filter: self.filter.clone(),
})
}
fn info(&self) -> Option<TunnelInfo> {
self.inner.info()
}
}
impl<T, F> TunnelWithFilter<T, F>
where
T: Tunnel + Send + Sync + 'static,
F: TunnelFilter + Send + Sync + 'static,
{
pub fn new(inner: T, filter: Arc<F>) -> Self {
Self { inner, filter }
}
}
pub struct PacketRecorderTunnelFilter {
pub received: Arc<std::sync::Mutex<Vec<Bytes>>>,
pub sent: Arc<std::sync::Mutex<Vec<Bytes>>>,
}
impl TunnelFilter for PacketRecorderTunnelFilter {
fn before_send(&self, data: SinkItem) -> Option<Result<SinkItem, SinkError>> {
self.received.lock().unwrap().push(data.clone());
Some(Ok(data))
}
fn after_received(&self, data: StreamItem) -> Option<Result<BytesMut, TunnelError>> {
match data {
Ok(v) => {
self.sent.lock().unwrap().push(v.clone().into());
Some(Ok(v))
}
Err(e) => Some(Err(e)),
}
}
}
impl PacketRecorderTunnelFilter {
pub fn new() -> Self {
Self {
received: Arc::new(std::sync::Mutex::new(Vec::new())),
sent: Arc::new(std::sync::Mutex::new(Vec::new())),
}
}
}
pub struct StatsRecorderTunnelFilter {
throughput: Arc<Throughput>,
}
impl TunnelFilter for StatsRecorderTunnelFilter {
fn before_send(&self, data: SinkItem) -> Option<Result<SinkItem, SinkError>> {
self.throughput.record_tx_bytes(data.len() as u64);
Some(Ok(data))
}
fn after_received(&self, data: StreamItem) -> Option<Result<BytesMut, TunnelError>> {
match data {
Ok(v) => {
self.throughput.record_rx_bytes(v.len() as u64);
Some(Ok(v))
}
Err(e) => Some(Err(e)),
}
}
}
impl StatsRecorderTunnelFilter {
pub fn new() -> Self {
Self {
throughput: Arc::new(Throughput::new()),
}
}
pub fn get_throughput(&self) -> Arc<Throughput> {
self.throughput.clone()
}
}
#[macro_export]
macro_rules! define_tunnel_filter_chain {
($type_name:ident $(, $field_name:ident = $filter_type:ty)+) => (
pub struct $type_name {
$($field_name: std::sync::Arc<$filter_type>,)+
}
impl $type_name {
pub fn new() -> Self {
Self {
$($field_name: std::sync::Arc::new(<$filter_type>::new()),)+
}
}
pub fn wrap_tunnel(&self, tunnel: impl Tunnel + 'static) -> impl Tunnel {
$(
let tunnel = crate::tunnels::tunnel_filter::TunnelWithFilter::new(tunnel, self.$field_name.clone());
)+
tunnel
}
}
)
}
#[cfg(test)]
pub mod tests {
use std::sync::atomic::{AtomicU32, Ordering};
use super::*;
use crate::tunnels::ring_tunnel::RingTunnel;
pub struct DropSendTunnelFilter {
start: AtomicU32,
end: AtomicU32,
cur: AtomicU32,
}
impl TunnelFilter for DropSendTunnelFilter {
fn before_send(&self, data: SinkItem) -> Option<Result<SinkItem, SinkError>> {
self.cur.fetch_add(1, Ordering::SeqCst);
if self.cur.load(Ordering::SeqCst) >= self.start.load(Ordering::SeqCst)
&& self.cur.load(std::sync::atomic::Ordering::SeqCst)
< self.end.load(Ordering::SeqCst)
{
tracing::trace!("drop packet: {:?}", data);
return None;
}
Some(Ok(data))
}
}
impl DropSendTunnelFilter {
pub fn new(start: u32, end: u32) -> Self {
Self {
start: AtomicU32::new(start),
end: AtomicU32::new(end),
cur: AtomicU32::new(0),
}
}
}
#[tokio::test]
async fn test_nested_filter() {
define_tunnel_filter_chain!(
Filter,
a = PacketRecorderTunnelFilter,
b = PacketRecorderTunnelFilter,
c = PacketRecorderTunnelFilter
);
let filter = Filter::new();
let tunnel = filter.wrap_tunnel(RingTunnel::new(1));
let mut s = tunnel.pin_sink();
s.send(Bytes::from("hello")).await.unwrap();
assert_eq!(1, filter.a.received.lock().unwrap().len());
assert_eq!(1, filter.b.received.lock().unwrap().len());
assert_eq!(1, filter.c.received.lock().unwrap().len());
}
}
+772
View File
@@ -0,0 +1,772 @@
use std::{fmt::Debug, pin::Pin, sync::Arc};
use async_trait::async_trait;
use dashmap::DashMap;
use futures::{stream::FuturesUnordered, SinkExt, StreamExt};
use rkyv::{Archive, Deserialize, Serialize};
use std::net::SocketAddr;
use tokio::{net::UdpSocket, sync::Mutex, task::JoinSet};
use tokio_util::{
bytes::{Buf, Bytes, BytesMut},
udp::UdpFramed,
};
use tracing::Instrument;
use crate::{
common::{
join_joinset_background,
rkyv_util::{self, encode_to_bytes, vec_to_string},
},
rpc::TunnelInfo,
tunnels::{build_url_from_socket_addr, close_tunnel, TunnelConnCounter, TunnelConnector},
};
use super::{
codec::BytesCodec,
common::{setup_sokcet2, setup_sokcet2_ext, FramedTunnel, TunnelWithCustomInfo},
ring_tunnel::create_ring_tunnel_pair,
DatagramSink, DatagramStream, Tunnel, TunnelListener, TunnelUrl,
};
pub const UDP_DATA_MTU: usize = 65000;
#[derive(Archive, Deserialize, Serialize)]
#[archive(compare(PartialEq), check_bytes)]
// Derives can be passed through to the generated type:
pub enum UdpPacketPayload {
Syn,
Sack,
HolePunch(String),
Data(String),
}
impl std::fmt::Debug for UdpPacketPayload {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let mut tmp = f.debug_struct("ArchivedUdpPacketPayload");
match self {
UdpPacketPayload::Syn => tmp.field("Syn", &"").finish(),
UdpPacketPayload::Sack => tmp.field("Sack", &"").finish(),
UdpPacketPayload::HolePunch(s) => tmp.field("HolePunch", &s.as_bytes()).finish(),
UdpPacketPayload::Data(s) => tmp.field("Data", &s.as_bytes()).finish(),
}
}
}
#[derive(Archive, Deserialize, Serialize, Debug)]
#[archive(compare(PartialEq), check_bytes)]
#[archive_attr(derive(Debug))]
pub struct UdpPacket {
pub conn_id: u32,
pub magic: u32,
pub payload: UdpPacketPayload,
}
const UDP_PACKET_MAGIC: u32 = 0x19941126;
impl std::fmt::Debug for ArchivedUdpPacketPayload {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let mut tmp = f.debug_struct("ArchivedUdpPacketPayload");
match self {
ArchivedUdpPacketPayload::Syn => tmp.field("Syn", &"").finish(),
ArchivedUdpPacketPayload::Sack => tmp.field("Sack", &"").finish(),
ArchivedUdpPacketPayload::HolePunch(s) => {
tmp.field("HolePunch", &s.as_bytes()).finish()
}
ArchivedUdpPacketPayload::Data(s) => tmp.field("Data", &s.as_bytes()).finish(),
}
}
}
impl UdpPacket {
pub fn new_data_packet(conn_id: u32, data: Vec<u8>) -> Self {
Self {
conn_id,
magic: UDP_PACKET_MAGIC,
payload: UdpPacketPayload::Data(vec_to_string(data)),
}
}
pub fn new_hole_punch_packet(data: Vec<u8>) -> Self {
Self {
conn_id: 0,
magic: UDP_PACKET_MAGIC,
payload: UdpPacketPayload::HolePunch(vec_to_string(data)),
}
}
pub fn new_syn_packet(conn_id: u32) -> Self {
Self {
conn_id,
magic: UDP_PACKET_MAGIC,
payload: UdpPacketPayload::Syn,
}
}
pub fn new_sack_packet(conn_id: u32) -> Self {
Self {
conn_id,
magic: UDP_PACKET_MAGIC,
payload: UdpPacketPayload::Sack,
}
}
}
fn try_get_data_payload(mut buf: BytesMut, conn_id: u32) -> Option<BytesMut> {
let Ok(udp_packet) = rkyv_util::decode_from_bytes::<UdpPacket>(&buf) else {
tracing::warn!(?buf, "udp decode error");
return None;
};
if udp_packet.conn_id != conn_id.clone() {
tracing::warn!(?udp_packet, ?conn_id, "udp conn id not match");
return None;
}
if udp_packet.magic != UDP_PACKET_MAGIC {
tracing::trace!(?udp_packet, "udp magic not match");
return None;
}
let ArchivedUdpPacketPayload::Data(payload) = &udp_packet.payload else {
tracing::warn!(?udp_packet, "udp payload not data");
return None;
};
let offset = payload.as_ptr() as usize - buf.as_ptr() as usize;
let len = payload.len();
if offset + len > buf.len() {
tracing::warn!(?offset, ?len, ?buf, "udp payload data out of range");
return None;
}
buf.advance(offset);
buf.truncate(len);
tracing::trace!(?offset, ?len, ?buf, "udp payload data");
Some(buf)
}
fn get_tunnel_from_socket(
socket: Arc<UdpSocket>,
addr: SocketAddr,
conn_id: u32,
) -> Box<dyn super::Tunnel> {
let udp = UdpFramed::new(socket.clone(), BytesCodec::new(UDP_DATA_MTU));
let (sink, stream) = udp.split();
let recv_addr = addr;
let stream = stream.filter_map(move |v| async move {
tracing::trace!(?v, "udp stream recv something");
if v.is_err() {
tracing::warn!(?v, "udp stream error");
return Some(Err(super::TunnelError::CommonError(
"udp stream error".to_owned(),
)));
}
let (buf, addr) = v.unwrap();
if recv_addr != addr {
tracing::warn!(?addr, ?recv_addr, "udp recv addr not match");
return None;
}
Some(Ok(try_get_data_payload(buf, conn_id.clone())?))
});
let stream = Box::pin(stream);
let sender_addr = addr;
let sink = Box::pin(sink.with(move |v: Bytes| async move {
if false {
return Err(super::TunnelError::CommonError("udp sink error".to_owned()));
}
// TODO: two copy here, how to avoid?
let udp_packet = UdpPacket::new_data_packet(conn_id, v.to_vec());
let v = encode_to_bytes::<_, UDP_DATA_MTU>(&udp_packet);
tracing::trace!(?udp_packet, ?v, "udp send packet");
Ok((v, sender_addr))
}));
FramedTunnel::new_tunnel_with_info(
stream,
sink,
// TODO: this remote addr is not a url
super::TunnelInfo {
tunnel_type: "udp".to_owned(),
local_addr: super::build_url_from_socket_addr(
&socket.local_addr().unwrap().to_string(),
"udp",
)
.into(),
remote_addr: super::build_url_from_socket_addr(&addr.to_string(), "udp").into(),
},
)
}
pub(crate) struct StreamSinkPair(
pub Pin<Box<dyn DatagramStream>>,
pub Pin<Box<dyn DatagramSink>>,
pub u32,
);
pub(crate) type ArcStreamSinkPair = Arc<Mutex<StreamSinkPair>>;
pub struct UdpTunnelListener {
addr: url::Url,
socket: Option<Arc<UdpSocket>>,
sock_map: Arc<DashMap<SocketAddr, ArcStreamSinkPair>>,
forward_tasks: Arc<std::sync::Mutex<JoinSet<()>>>,
conn_recv: tokio::sync::mpsc::Receiver<Box<dyn Tunnel>>,
conn_send: Option<tokio::sync::mpsc::Sender<Box<dyn Tunnel>>>,
}
impl UdpTunnelListener {
pub fn new(addr: url::Url) -> Self {
let (conn_send, conn_recv) = tokio::sync::mpsc::channel(100);
Self {
addr,
socket: None,
sock_map: Arc::new(DashMap::new()),
forward_tasks: Arc::new(std::sync::Mutex::new(JoinSet::new())),
conn_recv,
conn_send: Some(conn_send),
}
}
async fn try_forward_packet(
sock_map: &DashMap<SocketAddr, ArcStreamSinkPair>,
buf: BytesMut,
addr: SocketAddr,
) -> Result<(), super::TunnelError> {
let entry = sock_map.get_mut(&addr);
if entry.is_none() {
log::warn!("udp forward packet: {:?}, {:?}, no entry", addr, buf);
return Ok(());
}
log::trace!("udp forward packet: {:?}, {:?}", addr, buf);
let entry = entry.unwrap();
let pair = entry.value().clone();
drop(entry);
let Some(buf) = try_get_data_payload(buf, pair.lock().await.2) else {
return Ok(());
};
pair.lock().await.1.send(buf.freeze()).await?;
Ok(())
}
async fn handle_connect(
socket: Arc<UdpSocket>,
addr: SocketAddr,
forward_tasks: Arc<std::sync::Mutex<JoinSet<()>>>,
sock_map: Arc<DashMap<SocketAddr, ArcStreamSinkPair>>,
local_url: url::Url,
conn_id: u32,
) -> Result<Box<dyn Tunnel>, super::TunnelError> {
tracing::info!(?conn_id, ?addr, "udp connection accept handling",);
let udp_packet = UdpPacket::new_sack_packet(conn_id);
let sack_buf = encode_to_bytes::<_, UDP_DATA_MTU>(&udp_packet);
socket.send_to(&sack_buf, addr).await?;
let (ctunnel, stunnel) = create_ring_tunnel_pair();
let udp_tunnel = get_tunnel_from_socket(socket.clone(), addr, conn_id);
let ss_pair = StreamSinkPair(ctunnel.pin_stream(), ctunnel.pin_sink(), conn_id);
let addr_copy = addr.clone();
sock_map.insert(addr, Arc::new(Mutex::new(ss_pair)));
let ctunnel_stream = ctunnel.pin_stream();
forward_tasks.lock().unwrap().spawn(async move {
let ret = ctunnel_stream
.map(|v| {
tracing::trace!(?v, "udp stream recv something in forward task");
if v.is_err() {
return Err(super::TunnelError::CommonError(
"udp stream error".to_owned(),
));
}
Ok(v.unwrap().freeze())
})
.forward(udp_tunnel.pin_sink())
.await;
if let None = sock_map.remove(&addr_copy) {
log::warn!("udp forward packet: {:?}, no entry", addr_copy);
}
close_tunnel(&ctunnel).await.unwrap();
log::warn!("udp connection forward done: {:?}, {:?}", addr_copy, ret);
});
Ok(Box::new(TunnelWithCustomInfo::new(
stunnel,
TunnelInfo {
tunnel_type: "udp".to_owned(),
local_addr: local_url.into(),
remote_addr: build_url_from_socket_addr(&addr.to_string(), "udp").into(),
},
)))
}
pub fn get_socket(&self) -> Option<Arc<UdpSocket>> {
self.socket.clone()
}
}
#[async_trait]
impl TunnelListener for UdpTunnelListener {
async fn listen(&mut self) -> Result<(), super::TunnelError> {
let addr = super::check_scheme_and_get_socket_addr::<SocketAddr>(&self.addr, "udp")?;
let socket2_socket = socket2::Socket::new(
socket2::Domain::for_address(addr),
socket2::Type::DGRAM,
Some(socket2::Protocol::UDP),
)?;
let tunnel_url: TunnelUrl = self.addr.clone().into();
if let Some(bind_dev) = tunnel_url.bind_dev() {
setup_sokcet2_ext(&socket2_socket, &addr, Some(bind_dev))?;
} else {
setup_sokcet2(&socket2_socket, &addr)?;
}
self.socket = Some(Arc::new(UdpSocket::from_std(socket2_socket.into())?));
let socket = self.socket.as_ref().unwrap().clone();
let forward_tasks = self.forward_tasks.clone();
let sock_map = self.sock_map.clone();
let conn_send = self.conn_send.take().unwrap();
let local_url = self.local_url().clone();
self.forward_tasks.lock().unwrap().spawn(
async move {
loop {
let mut buf = BytesMut::new();
buf.resize(UDP_DATA_MTU, 0);
let (_size, addr) = socket.recv_from(&mut buf).await.unwrap();
let _ = buf.split_off(_size);
log::trace!(
"udp recv packet: {:?}, buf: {:?}, size: {}",
addr,
buf,
_size
);
let Ok(udp_packet) = rkyv_util::decode_from_bytes::<UdpPacket>(&buf) else {
tracing::warn!(?buf, "udp decode error in forward task");
continue;
};
if udp_packet.magic != UDP_PACKET_MAGIC {
tracing::trace!(?udp_packet, "udp magic not match");
continue;
}
if matches!(udp_packet.payload, ArchivedUdpPacketPayload::Syn) {
let Ok(conn) = Self::handle_connect(
socket.clone(),
addr,
forward_tasks.clone(),
sock_map.clone(),
local_url.clone(),
udp_packet.conn_id.into(),
)
.await
else {
tracing::error!(?addr, "udp handle connect error");
continue;
};
if let Err(e) = conn_send.send(conn).await {
tracing::warn!(?e, "udp send conn to accept channel error");
}
} else {
Self::try_forward_packet(sock_map.as_ref(), buf, addr)
.await
.unwrap();
}
}
}
.instrument(tracing::info_span!("udp forward task", ?self.socket)),
);
join_joinset_background(self.forward_tasks.clone(), "UdpTunnelListener".to_owned());
Ok(())
}
async fn accept(&mut self) -> Result<Box<dyn super::Tunnel>, super::TunnelError> {
log::info!("start udp accept: {:?}", self.addr);
while let Some(conn) = self.conn_recv.recv().await {
return Ok(conn);
}
return Err(super::TunnelError::CommonError(
"udp accept error".to_owned(),
));
}
fn local_url(&self) -> url::Url {
self.addr.clone()
}
fn get_conn_counter(&self) -> Arc<Box<dyn TunnelConnCounter>> {
struct UdpTunnelConnCounter {
sock_map: Arc<DashMap<SocketAddr, ArcStreamSinkPair>>,
}
impl Debug for UdpTunnelConnCounter {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("UdpTunnelConnCounter")
.field("sock_map_len", &self.sock_map.len())
.finish()
}
}
impl TunnelConnCounter for UdpTunnelConnCounter {
fn get(&self) -> u32 {
self.sock_map.len() as u32
}
}
Arc::new(Box::new(UdpTunnelConnCounter {
sock_map: self.sock_map.clone(),
}))
}
}
pub struct UdpTunnelConnector {
addr: url::Url,
bind_addrs: Vec<SocketAddr>,
}
impl UdpTunnelConnector {
pub fn new(addr: url::Url) -> Self {
Self {
addr,
bind_addrs: vec![],
}
}
async fn wait_sack(
socket: &UdpSocket,
addr: SocketAddr,
conn_id: u32,
) -> Result<(), super::TunnelError> {
let mut buf = BytesMut::new();
buf.resize(128, 0);
let (usize, recv_addr) = tokio::time::timeout(
tokio::time::Duration::from_secs(3),
socket.recv_from(&mut buf),
)
.await??;
if recv_addr != addr {
return Err(super::TunnelError::ConnectError(format!(
"udp connect error, unexpected sack addr: {:?}, {:?}",
recv_addr, addr
)));
}
let _ = buf.split_off(usize);
let Ok(udp_packet) = rkyv_util::decode_from_bytes::<UdpPacket>(&buf) else {
tracing::warn!(?buf, "udp decode error in wait sack");
return Err(super::TunnelError::ConnectError(format!(
"udp connect error, decode error. buf: {:?}",
buf
)));
};
if udp_packet.magic != UDP_PACKET_MAGIC {
tracing::trace!(?udp_packet, "udp magic not match");
return Err(super::TunnelError::ConnectError(format!(
"udp connect error, magic not match. magic: {:?}",
udp_packet.magic
)));
}
if conn_id != udp_packet.conn_id {
return Err(super::TunnelError::ConnectError(format!(
"udp connect error, conn id not match. conn_id: {:?}, {:?}",
conn_id, udp_packet.conn_id
)));
}
if !matches!(udp_packet.payload, ArchivedUdpPacketPayload::Sack) {
return Err(super::TunnelError::ConnectError(format!(
"udp connect error, unexpected payload. payload: {:?}",
udp_packet.payload
)));
}
Ok(())
}
async fn wait_sack_loop(
socket: &UdpSocket,
addr: SocketAddr,
conn_id: u32,
) -> Result<(), super::TunnelError> {
while let Err(err) = Self::wait_sack(socket, addr, conn_id).await {
tracing::warn!(?err, "udp wait sack error");
}
Ok(())
}
pub async fn try_connect_with_socket(
&self,
socket: UdpSocket,
) -> Result<Box<dyn super::Tunnel>, super::TunnelError> {
let addr = super::check_scheme_and_get_socket_addr::<SocketAddr>(&self.addr, "udp")?;
log::warn!("udp connect: {:?}", self.addr);
#[cfg(target_os = "windows")]
crate::arch::windows::disable_connection_reset(&socket)?;
// send syn
let conn_id = rand::random();
let udp_packet = UdpPacket::new_syn_packet(conn_id);
let b = encode_to_bytes::<_, UDP_DATA_MTU>(&udp_packet);
let ret = socket.send_to(&b, &addr).await?;
tracing::warn!(?udp_packet, ?ret, "udp send syn");
// wait sack
tokio::time::timeout(
tokio::time::Duration::from_secs(3),
Self::wait_sack_loop(&socket, addr, conn_id),
)
.await??;
// sack done
let local_addr = socket.local_addr().unwrap().to_string();
Ok(Box::new(TunnelWithCustomInfo::new(
get_tunnel_from_socket(Arc::new(socket), addr, conn_id),
TunnelInfo {
tunnel_type: "udp".to_owned(),
local_addr: super::build_url_from_socket_addr(&local_addr, "udp").into(),
remote_addr: self.remote_url().into(),
},
)))
}
async fn connect_with_default_bind(&mut self) -> Result<Box<dyn Tunnel>, super::TunnelError> {
let socket = UdpSocket::bind("0.0.0.0:0").await?;
return self.try_connect_with_socket(socket).await;
}
async fn connect_with_custom_bind(&mut self) -> Result<Box<dyn Tunnel>, super::TunnelError> {
let mut futures = FuturesUnordered::new();
for bind_addr in self.bind_addrs.iter() {
let socket2_socket = socket2::Socket::new(
socket2::Domain::for_address(*bind_addr),
socket2::Type::DGRAM,
Some(socket2::Protocol::UDP),
)?;
setup_sokcet2(&socket2_socket, &bind_addr)?;
let socket = UdpSocket::from_std(socket2_socket.into())?;
futures.push(self.try_connect_with_socket(socket));
}
let Some(ret) = futures.next().await else {
return Err(super::TunnelError::CommonError(
"join connect futures failed".to_owned(),
));
};
return ret;
}
}
#[async_trait]
impl super::TunnelConnector for UdpTunnelConnector {
async fn connect(&mut self) -> Result<Box<dyn super::Tunnel>, super::TunnelError> {
if self.bind_addrs.is_empty() {
self.connect_with_default_bind().await
} else {
self.connect_with_custom_bind().await
}
}
fn remote_url(&self) -> url::Url {
self.addr.clone()
}
fn set_bind_addrs(&mut self, addrs: Vec<SocketAddr>) {
self.bind_addrs = addrs;
}
}
#[cfg(test)]
mod tests {
use std::time::Duration;
use rand::Rng;
use tokio::time::timeout;
use crate::{
common::global_ctx::tests::get_mock_global_ctx,
tunnels::{
check_scheme_and_get_socket_addr,
common::{
get_interface_name_by_ip, setup_sokcet2_ext,
tests::{_tunnel_bench, _tunnel_echo_server, _tunnel_pingpong},
},
},
};
use super::*;
#[tokio::test]
async fn udp_pingpong() {
let listener = UdpTunnelListener::new("udp://0.0.0.0:5556".parse().unwrap());
let connector = UdpTunnelConnector::new("udp://127.0.0.1:5556".parse().unwrap());
_tunnel_pingpong(listener, connector).await
}
#[tokio::test]
async fn udp_bench() {
let listener = UdpTunnelListener::new("udp://0.0.0.0:5555".parse().unwrap());
let connector = UdpTunnelConnector::new("udp://127.0.0.1:5555".parse().unwrap());
_tunnel_bench(listener, connector).await
}
#[tokio::test]
async fn udp_bench_with_bind() {
let listener = UdpTunnelListener::new("udp://127.0.0.1:5554".parse().unwrap());
let mut connector = UdpTunnelConnector::new("udp://127.0.0.1:5554".parse().unwrap());
connector.set_bind_addrs(vec!["127.0.0.1:0".parse().unwrap()]);
_tunnel_pingpong(listener, connector).await
}
#[tokio::test]
#[should_panic]
async fn udp_bench_with_bind_fail() {
let listener = UdpTunnelListener::new("udp://127.0.0.1:5553".parse().unwrap());
let mut connector = UdpTunnelConnector::new("udp://127.0.0.1:5553".parse().unwrap());
connector.set_bind_addrs(vec!["10.0.0.1:0".parse().unwrap()]);
_tunnel_pingpong(listener, connector).await
}
async fn send_random_data_to_socket(remote_url: url::Url) {
let socket = UdpSocket::bind("0.0.0.0:0").await.unwrap();
socket
.connect(format!(
"{}:{}",
remote_url.host().unwrap(),
remote_url.port().unwrap()
))
.await
.unwrap();
// get a random 100-len buf
loop {
let mut buf = vec![0u8; 100];
rand::thread_rng().fill(&mut buf[..]);
socket.send(&buf).await.unwrap();
tokio::time::sleep(tokio::time::Duration::from_millis(50)).await;
}
}
#[tokio::test]
async fn udp_multiple_conns() {
let mut listener = UdpTunnelListener::new("udp://0.0.0.0:5557".parse().unwrap());
listener.listen().await.unwrap();
let _lis = tokio::spawn(async move {
loop {
let ret = listener.accept().await.unwrap();
assert_eq!(
ret.info().unwrap().local_addr,
listener.local_url().to_string()
);
tokio::spawn(async move { _tunnel_echo_server(ret, false).await });
}
});
let mut connector1 = UdpTunnelConnector::new("udp://127.0.0.1:5557".parse().unwrap());
let mut connector2 = UdpTunnelConnector::new("udp://127.0.0.1:5557".parse().unwrap());
let t1 = connector1.connect().await.unwrap();
let t2 = connector2.connect().await.unwrap();
tokio::spawn(timeout(
Duration::from_secs(2),
send_random_data_to_socket(t1.info().unwrap().local_addr.parse().unwrap()),
));
tokio::spawn(timeout(
Duration::from_secs(2),
send_random_data_to_socket(t1.info().unwrap().remote_addr.parse().unwrap()),
));
tokio::spawn(timeout(
Duration::from_secs(2),
send_random_data_to_socket(t2.info().unwrap().remote_addr.parse().unwrap()),
));
let sender1 = tokio::spawn(async move {
let mut sink = t1.pin_sink();
let mut stream = t1.pin_stream();
for i in 0..10 {
sink.send(Bytes::from("hello1")).await.unwrap();
let recv = stream.next().await.unwrap().unwrap();
println!("t1 recv: {:?}, {:?}", recv, i);
assert_eq!(recv, Bytes::from("hello1"));
tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
}
});
let sender2 = tokio::spawn(async move {
let mut sink = t2.pin_sink();
let mut stream = t2.pin_stream();
for i in 0..10 {
sink.send(Bytes::from("hello2")).await.unwrap();
let recv = stream.next().await.unwrap().unwrap();
println!("t2 recv: {:?}, {:?}", recv, i);
assert_eq!(recv, Bytes::from("hello2"));
tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
}
});
let _ = tokio::join!(sender1, sender2);
}
#[tokio::test]
async fn udp_packet_print() {
let udp_packet = UdpPacket::new_data_packet(1, vec![1, 2, 3, 4, 5]);
let b = encode_to_bytes::<_, UDP_DATA_MTU>(&udp_packet);
let a_udp_packet = rkyv_util::decode_from_bytes::<UdpPacket>(&b).unwrap();
println!("{:?}, {:?}", udp_packet, a_udp_packet);
}
#[tokio::test]
async fn bind_multi_ip_to_same_dev() {
let global_ctx = get_mock_global_ctx();
let ips = global_ctx
.get_ip_collector()
.collect_ip_addrs()
.await
.interface_ipv4s;
if ips.is_empty() {
return;
}
let bind_dev = get_interface_name_by_ip(&ips[0].parse().unwrap());
for ip in ips {
println!("bind to ip: {:?}, {:?}", ip, bind_dev);
let addr = check_scheme_and_get_socket_addr::<SocketAddr>(
&format!("udp://{}:11111", ip).parse().unwrap(),
"udp",
)
.unwrap();
let socket2_socket = socket2::Socket::new(
socket2::Domain::for_address(addr),
socket2::Type::DGRAM,
Some(socket2::Protocol::UDP),
)
.unwrap();
setup_sokcet2_ext(&socket2_socket, &addr, bind_dev.clone()).unwrap();
}
}
}
+797
View File
@@ -0,0 +1,797 @@
use std::{
collections::hash_map::DefaultHasher,
fmt::{Debug, Formatter},
hash::Hasher,
net::SocketAddr,
pin::Pin,
sync::Arc,
time::Duration,
};
use anyhow::Context;
use async_recursion::async_recursion;
use async_trait::async_trait;
use boringtun::{
noise::{errors::WireGuardError, Tunn, TunnResult},
x25519::{PublicKey, StaticSecret},
};
use dashmap::DashMap;
use futures::{stream::FuturesUnordered, SinkExt, StreamExt};
use rand::RngCore;
use tokio::{net::UdpSocket, sync::Mutex, task::JoinSet};
use crate::{
rpc::TunnelInfo,
tunnels::{build_url_from_socket_addr, common::TunnelWithCustomInfo},
};
use super::{
check_scheme_and_get_socket_addr,
common::{setup_sokcet2, setup_sokcet2_ext},
ring_tunnel::create_ring_tunnel_pair,
DatagramSink, DatagramStream, Tunnel, TunnelError, TunnelListener, TunnelUrl,
};
const MAX_PACKET: usize = 65500;
#[derive(Debug, Clone)]
enum WgType {
// used by easytier peer, need remove/add ip header for in/out wg msg
InternalUse,
// used by wireguard peer, keep original ip header
ExternalUse,
}
#[derive(Clone)]
pub struct WgConfig {
my_secret_key: StaticSecret,
my_public_key: PublicKey,
peer_secret_key: StaticSecret,
peer_public_key: PublicKey,
wg_type: WgType,
}
impl WgConfig {
pub fn new_from_network_identity(network_name: &str, network_secret: &str) -> Self {
let mut my_sec = [0u8; 32];
let mut hasher = DefaultHasher::new();
hasher.write(network_name.as_bytes());
hasher.write(network_secret.as_bytes());
my_sec[0..8].copy_from_slice(&hasher.finish().to_be_bytes());
hasher.write(&my_sec[0..8]);
my_sec[8..16].copy_from_slice(&hasher.finish().to_be_bytes());
hasher.write(&my_sec[0..16]);
my_sec[16..24].copy_from_slice(&hasher.finish().to_be_bytes());
hasher.write(&my_sec[0..24]);
my_sec[24..32].copy_from_slice(&hasher.finish().to_be_bytes());
let my_secret_key = StaticSecret::from(my_sec);
let my_public_key = PublicKey::from(&my_secret_key);
let peer_secret_key = StaticSecret::from(my_sec);
let peer_public_key = my_public_key.clone();
WgConfig {
my_secret_key,
my_public_key,
peer_secret_key,
peer_public_key,
wg_type: WgType::InternalUse,
}
}
pub fn new_for_portal(server_key_seed: &str, client_key_seed: &str) -> Self {
let server_cfg = Self::new_from_network_identity("server", server_key_seed);
let client_cfg = Self::new_from_network_identity("client", client_key_seed);
Self {
my_secret_key: server_cfg.my_secret_key,
my_public_key: server_cfg.my_public_key,
peer_secret_key: client_cfg.my_secret_key,
peer_public_key: client_cfg.my_public_key,
wg_type: WgType::ExternalUse,
}
}
pub fn my_secret_key(&self) -> &[u8] {
self.my_secret_key.as_bytes()
}
pub fn peer_secret_key(&self) -> &[u8] {
self.peer_secret_key.as_bytes()
}
pub fn my_public_key(&self) -> &[u8] {
self.my_public_key.as_bytes()
}
pub fn peer_public_key(&self) -> &[u8] {
self.peer_public_key.as_bytes()
}
}
#[derive(Clone)]
struct WgPeerData {
udp: Arc<UdpSocket>, // only for send
endpoint: SocketAddr,
tunn: Arc<Mutex<Tunn>>,
sink: Arc<Mutex<Pin<Box<dyn DatagramSink>>>>,
stream: Arc<Mutex<Pin<Box<dyn DatagramStream>>>>,
wg_type: WgType,
}
impl Debug for WgPeerData {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
f.debug_struct("WgPeerData")
.field("endpoint", &self.endpoint)
.field("local", &self.udp.local_addr())
.finish()
}
}
impl WgPeerData {
#[tracing::instrument]
async fn handle_one_packet_from_me(&self, packet: &[u8]) -> Result<(), anyhow::Error> {
let mut send_buf = vec![0u8; MAX_PACKET];
let encapsulate_result = {
let mut peer = self.tunn.lock().await;
if matches!(self.wg_type, WgType::InternalUse) {
peer.encapsulate(&self.add_ip_header(&packet), &mut send_buf)
} else {
peer.encapsulate(&packet, &mut send_buf)
}
};
tracing::trace!(
?encapsulate_result,
"Received {} bytes from me",
packet.len()
);
match encapsulate_result {
TunnResult::WriteToNetwork(packet) => {
self.udp
.send_to(packet, self.endpoint)
.await
.context("Failed to send encrypted IP packet to WireGuard endpoint.")?;
tracing::debug!(
"Sent {} bytes to WireGuard endpoint (encrypted IP packet)",
packet.len()
);
}
TunnResult::Err(e) => {
tracing::error!("Failed to encapsulate IP packet: {:?}", e);
}
TunnResult::Done => {
// Ignored
}
other => {
tracing::error!(
"Unexpected WireGuard state during encapsulation: {:?}",
other
);
}
};
Ok(())
}
/// WireGuard consumption task. Receives encrypted packets from the WireGuard endpoint,
/// decapsulates them, and dispatches newly received IP packets.
#[tracing::instrument]
pub async fn handle_one_packet_from_peer(&self, recv_buf: &[u8]) {
let mut send_buf = vec![0u8; MAX_PACKET];
let data = &recv_buf[..];
let decapsulate_result = {
let mut peer = self.tunn.lock().await;
peer.decapsulate(None, data, &mut send_buf)
};
tracing::debug!("Decapsulation result: {:?}", decapsulate_result);
match decapsulate_result {
TunnResult::WriteToNetwork(packet) => {
match self.udp.send_to(packet, self.endpoint).await {
Ok(_) => {}
Err(e) => {
tracing::error!("Failed to send decapsulation-instructed packet to WireGuard endpoint: {:?}", e);
return;
}
};
let mut peer = self.tunn.lock().await;
loop {
let mut send_buf = vec![0u8; MAX_PACKET];
match peer.decapsulate(None, &[], &mut send_buf) {
TunnResult::WriteToNetwork(packet) => {
match self.udp.send_to(packet, self.endpoint).await {
Ok(_) => {}
Err(e) => {
tracing::error!("Failed to send decapsulation-instructed packet to WireGuard endpoint: {:?}", e);
break;
}
};
}
_ => {
break;
}
}
}
}
TunnResult::WriteToTunnelV4(packet, _) | TunnResult::WriteToTunnelV6(packet, _) => {
tracing::debug!(
"WireGuard endpoint sent IP packet of {} bytes",
packet.len()
);
let ret = self
.sink
.lock()
.await
.send(
if matches!(self.wg_type, WgType::InternalUse) {
self.remove_ip_header(packet, packet[0] >> 4 == 4)
} else {
packet
}
.to_vec()
.into(),
)
.await;
if ret.is_err() {
tracing::error!("Failed to send packet to tunnel: {:?}", ret);
}
}
_ => {
tracing::warn!(
"Unexpected WireGuard state during decapsulation: {:?}",
decapsulate_result
);
}
}
}
#[tracing::instrument]
#[async_recursion]
async fn handle_routine_tun_result<'a: 'async_recursion>(&self, result: TunnResult<'a>) -> () {
match result {
TunnResult::WriteToNetwork(packet) => {
tracing::debug!(
"Sending routine packet of {} bytes to WireGuard endpoint",
packet.len()
);
match self.udp.send_to(packet, self.endpoint).await {
Ok(_) => {}
Err(e) => {
tracing::error!(
"Failed to send routine packet to WireGuard endpoint: {:?}",
e
);
}
};
}
TunnResult::Err(WireGuardError::ConnectionExpired) => {
tracing::warn!("Wireguard handshake has expired!");
let mut buf = vec![0u8; MAX_PACKET];
let result = self
.tunn
.lock()
.await
.format_handshake_initiation(&mut buf[..], false);
self.handle_routine_tun_result(result).await
}
TunnResult::Err(e) => {
tracing::error!(
"Failed to prepare routine packet for WireGuard endpoint: {:?}",
e
);
}
TunnResult::Done => {
// Sleep for a bit
tokio::time::sleep(Duration::from_millis(250)).await;
}
other => {
tracing::warn!("Unexpected WireGuard routine task state: {:?}", other);
tokio::time::sleep(Duration::from_millis(250)).await;
}
};
}
/// WireGuard Routine task. Handles Handshake, keep-alive, etc.
pub async fn routine_task(self) {
loop {
let mut send_buf = vec![0u8; MAX_PACKET];
let tun_result = { self.tunn.lock().await.update_timers(&mut send_buf) };
self.handle_routine_tun_result(tun_result).await;
}
}
fn add_ip_header(&self, packet: &[u8]) -> Vec<u8> {
let mut ret = vec![0u8; packet.len() + 20];
let ip_header = ret.as_mut_slice();
ip_header[0] = 0x45;
ip_header[1] = 0;
ip_header[2..4].copy_from_slice(&((packet.len() + 20) as u16).to_be_bytes());
ip_header[4..6].copy_from_slice(&0u16.to_be_bytes());
ip_header[6..8].copy_from_slice(&0u16.to_be_bytes());
ip_header[8] = 64;
ip_header[9] = 0;
ip_header[10..12].copy_from_slice(&0u16.to_be_bytes());
ip_header[12..16].copy_from_slice(&0u32.to_be_bytes());
ip_header[16..20].copy_from_slice(&0u32.to_be_bytes());
ip_header[20..].copy_from_slice(packet);
ret
}
fn remove_ip_header<'a>(&self, packet: &'a [u8], is_v4: bool) -> &'a [u8] {
if is_v4 {
return &packet[20..];
} else {
return &packet[40..];
}
}
}
struct WgPeer {
udp: Arc<UdpSocket>, // only for send
config: WgConfig,
endpoint: SocketAddr,
data: Option<WgPeerData>,
tasks: JoinSet<()>,
access_time: std::time::Instant,
}
impl WgPeer {
fn new(udp: Arc<UdpSocket>, config: WgConfig, endpoint: SocketAddr) -> Self {
WgPeer {
udp,
config,
endpoint,
data: None,
tasks: JoinSet::new(),
access_time: std::time::Instant::now(),
}
}
async fn handle_packet_from_me(data: WgPeerData) {
while let Some(Ok(packet)) = data.stream.lock().await.next().await {
let ret = data.handle_one_packet_from_me(&packet).await;
if let Err(e) = ret {
tracing::error!("Failed to handle packet from me: {}", e);
}
}
}
async fn handle_packet_from_peer(&mut self, packet: &[u8]) {
self.access_time = std::time::Instant::now();
tracing::trace!("Received {} bytes from peer", packet.len());
let data = self.data.as_ref().unwrap();
data.handle_one_packet_from_peer(packet).await;
}
fn start_and_get_tunnel(&mut self) -> Box<dyn Tunnel> {
let (stunnel, ctunnel) = create_ring_tunnel_pair();
let data = WgPeerData {
udp: self.udp.clone(),
endpoint: self.endpoint,
tunn: Arc::new(Mutex::new(
Tunn::new(
self.config.my_secret_key.clone(),
self.config.peer_public_key.clone(),
None,
None,
rand::thread_rng().next_u32(),
None,
)
.unwrap(),
)),
sink: Arc::new(Mutex::new(stunnel.pin_sink())),
stream: Arc::new(Mutex::new(stunnel.pin_stream())),
wg_type: self.config.wg_type.clone(),
};
self.data = Some(data.clone());
self.tasks.spawn(Self::handle_packet_from_me(data.clone()));
self.tasks.spawn(data.routine_task());
ctunnel
}
}
impl Drop for WgPeer {
fn drop(&mut self) {
self.tasks.abort_all();
if let Some(data) = self.data.clone() {
tokio::spawn(async move {
let _ = data.sink.lock().await.close().await;
});
}
}
}
type ConnSender = tokio::sync::mpsc::UnboundedSender<Box<dyn Tunnel>>;
type ConnReceiver = tokio::sync::mpsc::UnboundedReceiver<Box<dyn Tunnel>>;
pub struct WgTunnelListener {
addr: url::Url,
config: WgConfig,
udp: Option<Arc<UdpSocket>>,
conn_recv: ConnReceiver,
conn_send: Option<ConnSender>,
tasks: JoinSet<()>,
}
impl WgTunnelListener {
pub fn new(addr: url::Url, config: WgConfig) -> Self {
let (conn_send, conn_recv) = tokio::sync::mpsc::unbounded_channel();
WgTunnelListener {
addr,
config,
udp: None,
conn_recv,
conn_send: Some(conn_send),
tasks: JoinSet::new(),
}
}
fn get_udp_socket(&self) -> Arc<UdpSocket> {
self.udp.as_ref().unwrap().clone()
}
async fn handle_udp_incoming(
socket: Arc<UdpSocket>,
config: WgConfig,
conn_sender: ConnSender,
) {
let mut tasks = JoinSet::new();
let peer_map: Arc<DashMap<SocketAddr, WgPeer>> = Arc::new(DashMap::new());
let peer_map_clone = peer_map.clone();
tasks.spawn(async move {
loop {
peer_map_clone.retain(|_, peer| peer.access_time.elapsed().as_secs() < 600);
tokio::time::sleep(Duration::from_secs(60)).await;
}
});
let mut buf = vec![0u8; MAX_PACKET];
loop {
let Ok((n, addr)) = socket.recv_from(&mut buf).await else {
tracing::error!("Failed to receive from UDP socket");
break;
};
let data = &buf[..n];
tracing::trace!("Received {} bytes from {}", n, addr);
if !peer_map.contains_key(&addr) {
tracing::info!("New peer: {}", addr);
let mut wg = WgPeer::new(socket.clone(), config.clone(), addr.clone());
let tunnel = Box::new(TunnelWithCustomInfo::new(
wg.start_and_get_tunnel(),
TunnelInfo {
tunnel_type: "wg".to_owned(),
local_addr: build_url_from_socket_addr(
&socket.local_addr().unwrap().to_string(),
"wg",
)
.into(),
remote_addr: build_url_from_socket_addr(&addr.to_string(), "wg").into(),
},
));
if let Err(e) = conn_sender.send(tunnel) {
tracing::error!("Failed to send tunnel to conn_sender: {}", e);
}
peer_map.insert(addr, wg);
}
let mut peer = peer_map.get_mut(&addr).unwrap();
peer.handle_packet_from_peer(data).await;
}
}
}
#[async_trait]
impl TunnelListener for WgTunnelListener {
async fn listen(&mut self) -> Result<(), super::TunnelError> {
let addr = check_scheme_and_get_socket_addr::<SocketAddr>(&self.addr, "wg")?;
let socket2_socket = socket2::Socket::new(
socket2::Domain::for_address(addr),
socket2::Type::DGRAM,
Some(socket2::Protocol::UDP),
)?;
let tunnel_url: TunnelUrl = self.addr.clone().into();
if let Some(bind_dev) = tunnel_url.bind_dev() {
setup_sokcet2_ext(&socket2_socket, &addr, Some(bind_dev))?;
} else {
setup_sokcet2(&socket2_socket, &addr)?;
}
self.udp = Some(Arc::new(UdpSocket::from_std(socket2_socket.into())?));
self.tasks.spawn(Self::handle_udp_incoming(
self.get_udp_socket(),
self.config.clone(),
self.conn_send.take().unwrap(),
));
Ok(())
}
async fn accept(&mut self) -> Result<Box<dyn Tunnel>, super::TunnelError> {
while let Some(tunnel) = self.conn_recv.recv().await {
tracing::info!(?tunnel, "Accepted tunnel");
return Ok(tunnel);
}
Err(TunnelError::CommonError(
"Failed to accept tunnel".to_string(),
))
}
fn local_url(&self) -> url::Url {
self.addr.clone()
}
}
pub struct WgClientTunnel {
wg_peer: WgPeer,
tunnel: Box<dyn Tunnel>,
info: TunnelInfo,
}
impl Tunnel for WgClientTunnel {
fn stream(&self) -> Box<dyn DatagramStream> {
self.tunnel.stream()
}
fn sink(&self) -> Box<dyn DatagramSink> {
self.tunnel.sink()
}
fn info(&self) -> Option<TunnelInfo> {
Some(self.info.clone())
}
}
#[derive(Clone)]
pub struct WgTunnelConnector {
addr: url::Url,
config: WgConfig,
udp: Option<Arc<UdpSocket>>,
bind_addrs: Vec<SocketAddr>,
}
impl Debug for WgTunnelConnector {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
f.debug_struct("WgTunnelConnector")
.field("addr", &self.addr)
.field("udp", &self.udp)
.finish()
}
}
impl WgTunnelConnector {
pub fn new(addr: url::Url, config: WgConfig) -> Self {
WgTunnelConnector {
addr,
config,
udp: None,
bind_addrs: vec![],
}
}
fn create_handshake_init(tun: &mut Tunn) -> Vec<u8> {
let mut dst = vec![0u8; 2048];
let handshake_init = tun.format_handshake_initiation(&mut dst, false);
assert!(matches!(handshake_init, TunnResult::WriteToNetwork(_)));
let handshake_init = if let TunnResult::WriteToNetwork(sent) = handshake_init {
sent
} else {
unreachable!();
};
handshake_init.into()
}
fn parse_handshake_resp(tun: &mut Tunn, handshake_resp: &[u8]) -> Vec<u8> {
let mut dst = vec![0u8; 2048];
let keepalive = tun.decapsulate(None, handshake_resp, &mut dst);
assert!(
matches!(keepalive, TunnResult::WriteToNetwork(_)),
"Failed to parse handshake response, {:?}",
keepalive
);
let keepalive = if let TunnResult::WriteToNetwork(sent) = keepalive {
sent
} else {
unreachable!();
};
keepalive.into()
}
#[tracing::instrument(skip(config))]
async fn connect_with_socket(
addr_url: url::Url,
config: WgConfig,
udp: UdpSocket,
) -> Result<Box<dyn super::Tunnel>, super::TunnelError> {
let addr = super::check_scheme_and_get_socket_addr::<SocketAddr>(&addr_url, "wg")?;
tracing::warn!("wg connect: {:?}", addr);
let local_addr = udp.local_addr().unwrap().to_string();
let mut my_tun = Tunn::new(
config.my_secret_key.clone(),
config.peer_public_key.clone(),
None,
None,
rand::thread_rng().next_u32(),
None,
)
.unwrap();
let init = Self::create_handshake_init(&mut my_tun);
udp.send_to(&init, addr).await?;
let mut buf = vec![0u8; MAX_PACKET];
let (n, _) = udp.recv_from(&mut buf).await.unwrap();
let keepalive = Self::parse_handshake_resp(&mut my_tun, &buf[..n]);
udp.send_to(&keepalive, addr).await?;
let mut wg_peer = WgPeer::new(Arc::new(udp), config.clone(), addr);
let tunnel = wg_peer.start_and_get_tunnel();
let data = wg_peer.data.as_ref().unwrap().clone();
wg_peer.tasks.spawn(async move {
loop {
let mut buf = vec![0u8; MAX_PACKET];
let (n, recv_addr) = data.udp.recv_from(&mut buf).await.unwrap();
if recv_addr != addr {
continue;
}
data.handle_one_packet_from_peer(&buf[..n]).await;
}
});
let ret = Box::new(WgClientTunnel {
wg_peer,
tunnel,
info: TunnelInfo {
tunnel_type: "wg".to_owned(),
local_addr: super::build_url_from_socket_addr(&local_addr, "wg").into(),
remote_addr: addr_url.to_string(),
},
});
Ok(ret)
}
}
#[async_trait]
impl super::TunnelConnector for WgTunnelConnector {
#[tracing::instrument]
async fn connect(&mut self) -> Result<Box<dyn super::Tunnel>, super::TunnelError> {
let bind_addrs = if self.bind_addrs.is_empty() {
vec!["0.0.0.0:0".parse().unwrap()]
} else {
self.bind_addrs.clone()
};
let mut futures = FuturesUnordered::new();
for bind_addr in bind_addrs.into_iter() {
let socket2_socket = socket2::Socket::new(
socket2::Domain::for_address(bind_addr),
socket2::Type::DGRAM,
Some(socket2::Protocol::UDP),
)?;
setup_sokcet2(&socket2_socket, &bind_addr)?;
let socket = UdpSocket::from_std(socket2_socket.into())?;
tracing::info!(?bind_addr, ?self.addr, "prepare wg connect task");
futures.push(Self::connect_with_socket(
self.addr.clone(),
self.config.clone(),
socket,
));
}
let Some(ret) = futures.next().await else {
return Err(super::TunnelError::CommonError(
"join connect futures failed".to_owned(),
));
};
return ret;
}
fn remote_url(&self) -> url::Url {
self.addr.clone()
}
fn set_bind_addrs(&mut self, addrs: Vec<SocketAddr>) {
self.bind_addrs = addrs;
}
}
#[cfg(test)]
pub mod tests {
use boringtun::*;
use crate::tunnels::common::tests::{_tunnel_bench, _tunnel_pingpong};
use crate::tunnels::{wireguard::*, TunnelConnector};
pub fn create_wg_config() -> (WgConfig, WgConfig) {
let my_secret_key = x25519::StaticSecret::random_from_rng(rand::thread_rng());
let my_public_key = x25519::PublicKey::from(&my_secret_key);
let their_secret_key = x25519::StaticSecret::random_from_rng(rand::thread_rng());
let their_public_key = x25519::PublicKey::from(&their_secret_key);
let server_cfg = WgConfig {
my_secret_key: my_secret_key.clone(),
my_public_key,
peer_secret_key: their_secret_key.clone(),
peer_public_key: their_public_key.clone(),
wg_type: WgType::InternalUse,
};
let client_cfg = WgConfig {
my_secret_key: their_secret_key,
my_public_key: their_public_key,
peer_secret_key: my_secret_key,
peer_public_key: my_public_key,
wg_type: WgType::InternalUse,
};
(server_cfg, client_cfg)
}
#[tokio::test]
async fn wg_pingpong() {
let (server_cfg, client_cfg) = create_wg_config();
let listener = WgTunnelListener::new("wg://0.0.0.0:5599".parse().unwrap(), server_cfg);
let connector = WgTunnelConnector::new("wg://127.0.0.1:5599".parse().unwrap(), client_cfg);
_tunnel_pingpong(listener, connector).await
}
#[tokio::test]
async fn wg_bench() {
let (server_cfg, client_cfg) = create_wg_config();
let listener = WgTunnelListener::new("wg://0.0.0.0:5598".parse().unwrap(), server_cfg);
let connector = WgTunnelConnector::new("wg://127.0.0.1:5598".parse().unwrap(), client_cfg);
_tunnel_bench(listener, connector).await
}
#[tokio::test]
async fn wg_bench_with_bind() {
let (server_cfg, client_cfg) = create_wg_config();
let listener = WgTunnelListener::new("wg://127.0.0.1:5597".parse().unwrap(), server_cfg);
let mut connector =
WgTunnelConnector::new("wg://127.0.0.1:5597".parse().unwrap(), client_cfg);
connector.set_bind_addrs(vec!["127.0.0.1:0".parse().unwrap()]);
_tunnel_pingpong(listener, connector).await
}
#[tokio::test]
#[should_panic]
async fn wg_bench_with_bind_fail() {
let (server_cfg, client_cfg) = create_wg_config();
let listener = WgTunnelListener::new("wg://127.0.0.1:5596".parse().unwrap(), server_cfg);
let mut connector =
WgTunnelConnector::new("wg://127.0.0.1:5596".parse().unwrap(), client_cfg);
connector.set_bind_addrs(vec!["10.0.0.1:0".parse().unwrap()]);
_tunnel_pingpong(listener, connector).await
}
}
+24
View File
@@ -0,0 +1,24 @@
// with vpn portal, user can use other vpn client to connect to easytier servers
// without installing easytier.
// these vpn client include:
// 1. wireguard
// 2. openvpn (TODO)
// 3. shadowsocks (TODO)
use std::sync::Arc;
use crate::{common::global_ctx::ArcGlobalCtx, peers::peer_manager::PeerManager};
pub mod wireguard;
#[async_trait::async_trait]
pub trait VpnPortal: Send + Sync {
async fn start(
&mut self,
global_ctx: ArcGlobalCtx,
peer_mgr: Arc<PeerManager>,
) -> anyhow::Result<()>;
async fn dump_client_config(&self, peer_mgr: Arc<PeerManager>) -> String;
fn name(&self) -> String;
async fn list_clients(&self) -> Vec<String>;
}
+357
View File
@@ -0,0 +1,357 @@
use std::{
net::{Ipv4Addr, SocketAddr},
pin::Pin,
sync::Arc,
};
use anyhow::Context;
use base64::{prelude::BASE64_STANDARD, Engine};
use cidr::Ipv4Inet;
use dashmap::DashMap;
use futures::{SinkExt, StreamExt};
use pnet::packet::ipv4::Ipv4Packet;
use tokio::{sync::Mutex, task::JoinSet};
use tokio_util::bytes::Bytes;
use crate::{
common::{
global_ctx::{ArcGlobalCtx, GlobalCtxEvent},
join_joinset_background,
},
peers::{
packet::{self, ArchivedPacket},
peer_manager::PeerManager,
PeerPacketFilter,
},
tunnels::{
wireguard::{WgConfig, WgTunnelListener},
DatagramSink, Tunnel, TunnelListener,
},
};
use super::VpnPortal;
type WgPeerIpTable = Arc<DashMap<Ipv4Addr, Arc<ClientEntry>>>;
struct ClientEntry {
endpoint_addr: Option<url::Url>,
sink: Mutex<Pin<Box<dyn DatagramSink + 'static>>>,
}
struct WireGuardImpl {
global_ctx: ArcGlobalCtx,
peer_mgr: Arc<PeerManager>,
wg_config: WgConfig,
listenr_addr: SocketAddr,
wg_peer_ip_table: WgPeerIpTable,
tasks: Arc<std::sync::Mutex<JoinSet<()>>>,
}
impl WireGuardImpl {
fn new(global_ctx: ArcGlobalCtx, peer_mgr: Arc<PeerManager>) -> Self {
let nid = global_ctx.get_network_identity();
let key_seed = format!("{}{}", nid.network_name, nid.network_secret);
let wg_config = WgConfig::new_for_portal(&key_seed, &key_seed);
let vpn_cfg = global_ctx.config.get_vpn_portal_config().unwrap();
let listenr_addr = vpn_cfg.wireguard_listen;
Self {
global_ctx,
peer_mgr,
wg_config,
listenr_addr,
wg_peer_ip_table: Arc::new(DashMap::new()),
tasks: Arc::new(std::sync::Mutex::new(JoinSet::new())),
}
}
async fn handle_incoming_conn(
t: Box<dyn Tunnel>,
peer_mgr: Arc<PeerManager>,
wg_peer_ip_table: WgPeerIpTable,
) {
let mut s = t.pin_stream();
let mut ip_registered = false;
let info = t.info().unwrap_or_default();
let remote_addr = info.remote_addr.clone();
peer_mgr
.get_global_ctx()
.issue_event(GlobalCtxEvent::VpnPortalClientConnected(
info.local_addr,
info.remote_addr,
));
while let Some(Ok(msg)) = s.next().await {
let Some(i) = Ipv4Packet::new(&msg) else {
tracing::error!(?msg, "Failed to parse ipv4 packet");
continue;
};
if !ip_registered {
let client_entry = Arc::new(ClientEntry {
endpoint_addr: remote_addr.parse().ok(),
sink: Mutex::new(t.pin_sink()),
});
wg_peer_ip_table.insert(i.get_source(), client_entry.clone());
ip_registered = true;
}
tracing::trace!(?i, "Received from wg client");
let _ = peer_mgr
.send_msg_ipv4(msg.clone(), i.get_destination())
.await;
}
let info = t.info().unwrap_or_default();
peer_mgr
.get_global_ctx()
.issue_event(GlobalCtxEvent::VpnPortalClientDisconnected(
info.local_addr,
info.remote_addr,
));
}
async fn start_pipeline_processor(&self) {
struct PeerPacketFilterForVpnPortal {
wg_peer_ip_table: WgPeerIpTable,
}
#[async_trait::async_trait]
impl PeerPacketFilter for PeerPacketFilterForVpnPortal {
async fn try_process_packet_from_peer(
&self,
packet: &ArchivedPacket,
_: &Bytes,
) -> Option<()> {
if packet.packet_type != packet::PacketType::Data {
return None;
};
let payload_bytes = packet.payload.as_bytes();
let ipv4 = Ipv4Packet::new(payload_bytes)?;
if ipv4.get_version() != 4 {
return None;
}
let entry = self.wg_peer_ip_table.get(&ipv4.get_destination())?.clone();
tracing::trace!(?ipv4, "Packet filter for vpn portal");
let ret = entry
.sink
.lock()
.await
.send(Bytes::copy_from_slice(payload_bytes))
.await;
ret.ok()
}
}
self.peer_mgr
.add_packet_process_pipeline(Box::new(PeerPacketFilterForVpnPortal {
wg_peer_ip_table: self.wg_peer_ip_table.clone(),
}))
.await;
}
async fn start(&self) -> anyhow::Result<()> {
let mut l = WgTunnelListener::new(
format!("wg://{}", self.listenr_addr).parse().unwrap(),
self.wg_config.clone(),
);
l.listen()
.await
.with_context(|| "Failed to start wireguard listener for vpn portal")?;
join_joinset_background(self.tasks.clone(), "wireguard".to_string());
let tasks = Arc::downgrade(&self.tasks.clone());
let peer_mgr = self.peer_mgr.clone();
let wg_peer_ip_table = self.wg_peer_ip_table.clone();
self.tasks.lock().unwrap().spawn(async move {
while let Ok(t) = l.accept().await {
let Some(tasks) = tasks.upgrade() else {
break;
};
tasks.lock().unwrap().spawn(Self::handle_incoming_conn(
t,
peer_mgr.clone(),
wg_peer_ip_table.clone(),
));
}
});
self.start_pipeline_processor().await;
Ok(())
}
}
#[derive(Default)]
pub struct WireGuard {
inner: Option<WireGuardImpl>,
}
#[async_trait::async_trait]
impl VpnPortal for WireGuard {
async fn start(
&mut self,
global_ctx: ArcGlobalCtx,
peer_mgr: Arc<PeerManager>,
) -> anyhow::Result<()> {
assert!(self.inner.is_none());
let vpn_cfg = global_ctx.config.get_vpn_portal_config();
if vpn_cfg.is_none() {
anyhow::bail!("vpn cfg is not set for wireguard vpn portal");
}
let inner = WireGuardImpl::new(global_ctx, peer_mgr);
inner.start().await?;
self.inner = Some(inner);
Ok(())
}
async fn dump_client_config(&self, peer_mgr: Arc<PeerManager>) -> String {
if self.inner.is_none() {
return "ERROR: Wireguard VPN Portal Not Started".to_string();
}
let global_ctx = self.inner.as_ref().unwrap().global_ctx.clone();
if global_ctx.config.get_vpn_portal_config().is_none() {
return "ERROR: VPN Portal Config Not Set".to_string();
}
let routes = peer_mgr.list_routes().await;
let mut allow_ips = routes
.iter()
.map(|x| x.proxy_cidrs.iter().map(String::to_string))
.flatten()
.collect::<Vec<_>>();
for ipv4 in routes
.iter()
.map(|x| x.ipv4_addr.clone())
.chain(global_ctx.get_ipv4().iter().map(|x| x.to_string()))
{
let Ok(ipv4) = ipv4.parse() else {
continue;
};
let inet = Ipv4Inet::new(ipv4, 24).unwrap();
allow_ips.push(inet.network().to_string());
break;
}
let allow_ips = allow_ips
.into_iter()
.map(|x| x.to_string())
.collect::<Vec<_>>()
.join(",");
let vpn_cfg = global_ctx.config.get_vpn_portal_config().unwrap();
let client_cidr = vpn_cfg.client_cidr;
let cfg = self.inner.as_ref().unwrap().wg_config.clone();
let cfg_str = format!(
r#"
[Interface]
PrivateKey = {peer_secret_key}
Address = {client_cidr} # should assign an ip from this cidr manually
[Peer]
PublicKey = {my_public_key}
AllowedIPs = {allow_ips}
Endpoint = {listenr_addr} # should be the public ip of the vpn server
"#,
peer_secret_key = BASE64_STANDARD.encode(cfg.peer_secret_key()),
my_public_key = BASE64_STANDARD.encode(cfg.my_public_key()),
listenr_addr = self.inner.as_ref().unwrap().listenr_addr,
allow_ips = allow_ips,
client_cidr = client_cidr,
);
cfg_str
}
fn name(&self) -> String {
"wireguard".to_string()
}
async fn list_clients(&self) -> Vec<String> {
self.inner
.as_ref()
.unwrap()
.wg_peer_ip_table
.iter()
.map(|x| {
x.value()
.endpoint_addr
.as_ref()
.map(|x| x.to_string())
.unwrap_or_default()
})
.collect()
}
}
#[cfg(test)]
mod tests {
use std::sync::Arc;
use super::*;
use crate::{
common::{
config::{NetworkIdentity, VpnPortalConfig},
global_ctx::tests::get_mock_global_ctx_with_network,
},
connector::udp_hole_punch::tests::replace_stun_info_collector,
peers::{
peer_manager::{PeerManager, RouteAlgoType},
tests::wait_for_condition,
},
rpc::NatType,
tunnels::{tcp_tunnel::TcpTunnelConnector, TunnelConnector},
};
async fn portal_test() {
let (s, _r) = tokio::sync::mpsc::channel(1000);
let peer_mgr = Arc::new(PeerManager::new(
RouteAlgoType::Ospf,
get_mock_global_ctx_with_network(Some(NetworkIdentity {
network_name: "sijie".to_string(),
network_secret: "1919119".to_string(),
})),
s,
));
replace_stun_info_collector(peer_mgr.clone(), NatType::Unknown);
peer_mgr
.get_global_ctx()
.config
.set_vpn_portal_config(VpnPortalConfig {
wireguard_listen: "0.0.0.0:11021".parse().unwrap(),
client_cidr: "10.14.14.0/24".parse().unwrap(),
});
peer_mgr.run().await.unwrap();
let mut pmgr_conn = TcpTunnelConnector::new("tcp://127.0.0.1:11010".parse().unwrap());
let tunnel = pmgr_conn.connect().await;
peer_mgr.add_client_tunnel(tunnel.unwrap()).await.unwrap();
wait_for_condition(
|| async {
let routes = peer_mgr.list_routes().await;
println!("Routes: {:?}", routes);
routes.len() != 0
},
std::time::Duration::from_secs(10),
)
.await;
let mut wg = WireGuard::default();
wg.start(peer_mgr.get_global_ctx(), peer_mgr.clone())
.await
.unwrap();
}
}