use workspace, prepare for config server and gui (#48)
This commit is contained in:
@@ -0,0 +1,2 @@
|
||||
#[cfg(target_os = "windows")]
|
||||
pub mod windows;
|
||||
@@ -0,0 +1,145 @@
|
||||
use std::{
|
||||
ffi::c_void,
|
||||
io::{self, ErrorKind},
|
||||
mem,
|
||||
net::SocketAddr,
|
||||
os::windows::io::AsRawSocket,
|
||||
ptr,
|
||||
};
|
||||
|
||||
use network_interface::NetworkInterfaceConfig;
|
||||
use windows_sys::{
|
||||
core::PCSTR,
|
||||
Win32::{
|
||||
Foundation::{BOOL, FALSE},
|
||||
Networking::WinSock::{
|
||||
htonl, setsockopt, WSAGetLastError, WSAIoctl, IPPROTO_IP, IPPROTO_IPV6,
|
||||
IPV6_UNICAST_IF, IP_UNICAST_IF, SIO_UDP_CONNRESET, SOCKET, SOCKET_ERROR,
|
||||
},
|
||||
},
|
||||
};
|
||||
|
||||
pub fn disable_connection_reset<S: AsRawSocket>(socket: &S) -> io::Result<()> {
|
||||
let handle = socket.as_raw_socket() as SOCKET;
|
||||
|
||||
unsafe {
|
||||
// Ignoring UdpSocket's WSAECONNRESET error
|
||||
// https://github.com/shadowsocks/shadowsocks-rust/issues/179
|
||||
// https://stackoverflow.com/questions/30749423/is-winsock-error-10054-wsaeconnreset-normal-with-udp-to-from-localhost
|
||||
//
|
||||
// This is because `UdpSocket::recv_from` may return WSAECONNRESET
|
||||
// if you called `UdpSocket::send_to` a destination that is not existed (may be closed).
|
||||
//
|
||||
// It is not an error. Could be ignored completely.
|
||||
// We have to ignore it here because it will crash the server.
|
||||
|
||||
let mut bytes_returned: u32 = 0;
|
||||
let enable: BOOL = FALSE;
|
||||
|
||||
let ret = WSAIoctl(
|
||||
handle,
|
||||
SIO_UDP_CONNRESET,
|
||||
&enable as *const _ as *const c_void,
|
||||
mem::size_of_val(&enable) as u32,
|
||||
ptr::null_mut(),
|
||||
0,
|
||||
&mut bytes_returned as *mut _,
|
||||
ptr::null_mut(),
|
||||
None,
|
||||
);
|
||||
|
||||
if ret == SOCKET_ERROR {
|
||||
use std::io::Error;
|
||||
|
||||
// Error occurs
|
||||
let err_code = WSAGetLastError();
|
||||
return Err(Error::from_raw_os_error(err_code));
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn find_interface_index(iface_name: &str) -> io::Result<u32> {
|
||||
let ifaces = network_interface::NetworkInterface::show().map_err(|e| {
|
||||
io::Error::new(
|
||||
ErrorKind::NotFound,
|
||||
format!("Failed to get interfaces. {}, error: {}", iface_name, e),
|
||||
)
|
||||
})?;
|
||||
if let Some(iface) = ifaces.iter().find(|iface| iface.name == iface_name) {
|
||||
return Ok(iface.index);
|
||||
}
|
||||
tracing::error!("Failed to find interface index for {}", iface_name);
|
||||
Err(io::Error::new(
|
||||
ErrorKind::NotFound,
|
||||
format!("{}", iface_name),
|
||||
))
|
||||
}
|
||||
|
||||
pub fn set_ip_unicast_if<S: AsRawSocket>(
|
||||
socket: &S,
|
||||
addr: &SocketAddr,
|
||||
iface: &str,
|
||||
) -> io::Result<()> {
|
||||
let handle = socket.as_raw_socket() as SOCKET;
|
||||
|
||||
let if_index = find_interface_index(iface)?;
|
||||
|
||||
unsafe {
|
||||
// https://docs.microsoft.com/en-us/windows/win32/winsock/ipproto-ip-socket-options
|
||||
let ret = match addr {
|
||||
SocketAddr::V4(..) => {
|
||||
// Interface index is in network byte order for IPPROTO_IP.
|
||||
let if_index = htonl(if_index);
|
||||
setsockopt(
|
||||
handle,
|
||||
IPPROTO_IP as i32,
|
||||
IP_UNICAST_IF as i32,
|
||||
&if_index as *const _ as PCSTR,
|
||||
mem::size_of_val(&if_index) as i32,
|
||||
)
|
||||
}
|
||||
SocketAddr::V6(..) => {
|
||||
// Interface index is in host byte order for IPPROTO_IPV6.
|
||||
setsockopt(
|
||||
handle,
|
||||
IPPROTO_IPV6 as i32,
|
||||
IPV6_UNICAST_IF as i32,
|
||||
&if_index as *const _ as PCSTR,
|
||||
mem::size_of_val(&if_index) as i32,
|
||||
)
|
||||
}
|
||||
};
|
||||
|
||||
if ret == SOCKET_ERROR {
|
||||
let err = io::Error::from_raw_os_error(WSAGetLastError());
|
||||
tracing::error!(
|
||||
"set IP_UNICAST_IF / IPV6_UNICAST_IF interface: {}, index: {}, error: {}",
|
||||
iface,
|
||||
if_index,
|
||||
err
|
||||
);
|
||||
return Err(err);
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn setup_socket_for_win<S: AsRawSocket>(
|
||||
socket: &S,
|
||||
bind_addr: &SocketAddr,
|
||||
bind_dev: Option<String>,
|
||||
is_udp: bool,
|
||||
) -> io::Result<()> {
|
||||
if is_udp {
|
||||
disable_connection_reset(socket)?;
|
||||
}
|
||||
|
||||
if let Some(iface) = bind_dev {
|
||||
set_ip_unicast_if(socket, bind_addr, iface.as_str())?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
@@ -0,0 +1,425 @@
|
||||
use std::{
|
||||
net::SocketAddr,
|
||||
sync::{Arc, Mutex},
|
||||
};
|
||||
|
||||
use anyhow::Context;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
#[auto_impl::auto_impl(Box, &)]
|
||||
pub trait ConfigLoader: Send + Sync {
|
||||
fn get_id(&self) -> uuid::Uuid;
|
||||
|
||||
fn get_inst_name(&self) -> String;
|
||||
fn set_inst_name(&self, name: String);
|
||||
|
||||
fn get_netns(&self) -> Option<String>;
|
||||
fn set_netns(&self, ns: Option<String>);
|
||||
|
||||
fn get_ipv4(&self) -> Option<std::net::Ipv4Addr>;
|
||||
fn set_ipv4(&self, addr: std::net::Ipv4Addr);
|
||||
|
||||
fn add_proxy_cidr(&self, cidr: cidr::IpCidr);
|
||||
fn remove_proxy_cidr(&self, cidr: cidr::IpCidr);
|
||||
fn get_proxy_cidrs(&self) -> Vec<cidr::IpCidr>;
|
||||
|
||||
fn get_network_identity(&self) -> NetworkIdentity;
|
||||
fn set_network_identity(&self, identity: NetworkIdentity);
|
||||
|
||||
fn get_listener_uris(&self) -> Vec<url::Url>;
|
||||
|
||||
fn get_file_logger_config(&self) -> FileLoggerConfig;
|
||||
fn set_file_logger_config(&self, config: FileLoggerConfig);
|
||||
fn get_console_logger_config(&self) -> ConsoleLoggerConfig;
|
||||
fn set_console_logger_config(&self, config: ConsoleLoggerConfig);
|
||||
|
||||
fn get_peers(&self) -> Vec<PeerConfig>;
|
||||
fn set_peers(&self, peers: Vec<PeerConfig>);
|
||||
|
||||
fn get_listeners(&self) -> Vec<url::Url>;
|
||||
fn set_listeners(&self, listeners: Vec<url::Url>);
|
||||
|
||||
fn get_rpc_portal(&self) -> Option<SocketAddr>;
|
||||
fn set_rpc_portal(&self, addr: SocketAddr);
|
||||
|
||||
fn get_vpn_portal_config(&self) -> Option<VpnPortalConfig>;
|
||||
fn set_vpn_portal_config(&self, config: VpnPortalConfig);
|
||||
|
||||
fn get_flags(&self) -> Flags;
|
||||
fn set_flags(&self, flags: Flags);
|
||||
|
||||
fn dump(&self) -> String;
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Serialize, PartialEq)]
|
||||
pub struct NetworkIdentity {
|
||||
pub network_name: String,
|
||||
pub network_secret: String,
|
||||
}
|
||||
|
||||
impl NetworkIdentity {
|
||||
pub fn new(network_name: String, network_secret: String) -> Self {
|
||||
NetworkIdentity {
|
||||
network_name,
|
||||
network_secret,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn default() -> Self {
|
||||
Self::new("default".to_string(), "".to_string())
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Serialize, PartialEq)]
|
||||
pub struct PeerConfig {
|
||||
pub uri: url::Url,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Serialize, PartialEq)]
|
||||
pub struct NetworkConfig {
|
||||
pub cidr: String,
|
||||
pub allow: Option<Vec<String>>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Serialize, PartialEq, Default)]
|
||||
pub struct FileLoggerConfig {
|
||||
pub level: Option<String>,
|
||||
pub file: Option<String>,
|
||||
pub dir: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Serialize, PartialEq, Default)]
|
||||
pub struct ConsoleLoggerConfig {
|
||||
pub level: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Serialize, PartialEq)]
|
||||
pub struct VpnPortalConfig {
|
||||
pub client_cidr: cidr::Ipv4Cidr,
|
||||
pub wireguard_listen: SocketAddr,
|
||||
}
|
||||
|
||||
// Flags is used to control the behavior of the program
|
||||
#[derive(derivative::Derivative, Deserialize, Serialize)]
|
||||
#[derivative(Debug, Clone, PartialEq, Default)]
|
||||
pub struct Flags {
|
||||
#[derivative(Default(value = "\"tcp\".to_string()"))]
|
||||
pub default_protocol: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Serialize, PartialEq)]
|
||||
struct Config {
|
||||
netns: Option<String>,
|
||||
instance_name: Option<String>,
|
||||
instance_id: Option<String>,
|
||||
ipv4: Option<String>,
|
||||
network_identity: Option<NetworkIdentity>,
|
||||
listeners: Option<Vec<url::Url>>,
|
||||
|
||||
peer: Option<Vec<PeerConfig>>,
|
||||
proxy_network: Option<Vec<NetworkConfig>>,
|
||||
|
||||
file_logger: Option<FileLoggerConfig>,
|
||||
console_logger: Option<ConsoleLoggerConfig>,
|
||||
|
||||
rpc_portal: Option<SocketAddr>,
|
||||
|
||||
vpn_portal_config: Option<VpnPortalConfig>,
|
||||
|
||||
flags: Option<Flags>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct TomlConfigLoader {
|
||||
config: Arc<Mutex<Config>>,
|
||||
}
|
||||
|
||||
impl Default for TomlConfigLoader {
|
||||
fn default() -> Self {
|
||||
TomlConfigLoader::new_from_str("").unwrap()
|
||||
}
|
||||
}
|
||||
|
||||
impl TomlConfigLoader {
|
||||
pub fn new_from_str(config_str: &str) -> Result<Self, anyhow::Error> {
|
||||
let config = toml::de::from_str::<Config>(config_str).with_context(|| {
|
||||
format!(
|
||||
"failed to parse config file: {}\n{}",
|
||||
config_str, config_str
|
||||
)
|
||||
})?;
|
||||
Ok(TomlConfigLoader {
|
||||
config: Arc::new(Mutex::new(config)),
|
||||
})
|
||||
}
|
||||
|
||||
pub fn new(config_path: &str) -> Result<Self, anyhow::Error> {
|
||||
let config_str = std::fs::read_to_string(config_path)
|
||||
.with_context(|| format!("failed to read config file: {}", config_path))?;
|
||||
Self::new_from_str(&config_str)
|
||||
}
|
||||
}
|
||||
|
||||
impl ConfigLoader for TomlConfigLoader {
|
||||
fn get_inst_name(&self) -> String {
|
||||
self.config
|
||||
.lock()
|
||||
.unwrap()
|
||||
.instance_name
|
||||
.clone()
|
||||
.unwrap_or("default".to_string())
|
||||
}
|
||||
|
||||
fn set_inst_name(&self, name: String) {
|
||||
self.config.lock().unwrap().instance_name = Some(name);
|
||||
}
|
||||
|
||||
fn get_netns(&self) -> Option<String> {
|
||||
self.config.lock().unwrap().netns.clone()
|
||||
}
|
||||
|
||||
fn set_netns(&self, ns: Option<String>) {
|
||||
self.config.lock().unwrap().netns = ns;
|
||||
}
|
||||
|
||||
fn get_ipv4(&self) -> Option<std::net::Ipv4Addr> {
|
||||
let locked_config = self.config.lock().unwrap();
|
||||
locked_config
|
||||
.ipv4
|
||||
.as_ref()
|
||||
.map(|s| s.parse().ok())
|
||||
.flatten()
|
||||
}
|
||||
|
||||
fn set_ipv4(&self, addr: std::net::Ipv4Addr) {
|
||||
self.config.lock().unwrap().ipv4 = Some(addr.to_string());
|
||||
}
|
||||
|
||||
fn add_proxy_cidr(&self, cidr: cidr::IpCidr) {
|
||||
let mut locked_config = self.config.lock().unwrap();
|
||||
if locked_config.proxy_network.is_none() {
|
||||
locked_config.proxy_network = Some(vec![]);
|
||||
}
|
||||
let cidr_str = cidr.to_string();
|
||||
// insert if no duplicate
|
||||
if !locked_config
|
||||
.proxy_network
|
||||
.as_ref()
|
||||
.unwrap()
|
||||
.iter()
|
||||
.any(|c| c.cidr == cidr_str)
|
||||
{
|
||||
locked_config
|
||||
.proxy_network
|
||||
.as_mut()
|
||||
.unwrap()
|
||||
.push(NetworkConfig {
|
||||
cidr: cidr_str,
|
||||
allow: None,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
fn remove_proxy_cidr(&self, cidr: cidr::IpCidr) {
|
||||
let mut locked_config = self.config.lock().unwrap();
|
||||
if let Some(proxy_cidrs) = &mut locked_config.proxy_network {
|
||||
let cidr_str = cidr.to_string();
|
||||
proxy_cidrs.retain(|c| c.cidr != cidr_str);
|
||||
}
|
||||
}
|
||||
|
||||
fn get_proxy_cidrs(&self) -> Vec<cidr::IpCidr> {
|
||||
self.config
|
||||
.lock()
|
||||
.unwrap()
|
||||
.proxy_network
|
||||
.as_ref()
|
||||
.map(|v| {
|
||||
v.iter()
|
||||
.map(|c| c.cidr.parse().unwrap())
|
||||
.collect::<Vec<cidr::IpCidr>>()
|
||||
})
|
||||
.unwrap_or_default()
|
||||
}
|
||||
|
||||
fn get_id(&self) -> uuid::Uuid {
|
||||
let mut locked_config = self.config.lock().unwrap();
|
||||
if locked_config.instance_id.is_none() {
|
||||
let id = uuid::Uuid::new_v4();
|
||||
locked_config.instance_id = Some(id.to_string());
|
||||
id
|
||||
} else {
|
||||
uuid::Uuid::parse_str(locked_config.instance_id.as_ref().unwrap())
|
||||
.with_context(|| {
|
||||
format!(
|
||||
"failed to parse instance id as uuid: {}, you can use this id: {}",
|
||||
locked_config.instance_id.as_ref().unwrap(),
|
||||
uuid::Uuid::new_v4()
|
||||
)
|
||||
})
|
||||
.unwrap()
|
||||
}
|
||||
}
|
||||
|
||||
fn get_network_identity(&self) -> NetworkIdentity {
|
||||
self.config
|
||||
.lock()
|
||||
.unwrap()
|
||||
.network_identity
|
||||
.clone()
|
||||
.unwrap_or_else(NetworkIdentity::default)
|
||||
}
|
||||
|
||||
fn set_network_identity(&self, identity: NetworkIdentity) {
|
||||
self.config.lock().unwrap().network_identity = Some(identity);
|
||||
}
|
||||
|
||||
fn get_listener_uris(&self) -> Vec<url::Url> {
|
||||
self.config
|
||||
.lock()
|
||||
.unwrap()
|
||||
.listeners
|
||||
.clone()
|
||||
.unwrap_or_default()
|
||||
}
|
||||
|
||||
fn get_file_logger_config(&self) -> FileLoggerConfig {
|
||||
self.config
|
||||
.lock()
|
||||
.unwrap()
|
||||
.file_logger
|
||||
.clone()
|
||||
.unwrap_or_default()
|
||||
}
|
||||
|
||||
fn set_file_logger_config(&self, config: FileLoggerConfig) {
|
||||
self.config.lock().unwrap().file_logger = Some(config);
|
||||
}
|
||||
|
||||
fn get_console_logger_config(&self) -> ConsoleLoggerConfig {
|
||||
self.config
|
||||
.lock()
|
||||
.unwrap()
|
||||
.console_logger
|
||||
.clone()
|
||||
.unwrap_or_default()
|
||||
}
|
||||
|
||||
fn set_console_logger_config(&self, config: ConsoleLoggerConfig) {
|
||||
self.config.lock().unwrap().console_logger = Some(config);
|
||||
}
|
||||
|
||||
fn get_peers(&self) -> Vec<PeerConfig> {
|
||||
self.config.lock().unwrap().peer.clone().unwrap_or_default()
|
||||
}
|
||||
|
||||
fn set_peers(&self, peers: Vec<PeerConfig>) {
|
||||
self.config.lock().unwrap().peer = Some(peers);
|
||||
}
|
||||
|
||||
fn get_listeners(&self) -> Vec<url::Url> {
|
||||
self.config
|
||||
.lock()
|
||||
.unwrap()
|
||||
.listeners
|
||||
.clone()
|
||||
.unwrap_or_default()
|
||||
}
|
||||
|
||||
fn set_listeners(&self, listeners: Vec<url::Url>) {
|
||||
self.config.lock().unwrap().listeners = Some(listeners);
|
||||
}
|
||||
|
||||
fn get_rpc_portal(&self) -> Option<SocketAddr> {
|
||||
self.config.lock().unwrap().rpc_portal
|
||||
}
|
||||
|
||||
fn set_rpc_portal(&self, addr: SocketAddr) {
|
||||
self.config.lock().unwrap().rpc_portal = Some(addr);
|
||||
}
|
||||
|
||||
fn get_vpn_portal_config(&self) -> Option<VpnPortalConfig> {
|
||||
self.config.lock().unwrap().vpn_portal_config.clone()
|
||||
}
|
||||
fn set_vpn_portal_config(&self, config: VpnPortalConfig) {
|
||||
self.config.lock().unwrap().vpn_portal_config = Some(config);
|
||||
}
|
||||
|
||||
fn get_flags(&self) -> Flags {
|
||||
self.config
|
||||
.lock()
|
||||
.unwrap()
|
||||
.flags
|
||||
.clone()
|
||||
.unwrap_or_default()
|
||||
}
|
||||
|
||||
fn set_flags(&self, flags: Flags) {
|
||||
self.config.lock().unwrap().flags = Some(flags);
|
||||
}
|
||||
|
||||
fn dump(&self) -> String {
|
||||
toml::to_string_pretty(&*self.config.lock().unwrap()).unwrap()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
pub mod tests {
|
||||
use super::*;
|
||||
|
||||
#[tokio::test]
|
||||
async fn full_example_test() {
|
||||
let config_str = r#"
|
||||
instance_name = "default"
|
||||
instance_id = "87ede5a2-9c3d-492d-9bbe-989b9d07e742"
|
||||
ipv4 = "10.144.144.10"
|
||||
listeners = [ "tcp://0.0.0.0:11010", "udp://0.0.0.0:11010" ]
|
||||
|
||||
[network_identity]
|
||||
network_name = "default"
|
||||
network_secret = ""
|
||||
|
||||
[[peer]]
|
||||
uri = "tcp://public.kkrainbow.top:11010"
|
||||
|
||||
[[peer]]
|
||||
uri = "udp://192.168.94.33:11010"
|
||||
|
||||
[[proxy_network]]
|
||||
cidr = "10.147.223.0/24"
|
||||
allow = ["tcp", "udp", "icmp"]
|
||||
|
||||
[[proxy_network]]
|
||||
cidr = "10.1.1.0/24"
|
||||
allow = ["tcp", "icmp"]
|
||||
|
||||
[file_logger]
|
||||
level = "info"
|
||||
file = "easytier"
|
||||
dir = "/tmp/easytier"
|
||||
|
||||
[console_logger]
|
||||
level = "warn"
|
||||
"#;
|
||||
let ret = TomlConfigLoader::new_from_str(config_str);
|
||||
if let Err(e) = &ret {
|
||||
println!("{}", e);
|
||||
} else {
|
||||
println!("{:?}", ret.as_ref().unwrap());
|
||||
}
|
||||
assert!(ret.is_ok());
|
||||
|
||||
let ret = ret.unwrap();
|
||||
assert_eq!("10.144.144.10", ret.get_ipv4().unwrap().to_string());
|
||||
|
||||
assert_eq!(
|
||||
vec!["tcp://0.0.0.0:11010", "udp://0.0.0.0:11010"],
|
||||
ret.get_listener_uris()
|
||||
.iter()
|
||||
.map(|u| u.to_string())
|
||||
.collect::<Vec<String>>()
|
||||
);
|
||||
|
||||
println!("{}", ret.dump());
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,24 @@
|
||||
macro_rules! define_global_var {
|
||||
($name:ident, $type:ty, $init:expr) => {
|
||||
pub static $name: once_cell::sync::Lazy<tokio::sync::Mutex<$type>> =
|
||||
once_cell::sync::Lazy::new(|| tokio::sync::Mutex::new($init));
|
||||
};
|
||||
}
|
||||
|
||||
#[macro_export]
|
||||
macro_rules! use_global_var {
|
||||
($name:ident) => {
|
||||
crate::common::constants::$name.lock().await.to_owned()
|
||||
};
|
||||
}
|
||||
|
||||
#[macro_export]
|
||||
macro_rules! set_global_var {
|
||||
($name:ident, $val:expr) => {
|
||||
*crate::common::constants::$name.lock().await = $val
|
||||
};
|
||||
}
|
||||
|
||||
define_global_var!(MANUAL_CONNECTOR_RECONNECT_INTERVAL_MS, u64, 1000);
|
||||
|
||||
pub const UDP_HOLE_PUNCH_CONNECTOR_SERVICE_ID: u32 = 2;
|
||||
@@ -0,0 +1,45 @@
|
||||
use std::{io, result};
|
||||
|
||||
use thiserror::Error;
|
||||
|
||||
use crate::tunnels;
|
||||
|
||||
use super::PeerId;
|
||||
|
||||
#[derive(Error, Debug)]
|
||||
pub enum Error {
|
||||
#[error("io error")]
|
||||
IOError(#[from] io::Error),
|
||||
#[error("rust tun error {0}")]
|
||||
TunError(#[from] tun::Error),
|
||||
#[error("tunnel error {0}")]
|
||||
TunnelError(#[from] tunnels::TunnelError),
|
||||
#[error("Peer has no conn, PeerId: {0}")]
|
||||
PeerNoConnectionError(PeerId),
|
||||
#[error("RouteError: {0:?}")]
|
||||
RouteError(Option<String>),
|
||||
#[error("Not found")]
|
||||
NotFound,
|
||||
#[error("Invalid Url: {0}")]
|
||||
InvalidUrl(String),
|
||||
#[error("Shell Command error: {0}")]
|
||||
ShellCommandError(String),
|
||||
// #[error("Rpc listen error: {0}")]
|
||||
// RpcListenError(String),
|
||||
#[error("Rpc connect error: {0}")]
|
||||
RpcConnectError(String),
|
||||
#[error("Rpc error: {0}")]
|
||||
RpcClientError(#[from] tarpc::client::RpcError),
|
||||
#[error("Timeout error: {0}")]
|
||||
Timeout(#[from] tokio::time::error::Elapsed),
|
||||
#[error("url in blacklist")]
|
||||
UrlInBlacklist,
|
||||
#[error("unknown data store error")]
|
||||
Unknown,
|
||||
#[error("anyhow error: {0}")]
|
||||
AnyhowError(#[from] anyhow::Error),
|
||||
}
|
||||
|
||||
pub type Result<T> = result::Result<T, Error>;
|
||||
|
||||
// impl From for std::
|
||||
@@ -0,0 +1,256 @@
|
||||
use std::sync::{Arc, Mutex};
|
||||
|
||||
use crate::rpc::PeerConnInfo;
|
||||
use crossbeam::atomic::AtomicCell;
|
||||
|
||||
use super::{
|
||||
config::{ConfigLoader, Flags},
|
||||
netns::NetNS,
|
||||
network::IPCollector,
|
||||
stun::{StunInfoCollector, StunInfoCollectorTrait},
|
||||
PeerId,
|
||||
};
|
||||
|
||||
pub type NetworkIdentity = crate::common::config::NetworkIdentity;
|
||||
|
||||
#[derive(Debug, Clone, PartialEq)]
|
||||
pub enum GlobalCtxEvent {
|
||||
TunDeviceReady(String),
|
||||
|
||||
PeerAdded(PeerId),
|
||||
PeerRemoved(PeerId),
|
||||
PeerConnAdded(PeerConnInfo),
|
||||
PeerConnRemoved(PeerConnInfo),
|
||||
|
||||
ListenerAdded(url::Url),
|
||||
ConnectionAccepted(String, String), // (local url, remote url)
|
||||
ConnectionError(String, String, String), // (local url, remote url, error message)
|
||||
|
||||
Connecting(url::Url),
|
||||
ConnectError(String, String), // (dst, error message)
|
||||
|
||||
VpnPortalClientConnected(String, String), // (portal, client ip)
|
||||
VpnPortalClientDisconnected(String, String), // (portal, client ip)
|
||||
}
|
||||
|
||||
type EventBus = tokio::sync::broadcast::Sender<GlobalCtxEvent>;
|
||||
type EventBusSubscriber = tokio::sync::broadcast::Receiver<GlobalCtxEvent>;
|
||||
|
||||
pub struct GlobalCtx {
|
||||
pub inst_name: String,
|
||||
pub id: uuid::Uuid,
|
||||
pub config: Box<dyn ConfigLoader>,
|
||||
pub net_ns: NetNS,
|
||||
pub network: NetworkIdentity,
|
||||
|
||||
event_bus: EventBus,
|
||||
|
||||
cached_ipv4: AtomicCell<Option<std::net::Ipv4Addr>>,
|
||||
cached_proxy_cidrs: AtomicCell<Option<Vec<cidr::IpCidr>>>,
|
||||
|
||||
ip_collector: Arc<IPCollector>,
|
||||
|
||||
hotname: AtomicCell<Option<String>>,
|
||||
|
||||
stun_info_collection: Box<dyn StunInfoCollectorTrait>,
|
||||
|
||||
running_listeners: Mutex<Vec<url::Url>>,
|
||||
}
|
||||
|
||||
impl std::fmt::Debug for GlobalCtx {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
f.debug_struct("GlobalCtx")
|
||||
.field("inst_name", &self.inst_name)
|
||||
.field("id", &self.id)
|
||||
.field("net_ns", &self.net_ns.name())
|
||||
.field("event_bus", &"EventBus")
|
||||
.field("ipv4", &self.cached_ipv4)
|
||||
.finish()
|
||||
}
|
||||
}
|
||||
|
||||
pub type ArcGlobalCtx = std::sync::Arc<GlobalCtx>;
|
||||
|
||||
impl GlobalCtx {
|
||||
pub fn new(config_fs: impl ConfigLoader + 'static + Send + Sync) -> Self {
|
||||
let id = config_fs.get_id();
|
||||
let network = config_fs.get_network_identity();
|
||||
let net_ns = NetNS::new(config_fs.get_netns());
|
||||
|
||||
let (event_bus, _) = tokio::sync::broadcast::channel(100);
|
||||
|
||||
GlobalCtx {
|
||||
inst_name: config_fs.get_inst_name(),
|
||||
id,
|
||||
config: Box::new(config_fs),
|
||||
net_ns: net_ns.clone(),
|
||||
network,
|
||||
|
||||
event_bus,
|
||||
cached_ipv4: AtomicCell::new(None),
|
||||
cached_proxy_cidrs: AtomicCell::new(None),
|
||||
|
||||
ip_collector: Arc::new(IPCollector::new(net_ns)),
|
||||
|
||||
hotname: AtomicCell::new(None),
|
||||
|
||||
stun_info_collection: Box::new(StunInfoCollector::new_with_default_servers()),
|
||||
|
||||
running_listeners: Mutex::new(Vec::new()),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn subscribe(&self) -> EventBusSubscriber {
|
||||
self.event_bus.subscribe()
|
||||
}
|
||||
|
||||
pub fn issue_event(&self, event: GlobalCtxEvent) {
|
||||
if self.event_bus.receiver_count() != 0 {
|
||||
self.event_bus.send(event).unwrap();
|
||||
} else {
|
||||
log::warn!("No subscriber for event: {:?}", event);
|
||||
}
|
||||
}
|
||||
|
||||
pub fn get_ipv4(&self) -> Option<std::net::Ipv4Addr> {
|
||||
if let Some(ret) = self.cached_ipv4.load() {
|
||||
return Some(ret);
|
||||
}
|
||||
let addr = self.config.get_ipv4();
|
||||
self.cached_ipv4.store(addr.clone());
|
||||
return addr;
|
||||
}
|
||||
|
||||
pub fn set_ipv4(&mut self, addr: std::net::Ipv4Addr) {
|
||||
self.config.set_ipv4(addr);
|
||||
self.cached_ipv4.store(None);
|
||||
}
|
||||
|
||||
pub fn add_proxy_cidr(&self, cidr: cidr::IpCidr) -> Result<(), std::io::Error> {
|
||||
self.config.add_proxy_cidr(cidr);
|
||||
self.cached_proxy_cidrs.store(None);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn remove_proxy_cidr(&self, cidr: cidr::IpCidr) -> Result<(), std::io::Error> {
|
||||
self.config.remove_proxy_cidr(cidr);
|
||||
self.cached_proxy_cidrs.store(None);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn get_proxy_cidrs(&self) -> Vec<cidr::IpCidr> {
|
||||
if let Some(proxy_cidrs) = self.cached_proxy_cidrs.take() {
|
||||
self.cached_proxy_cidrs.store(Some(proxy_cidrs.clone()));
|
||||
return proxy_cidrs;
|
||||
}
|
||||
|
||||
let ret = self.config.get_proxy_cidrs();
|
||||
self.cached_proxy_cidrs.store(Some(ret.clone()));
|
||||
ret
|
||||
}
|
||||
|
||||
pub fn get_id(&self) -> uuid::Uuid {
|
||||
self.config.get_id()
|
||||
}
|
||||
|
||||
pub fn get_network_identity(&self) -> NetworkIdentity {
|
||||
self.config.get_network_identity()
|
||||
}
|
||||
|
||||
pub fn get_ip_collector(&self) -> Arc<IPCollector> {
|
||||
self.ip_collector.clone()
|
||||
}
|
||||
|
||||
pub fn get_hostname(&self) -> Option<String> {
|
||||
if let Some(hostname) = self.hotname.take() {
|
||||
self.hotname.store(Some(hostname.clone()));
|
||||
return Some(hostname);
|
||||
}
|
||||
|
||||
let hostname = gethostname::gethostname().to_string_lossy().to_string();
|
||||
self.hotname.store(Some(hostname.clone()));
|
||||
return Some(hostname);
|
||||
}
|
||||
|
||||
pub fn get_stun_info_collector(&self) -> impl StunInfoCollectorTrait + '_ {
|
||||
self.stun_info_collection.as_ref()
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
pub fn replace_stun_info_collector(&self, collector: Box<dyn StunInfoCollectorTrait>) {
|
||||
// force replace the stun_info_collection without mut and drop the old one
|
||||
let ptr = &self.stun_info_collection as *const Box<dyn StunInfoCollectorTrait>;
|
||||
let ptr = ptr as *mut Box<dyn StunInfoCollectorTrait>;
|
||||
unsafe {
|
||||
std::ptr::drop_in_place(ptr);
|
||||
#[allow(invalid_reference_casting)]
|
||||
std::ptr::write(ptr, collector);
|
||||
}
|
||||
}
|
||||
|
||||
pub fn get_running_listeners(&self) -> Vec<url::Url> {
|
||||
self.running_listeners.lock().unwrap().clone()
|
||||
}
|
||||
|
||||
pub fn add_running_listener(&self, url: url::Url) {
|
||||
self.running_listeners.lock().unwrap().push(url);
|
||||
}
|
||||
|
||||
pub fn get_vpn_portal_cidr(&self) -> Option<cidr::Ipv4Cidr> {
|
||||
self.config.get_vpn_portal_config().map(|x| x.client_cidr)
|
||||
}
|
||||
|
||||
pub fn get_flags(&self) -> Flags {
|
||||
self.config.get_flags()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
pub mod tests {
|
||||
use crate::common::{config::TomlConfigLoader, new_peer_id};
|
||||
|
||||
use super::*;
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_global_ctx() {
|
||||
let config = TomlConfigLoader::default();
|
||||
let global_ctx = GlobalCtx::new(config);
|
||||
|
||||
let mut subscriber = global_ctx.subscribe();
|
||||
let peer_id = new_peer_id();
|
||||
global_ctx.issue_event(GlobalCtxEvent::PeerAdded(peer_id.clone()));
|
||||
global_ctx.issue_event(GlobalCtxEvent::PeerRemoved(peer_id.clone()));
|
||||
global_ctx.issue_event(GlobalCtxEvent::PeerConnAdded(PeerConnInfo::default()));
|
||||
global_ctx.issue_event(GlobalCtxEvent::PeerConnRemoved(PeerConnInfo::default()));
|
||||
|
||||
assert_eq!(
|
||||
subscriber.recv().await.unwrap(),
|
||||
GlobalCtxEvent::PeerAdded(peer_id.clone())
|
||||
);
|
||||
assert_eq!(
|
||||
subscriber.recv().await.unwrap(),
|
||||
GlobalCtxEvent::PeerRemoved(peer_id.clone())
|
||||
);
|
||||
assert_eq!(
|
||||
subscriber.recv().await.unwrap(),
|
||||
GlobalCtxEvent::PeerConnAdded(PeerConnInfo::default())
|
||||
);
|
||||
assert_eq!(
|
||||
subscriber.recv().await.unwrap(),
|
||||
GlobalCtxEvent::PeerConnRemoved(PeerConnInfo::default())
|
||||
);
|
||||
}
|
||||
|
||||
pub fn get_mock_global_ctx_with_network(
|
||||
network_identy: Option<NetworkIdentity>,
|
||||
) -> ArcGlobalCtx {
|
||||
let config_fs = TomlConfigLoader::default();
|
||||
config_fs.set_inst_name(format!("test_{}", config_fs.get_id()));
|
||||
config_fs.set_network_identity(network_identy.unwrap_or(NetworkIdentity::default()));
|
||||
std::sync::Arc::new(GlobalCtx::new(config_fs))
|
||||
}
|
||||
|
||||
pub fn get_mock_global_ctx() -> ArcGlobalCtx {
|
||||
get_mock_global_ctx_with_network(None)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,358 @@
|
||||
use std::net::Ipv4Addr;
|
||||
|
||||
use async_trait::async_trait;
|
||||
use tokio::process::Command;
|
||||
|
||||
use super::error::Error;
|
||||
|
||||
#[async_trait]
|
||||
pub trait IfConfiguerTrait {
|
||||
async fn add_ipv4_route(
|
||||
&self,
|
||||
name: &str,
|
||||
address: Ipv4Addr,
|
||||
cidr_prefix: u8,
|
||||
) -> Result<(), Error>;
|
||||
async fn remove_ipv4_route(
|
||||
&self,
|
||||
name: &str,
|
||||
address: Ipv4Addr,
|
||||
cidr_prefix: u8,
|
||||
) -> Result<(), Error>;
|
||||
async fn add_ipv4_ip(
|
||||
&self,
|
||||
name: &str,
|
||||
address: Ipv4Addr,
|
||||
cidr_prefix: u8,
|
||||
) -> Result<(), Error>;
|
||||
async fn set_link_status(&self, name: &str, up: bool) -> Result<(), Error>;
|
||||
async fn remove_ip(&self, name: &str, ip: Option<Ipv4Addr>) -> Result<(), Error>;
|
||||
async fn wait_interface_show(&self, _name: &str) -> Result<(), Error> {
|
||||
return Ok(());
|
||||
}
|
||||
}
|
||||
|
||||
fn cidr_to_subnet_mask(prefix_length: u8) -> Ipv4Addr {
|
||||
if prefix_length > 32 {
|
||||
panic!("Invalid CIDR prefix length");
|
||||
}
|
||||
|
||||
let subnet_mask: u32 = (!0u32)
|
||||
.checked_shl(32 - u32::from(prefix_length))
|
||||
.unwrap_or(0);
|
||||
Ipv4Addr::new(
|
||||
((subnet_mask >> 24) & 0xFF) as u8,
|
||||
((subnet_mask >> 16) & 0xFF) as u8,
|
||||
((subnet_mask >> 8) & 0xFF) as u8,
|
||||
(subnet_mask & 0xFF) as u8,
|
||||
)
|
||||
}
|
||||
|
||||
async fn run_shell_cmd(cmd: &str) -> Result<(), Error> {
|
||||
let cmd_out = if cfg!(target_os = "windows") {
|
||||
Command::new("cmd").arg("/C").arg(cmd).output().await?
|
||||
} else {
|
||||
Command::new("sh").arg("-c").arg(cmd).output().await?
|
||||
};
|
||||
let stdout = String::from_utf8_lossy(cmd_out.stdout.as_slice());
|
||||
let stderr = String::from_utf8_lossy(cmd_out.stderr.as_slice());
|
||||
let ec = cmd_out.status.code();
|
||||
let succ = cmd_out.status.success();
|
||||
tracing::info!(?cmd, ?ec, ?succ, ?stdout, ?stderr, "run shell cmd");
|
||||
|
||||
if !cmd_out.status.success() {
|
||||
return Err(Error::ShellCommandError(
|
||||
stdout.to_string() + &stderr.to_string(),
|
||||
));
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub struct MacIfConfiger {}
|
||||
#[async_trait]
|
||||
impl IfConfiguerTrait for MacIfConfiger {
|
||||
async fn add_ipv4_route(
|
||||
&self,
|
||||
name: &str,
|
||||
address: Ipv4Addr,
|
||||
cidr_prefix: u8,
|
||||
) -> Result<(), Error> {
|
||||
run_shell_cmd(
|
||||
format!(
|
||||
"route -n add {} -netmask {} -interface {} -hopcount 7",
|
||||
address,
|
||||
cidr_to_subnet_mask(cidr_prefix),
|
||||
name
|
||||
)
|
||||
.as_str(),
|
||||
)
|
||||
.await
|
||||
}
|
||||
|
||||
async fn remove_ipv4_route(
|
||||
&self,
|
||||
name: &str,
|
||||
address: Ipv4Addr,
|
||||
cidr_prefix: u8,
|
||||
) -> Result<(), Error> {
|
||||
run_shell_cmd(
|
||||
format!(
|
||||
"route -n delete {} -netmask {} -interface {}",
|
||||
address,
|
||||
cidr_to_subnet_mask(cidr_prefix),
|
||||
name
|
||||
)
|
||||
.as_str(),
|
||||
)
|
||||
.await
|
||||
}
|
||||
|
||||
async fn add_ipv4_ip(
|
||||
&self,
|
||||
name: &str,
|
||||
address: Ipv4Addr,
|
||||
cidr_prefix: u8,
|
||||
) -> Result<(), Error> {
|
||||
run_shell_cmd(
|
||||
format!(
|
||||
"ifconfig {} {:?}/{:?} 10.8.8.8 up",
|
||||
name, address, cidr_prefix,
|
||||
)
|
||||
.as_str(),
|
||||
)
|
||||
.await
|
||||
}
|
||||
|
||||
async fn set_link_status(&self, name: &str, up: bool) -> Result<(), Error> {
|
||||
run_shell_cmd(format!("ifconfig {} {}", name, if up { "up" } else { "down" }).as_str())
|
||||
.await
|
||||
}
|
||||
|
||||
async fn remove_ip(&self, name: &str, ip: Option<Ipv4Addr>) -> Result<(), Error> {
|
||||
if ip.is_none() {
|
||||
run_shell_cmd(format!("ifconfig {} inet delete", name).as_str()).await
|
||||
} else {
|
||||
run_shell_cmd(
|
||||
format!("ifconfig {} inet {} delete", name, ip.unwrap().to_string()).as_str(),
|
||||
)
|
||||
.await
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct LinuxIfConfiger {}
|
||||
#[async_trait]
|
||||
impl IfConfiguerTrait for LinuxIfConfiger {
|
||||
async fn add_ipv4_route(
|
||||
&self,
|
||||
name: &str,
|
||||
address: Ipv4Addr,
|
||||
cidr_prefix: u8,
|
||||
) -> Result<(), Error> {
|
||||
run_shell_cmd(
|
||||
format!(
|
||||
"ip route add {}/{} dev {} metric 65535",
|
||||
address, cidr_prefix, name
|
||||
)
|
||||
.as_str(),
|
||||
)
|
||||
.await
|
||||
}
|
||||
|
||||
async fn remove_ipv4_route(
|
||||
&self,
|
||||
name: &str,
|
||||
address: Ipv4Addr,
|
||||
cidr_prefix: u8,
|
||||
) -> Result<(), Error> {
|
||||
run_shell_cmd(format!("ip route del {}/{} dev {}", address, cidr_prefix, name).as_str())
|
||||
.await
|
||||
}
|
||||
|
||||
async fn add_ipv4_ip(
|
||||
&self,
|
||||
name: &str,
|
||||
address: Ipv4Addr,
|
||||
cidr_prefix: u8,
|
||||
) -> Result<(), Error> {
|
||||
run_shell_cmd(format!("ip addr add {:?}/{:?} dev {}", address, cidr_prefix, name).as_str())
|
||||
.await
|
||||
}
|
||||
|
||||
async fn set_link_status(&self, name: &str, up: bool) -> Result<(), Error> {
|
||||
run_shell_cmd(format!("ip link set {} {}", name, if up { "up" } else { "down" }).as_str())
|
||||
.await
|
||||
}
|
||||
|
||||
async fn remove_ip(&self, name: &str, ip: Option<Ipv4Addr>) -> Result<(), Error> {
|
||||
if ip.is_none() {
|
||||
run_shell_cmd(format!("ip addr flush dev {}", name).as_str()).await
|
||||
} else {
|
||||
run_shell_cmd(
|
||||
format!("ip addr del {:?} dev {}", ip.unwrap().to_string(), name).as_str(),
|
||||
)
|
||||
.await
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(target_os = "windows")]
|
||||
pub struct WindowsIfConfiger {}
|
||||
|
||||
#[cfg(target_os = "windows")]
|
||||
impl WindowsIfConfiger {
|
||||
pub fn get_interface_index(name: &str) -> Option<u32> {
|
||||
crate::arch::windows::find_interface_index(name).ok()
|
||||
}
|
||||
|
||||
async fn list_ipv4(name: &str) -> Result<Vec<Ipv4Addr>, Error> {
|
||||
use anyhow::Context;
|
||||
use network_interface::NetworkInterfaceConfig;
|
||||
use std::net::IpAddr;
|
||||
let ret = network_interface::NetworkInterface::show().with_context(|| "show interface")?;
|
||||
let addrs = ret
|
||||
.iter()
|
||||
.filter_map(|x| {
|
||||
if x.name != name {
|
||||
return None;
|
||||
}
|
||||
Some(x.addr.clone())
|
||||
})
|
||||
.flat_map(|x| x)
|
||||
.map(|x| x.ip())
|
||||
.filter_map(|x| {
|
||||
if let IpAddr::V4(ipv4) = x {
|
||||
Some(ipv4)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
Ok(addrs)
|
||||
}
|
||||
|
||||
async fn remove_one_ipv4(name: &str, ip: Ipv4Addr) -> Result<(), Error> {
|
||||
run_shell_cmd(
|
||||
format!(
|
||||
"netsh interface ipv4 delete address {} address={}",
|
||||
name,
|
||||
ip.to_string()
|
||||
)
|
||||
.as_str(),
|
||||
)
|
||||
.await
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(target_os = "windows")]
|
||||
#[async_trait]
|
||||
impl IfConfiguerTrait for WindowsIfConfiger {
|
||||
async fn add_ipv4_route(
|
||||
&self,
|
||||
name: &str,
|
||||
address: Ipv4Addr,
|
||||
cidr_prefix: u8,
|
||||
) -> Result<(), Error> {
|
||||
let Some(idx) = Self::get_interface_index(name) else {
|
||||
return Err(Error::NotFound);
|
||||
};
|
||||
run_shell_cmd(
|
||||
format!(
|
||||
"route ADD {} MASK {} 10.1.1.1 IF {} METRIC 255",
|
||||
address,
|
||||
cidr_to_subnet_mask(cidr_prefix),
|
||||
idx
|
||||
)
|
||||
.as_str(),
|
||||
)
|
||||
.await
|
||||
}
|
||||
|
||||
async fn remove_ipv4_route(
|
||||
&self,
|
||||
name: &str,
|
||||
address: Ipv4Addr,
|
||||
cidr_prefix: u8,
|
||||
) -> Result<(), Error> {
|
||||
let Some(idx) = Self::get_interface_index(name) else {
|
||||
return Err(Error::NotFound);
|
||||
};
|
||||
run_shell_cmd(
|
||||
format!(
|
||||
"route DELETE {} MASK {} IF {}",
|
||||
address,
|
||||
cidr_to_subnet_mask(cidr_prefix),
|
||||
idx
|
||||
)
|
||||
.as_str(),
|
||||
)
|
||||
.await
|
||||
}
|
||||
|
||||
async fn add_ipv4_ip(
|
||||
&self,
|
||||
name: &str,
|
||||
address: Ipv4Addr,
|
||||
cidr_prefix: u8,
|
||||
) -> Result<(), Error> {
|
||||
run_shell_cmd(
|
||||
format!(
|
||||
"netsh interface ipv4 add address {} address={} mask={}",
|
||||
name,
|
||||
address,
|
||||
cidr_to_subnet_mask(cidr_prefix)
|
||||
)
|
||||
.as_str(),
|
||||
)
|
||||
.await
|
||||
}
|
||||
|
||||
async fn set_link_status(&self, name: &str, up: bool) -> Result<(), Error> {
|
||||
run_shell_cmd(
|
||||
format!(
|
||||
"netsh interface set interface {} {}",
|
||||
name,
|
||||
if up { "enable" } else { "disable" }
|
||||
)
|
||||
.as_str(),
|
||||
)
|
||||
.await
|
||||
}
|
||||
|
||||
async fn remove_ip(&self, name: &str, ip: Option<Ipv4Addr>) -> Result<(), Error> {
|
||||
if ip.is_none() {
|
||||
for ip in Self::list_ipv4(name).await?.iter() {
|
||||
Self::remove_one_ipv4(name, *ip).await?;
|
||||
}
|
||||
Ok(())
|
||||
} else {
|
||||
Self::remove_one_ipv4(name, ip.unwrap()).await
|
||||
}
|
||||
}
|
||||
|
||||
async fn wait_interface_show(&self, name: &str) -> Result<(), Error> {
|
||||
Ok(
|
||||
tokio::time::timeout(std::time::Duration::from_secs(10), async move {
|
||||
loop {
|
||||
if let Some(idx) = Self::get_interface_index(name) {
|
||||
tracing::info!(?name, ?idx, "Interface found");
|
||||
break;
|
||||
}
|
||||
tokio::time::sleep(std::time::Duration::from_millis(100)).await;
|
||||
}
|
||||
Ok::<(), Error>(())
|
||||
})
|
||||
.await??,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(target_os = "macos")]
|
||||
pub type IfConfiger = MacIfConfiger;
|
||||
|
||||
#[cfg(target_os = "linux")]
|
||||
pub type IfConfiger = LinuxIfConfiger;
|
||||
|
||||
#[cfg(target_os = "windows")]
|
||||
pub type IfConfiger = WindowsIfConfiger;
|
||||
@@ -0,0 +1,120 @@
|
||||
use std::{
|
||||
fmt::Debug,
|
||||
future,
|
||||
sync::{Arc, Mutex},
|
||||
};
|
||||
use tokio::task::JoinSet;
|
||||
use tracing::Instrument;
|
||||
|
||||
pub mod config;
|
||||
pub mod constants;
|
||||
pub mod error;
|
||||
pub mod global_ctx;
|
||||
pub mod ifcfg;
|
||||
pub mod netns;
|
||||
pub mod network;
|
||||
pub mod rkyv_util;
|
||||
pub mod stun;
|
||||
pub mod stun_codec_ext;
|
||||
|
||||
pub fn get_logger_timer<F: time::formatting::Formattable>(
|
||||
format: F,
|
||||
) -> tracing_subscriber::fmt::time::OffsetTime<F> {
|
||||
unsafe {
|
||||
time::util::local_offset::set_soundness(time::util::local_offset::Soundness::Unsound)
|
||||
};
|
||||
let local_offset = time::UtcOffset::current_local_offset()
|
||||
.unwrap_or(time::UtcOffset::from_whole_seconds(0).unwrap());
|
||||
tracing_subscriber::fmt::time::OffsetTime::new(local_offset, format)
|
||||
}
|
||||
|
||||
pub fn get_logger_timer_rfc3339(
|
||||
) -> tracing_subscriber::fmt::time::OffsetTime<time::format_description::well_known::Rfc3339> {
|
||||
get_logger_timer(time::format_description::well_known::Rfc3339)
|
||||
}
|
||||
|
||||
pub type PeerId = u32;
|
||||
|
||||
pub fn new_peer_id() -> PeerId {
|
||||
rand::random()
|
||||
}
|
||||
|
||||
pub fn join_joinset_background<T: Debug + Send + Sync + 'static>(
|
||||
js: Arc<Mutex<JoinSet<T>>>,
|
||||
origin: String,
|
||||
) {
|
||||
let js = Arc::downgrade(&js);
|
||||
tokio::spawn(
|
||||
async move {
|
||||
loop {
|
||||
tokio::time::sleep(std::time::Duration::from_secs(1)).await;
|
||||
if js.weak_count() == 0 {
|
||||
tracing::info!("joinset task exit");
|
||||
break;
|
||||
}
|
||||
|
||||
future::poll_fn(|cx| {
|
||||
tracing::debug!("try join joinset tasks");
|
||||
let Some(js) = js.upgrade() else {
|
||||
return std::task::Poll::Ready(());
|
||||
};
|
||||
|
||||
let mut js = js.lock().unwrap();
|
||||
while !js.is_empty() {
|
||||
let ret = js.poll_join_next(cx);
|
||||
if ret.is_pending() {
|
||||
return std::task::Poll::Pending;
|
||||
}
|
||||
}
|
||||
|
||||
std::task::Poll::Ready(())
|
||||
})
|
||||
.await;
|
||||
}
|
||||
}
|
||||
.instrument(tracing::info_span!(
|
||||
"join_joinset_background",
|
||||
origin = origin
|
||||
)),
|
||||
);
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_join_joinset_backgroud() {
|
||||
let js = Arc::new(Mutex::new(JoinSet::<()>::new()));
|
||||
join_joinset_background(js.clone(), "TEST".to_owned());
|
||||
js.try_lock().unwrap().spawn(async {
|
||||
tokio::time::sleep(std::time::Duration::from_secs(1)).await;
|
||||
});
|
||||
tokio::time::sleep(std::time::Duration::from_secs(2)).await;
|
||||
assert!(js.try_lock().unwrap().is_empty());
|
||||
|
||||
for _ in 0..5 {
|
||||
js.try_lock().unwrap().spawn(async {
|
||||
tokio::time::sleep(std::time::Duration::from_secs(3)).await;
|
||||
});
|
||||
tokio::task::yield_now().await;
|
||||
}
|
||||
|
||||
tokio::time::sleep(std::time::Duration::from_secs(2)).await;
|
||||
|
||||
for _ in 0..5 {
|
||||
js.try_lock().unwrap().spawn(async {
|
||||
tokio::time::sleep(std::time::Duration::from_secs(1)).await;
|
||||
});
|
||||
tokio::task::yield_now().await;
|
||||
}
|
||||
|
||||
tokio::time::sleep(std::time::Duration::from_secs(2)).await;
|
||||
assert!(js.try_lock().unwrap().is_empty());
|
||||
|
||||
let weak_js = Arc::downgrade(&js);
|
||||
drop(js);
|
||||
tokio::time::sleep(std::time::Duration::from_secs(2)).await;
|
||||
assert_eq!(weak_js.weak_count(), 0);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,114 @@
|
||||
use futures::Future;
|
||||
|
||||
#[cfg(target_os = "linux")]
|
||||
use nix::sched::{setns, CloneFlags};
|
||||
#[cfg(target_os = "linux")]
|
||||
use std::os::fd::AsFd;
|
||||
|
||||
pub struct NetNSGuard {
|
||||
#[cfg(target_os = "linux")]
|
||||
old_ns: Option<std::fs::File>,
|
||||
}
|
||||
|
||||
pub static ROOT_NETNS_NAME: &str = "_root_ns";
|
||||
|
||||
#[cfg(target_os = "linux")]
|
||||
impl NetNSGuard {
|
||||
pub fn new(ns: Option<String>) -> Box<Self> {
|
||||
let old_ns = if ns.is_some() {
|
||||
let old_ns = if cfg!(target_os = "linux") {
|
||||
Some(std::fs::File::open("/proc/self/ns/net").unwrap())
|
||||
} else {
|
||||
None
|
||||
};
|
||||
Self::switch_ns(ns);
|
||||
old_ns
|
||||
} else {
|
||||
None
|
||||
};
|
||||
Box::new(NetNSGuard { old_ns })
|
||||
}
|
||||
|
||||
fn switch_ns(name: Option<String>) {
|
||||
if name.is_none() {
|
||||
return;
|
||||
}
|
||||
|
||||
let ns_path: String;
|
||||
let name = name.unwrap();
|
||||
if name == ROOT_NETNS_NAME {
|
||||
ns_path = "/proc/1/ns/net".to_string();
|
||||
} else {
|
||||
ns_path = format!("/var/run/netns/{}", name);
|
||||
}
|
||||
|
||||
let ns = std::fs::File::open(ns_path).unwrap();
|
||||
log::info!(
|
||||
"[INIT NS] switching to new ns_name: {:?}, ns_file: {:?}",
|
||||
name,
|
||||
ns
|
||||
);
|
||||
|
||||
setns(ns.as_fd(), CloneFlags::CLONE_NEWNET).unwrap();
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(target_os = "linux")]
|
||||
impl Drop for NetNSGuard {
|
||||
fn drop(&mut self) {
|
||||
if self.old_ns.is_none() {
|
||||
return;
|
||||
}
|
||||
log::info!("[INIT NS] switching back to old ns, ns: {:?}", self.old_ns);
|
||||
setns(
|
||||
self.old_ns.as_ref().unwrap().as_fd(),
|
||||
CloneFlags::CLONE_NEWNET,
|
||||
)
|
||||
.unwrap();
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(not(target_os = "linux"))]
|
||||
impl NetNSGuard {
|
||||
pub fn new(_ns: Option<String>) -> Box<Self> {
|
||||
Box::new(NetNSGuard {})
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct NetNS {
|
||||
name: Option<String>,
|
||||
}
|
||||
|
||||
impl NetNS {
|
||||
pub fn new(name: Option<String>) -> Self {
|
||||
NetNS { name }
|
||||
}
|
||||
|
||||
pub async fn run_async<F, Fut, Ret>(&self, f: F) -> Ret
|
||||
where
|
||||
F: FnOnce() -> Fut,
|
||||
Fut: Future<Output = Ret>,
|
||||
{
|
||||
// TODO: do we really need this lock
|
||||
// let _lock = LOCK.lock().await;
|
||||
let _guard = NetNSGuard::new(self.name.clone());
|
||||
f().await
|
||||
}
|
||||
|
||||
pub fn run<F, Ret>(&self, f: F) -> Ret
|
||||
where
|
||||
F: FnOnce() -> Ret,
|
||||
{
|
||||
let _guard = NetNSGuard::new(self.name.clone());
|
||||
f()
|
||||
}
|
||||
|
||||
pub fn guard(&self) -> Box<NetNSGuard> {
|
||||
NetNSGuard::new(self.name.clone())
|
||||
}
|
||||
|
||||
pub fn name(&self) -> Option<String> {
|
||||
self.name.clone()
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,243 @@
|
||||
use std::{ops::Deref, sync::Arc};
|
||||
|
||||
use crate::rpc::peer::GetIpListResponse;
|
||||
use pnet::datalink::NetworkInterface;
|
||||
use tokio::{
|
||||
sync::{Mutex, RwLock},
|
||||
task::JoinSet,
|
||||
};
|
||||
|
||||
use super::netns::NetNS;
|
||||
|
||||
pub const CACHED_IP_LIST_TIMEOUT_SEC: u64 = 60;
|
||||
|
||||
struct InterfaceFilter {
|
||||
iface: NetworkInterface,
|
||||
}
|
||||
|
||||
#[cfg(target_os = "linux")]
|
||||
impl InterfaceFilter {
|
||||
async fn is_tun_tap_device(&self) -> bool {
|
||||
let path = format!("/sys/class/net/{}/tun_flags", self.iface.name);
|
||||
tokio::fs::metadata(&path).await.is_ok()
|
||||
}
|
||||
|
||||
async fn has_valid_ip(&self) -> bool {
|
||||
self.iface
|
||||
.ips
|
||||
.iter()
|
||||
.map(|ip| ip.ip())
|
||||
.any(|ip| !ip.is_loopback() && !ip.is_unspecified() && !ip.is_multicast())
|
||||
}
|
||||
|
||||
async fn filter_iface(&self) -> bool {
|
||||
tracing::trace!(
|
||||
"filter linux iface: {:?}, is_point_to_point: {}, is_loopback: {}, is_up: {}, is_lower_up: {}, is_tun: {}, has_valid_ip: {}",
|
||||
self.iface,
|
||||
self.iface.is_point_to_point(),
|
||||
self.iface.is_loopback(),
|
||||
self.iface.is_up(),
|
||||
self.iface.is_lower_up(),
|
||||
self.is_tun_tap_device().await,
|
||||
self.has_valid_ip().await
|
||||
);
|
||||
|
||||
!self.iface.is_point_to_point()
|
||||
&& !self.iface.is_loopback()
|
||||
&& self.iface.is_up()
|
||||
&& self.iface.is_lower_up()
|
||||
&& !self.is_tun_tap_device().await
|
||||
&& self.has_valid_ip().await
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(target_os = "macos")]
|
||||
impl InterfaceFilter {
|
||||
async fn is_interface_physical(interface_name: &str) -> bool {
|
||||
let output = tokio::process::Command::new("networksetup")
|
||||
.args(&["-listallhardwareports"])
|
||||
.output()
|
||||
.await
|
||||
.expect("Failed to execute command");
|
||||
|
||||
let stdout = std::str::from_utf8(&output.stdout).expect("Invalid UTF-8");
|
||||
|
||||
let lines: Vec<&str> = stdout.lines().collect();
|
||||
|
||||
for i in 0..lines.len() {
|
||||
let line = lines[i];
|
||||
|
||||
if line.contains("Device:") && line.contains(interface_name) {
|
||||
let next_line = lines[i + 1];
|
||||
if next_line.contains("Virtual Interface") {
|
||||
return false;
|
||||
} else {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
false
|
||||
}
|
||||
|
||||
async fn filter_iface(&self) -> bool {
|
||||
!self.iface.is_point_to_point()
|
||||
&& !self.iface.is_loopback()
|
||||
&& self.iface.is_up()
|
||||
&& Self::is_interface_physical(&self.iface.name).await
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(target_os = "windows")]
|
||||
impl InterfaceFilter {
|
||||
async fn filter_iface(&self) -> bool {
|
||||
tracing::debug!(
|
||||
"iface_name: {:?}, p2p: {:?}, is_up: {:?}, iface: {:?}",
|
||||
self.iface.name,
|
||||
self.iface.is_point_to_point(),
|
||||
self.iface.is_up(),
|
||||
self.iface
|
||||
);
|
||||
!self.iface.is_point_to_point()
|
||||
&& !self.iface.is_loopback()
|
||||
&& self
|
||||
.iface
|
||||
.ips
|
||||
.iter()
|
||||
.map(|ip| ip.ip())
|
||||
.any(|ip| !ip.is_loopback() && !ip.is_unspecified() && !ip.is_multicast())
|
||||
&& self.iface.mac.map(|mac| !mac.is_zero()).unwrap_or(false)
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn local_ipv4() -> std::io::Result<std::net::Ipv4Addr> {
|
||||
let socket = tokio::net::UdpSocket::bind("0.0.0.0:0").await?;
|
||||
socket.connect("8.8.8.8:80").await?;
|
||||
let addr = socket.local_addr()?;
|
||||
match addr.ip() {
|
||||
std::net::IpAddr::V4(ip) => Ok(ip),
|
||||
std::net::IpAddr::V6(_) => Err(std::io::Error::new(
|
||||
std::io::ErrorKind::AddrNotAvailable,
|
||||
"no ipv4 address",
|
||||
)),
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn local_ipv6() -> std::io::Result<std::net::Ipv6Addr> {
|
||||
let socket = tokio::net::UdpSocket::bind("[::]:0").await?;
|
||||
socket
|
||||
.connect("[2001:4860:4860:0000:0000:0000:0000:8888]:80")
|
||||
.await?;
|
||||
let addr = socket.local_addr()?;
|
||||
match addr.ip() {
|
||||
std::net::IpAddr::V6(ip) => Ok(ip),
|
||||
std::net::IpAddr::V4(_) => Err(std::io::Error::new(
|
||||
std::io::ErrorKind::AddrNotAvailable,
|
||||
"no ipv4 address",
|
||||
)),
|
||||
}
|
||||
}
|
||||
|
||||
pub struct IPCollector {
|
||||
cached_ip_list: Arc<RwLock<GetIpListResponse>>,
|
||||
collect_ip_task: Mutex<JoinSet<()>>,
|
||||
net_ns: NetNS,
|
||||
}
|
||||
|
||||
impl IPCollector {
|
||||
pub fn new(net_ns: NetNS) -> Self {
|
||||
Self {
|
||||
cached_ip_list: Arc::new(RwLock::new(GetIpListResponse::new())),
|
||||
collect_ip_task: Mutex::new(JoinSet::new()),
|
||||
net_ns,
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn collect_ip_addrs(&self) -> GetIpListResponse {
|
||||
let mut task = self.collect_ip_task.lock().await;
|
||||
if task.is_empty() {
|
||||
let cached_ip_list = self.cached_ip_list.clone();
|
||||
*cached_ip_list.write().await =
|
||||
Self::do_collect_ip_addrs(false, self.net_ns.clone()).await;
|
||||
let net_ns = self.net_ns.clone();
|
||||
task.spawn(async move {
|
||||
loop {
|
||||
let ip_addrs = Self::do_collect_ip_addrs(true, net_ns.clone()).await;
|
||||
*cached_ip_list.write().await = ip_addrs;
|
||||
tokio::time::sleep(std::time::Duration::from_secs(CACHED_IP_LIST_TIMEOUT_SEC))
|
||||
.await;
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
return self.cached_ip_list.read().await.deref().clone();
|
||||
}
|
||||
|
||||
pub async fn collect_interfaces(net_ns: NetNS) -> Vec<NetworkInterface> {
|
||||
let _g = net_ns.guard();
|
||||
let ifaces = pnet::datalink::interfaces();
|
||||
let mut ret = vec![];
|
||||
for iface in ifaces {
|
||||
let f = InterfaceFilter {
|
||||
iface: iface.clone(),
|
||||
};
|
||||
|
||||
if !f.filter_iface().await {
|
||||
continue;
|
||||
}
|
||||
|
||||
ret.push(iface);
|
||||
}
|
||||
|
||||
ret
|
||||
}
|
||||
|
||||
#[tracing::instrument(skip(net_ns))]
|
||||
async fn do_collect_ip_addrs(with_public: bool, net_ns: NetNS) -> GetIpListResponse {
|
||||
let mut ret = crate::rpc::peer::GetIpListResponse::new();
|
||||
|
||||
if with_public {
|
||||
if let Some(v4_addr) =
|
||||
public_ip::addr_with(public_ip::http::ALL, public_ip::Version::V4).await
|
||||
{
|
||||
ret.public_ipv4 = v4_addr.to_string();
|
||||
}
|
||||
|
||||
if let Some(v6_addr) = public_ip::addr_v6().await {
|
||||
ret.public_ipv6 = v6_addr.to_string();
|
||||
}
|
||||
}
|
||||
|
||||
let ifaces = Self::collect_interfaces(net_ns.clone()).await;
|
||||
let _g = net_ns.guard();
|
||||
for iface in ifaces {
|
||||
for ip in iface.ips {
|
||||
let ip: std::net::IpAddr = ip.ip();
|
||||
if ip.is_loopback() || ip.is_multicast() {
|
||||
continue;
|
||||
}
|
||||
if ip.is_ipv4() {
|
||||
ret.interface_ipv4s.push(ip.to_string());
|
||||
} else if ip.is_ipv6() {
|
||||
ret.interface_ipv6s.push(ip.to_string());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if let Ok(v4_addr) = local_ipv4().await {
|
||||
tracing::trace!("got local ipv4: {}", v4_addr);
|
||||
if !ret.interface_ipv4s.contains(&v4_addr.to_string()) {
|
||||
ret.interface_ipv4s.push(v4_addr.to_string());
|
||||
}
|
||||
}
|
||||
|
||||
if let Ok(v6_addr) = local_ipv6().await {
|
||||
tracing::trace!("got local ipv6: {}", v6_addr);
|
||||
if !ret.interface_ipv6s.contains(&v6_addr.to_string()) {
|
||||
ret.interface_ipv6s.push(v6_addr.to_string());
|
||||
}
|
||||
}
|
||||
|
||||
ret
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,72 @@
|
||||
use rkyv::{
|
||||
string::ArchivedString,
|
||||
validation::{validators::DefaultValidator, CheckTypeError},
|
||||
vec::ArchivedVec,
|
||||
Archive, CheckBytes, Serialize,
|
||||
};
|
||||
use tokio_util::bytes::{Bytes, BytesMut};
|
||||
|
||||
pub fn decode_from_bytes_checked<'a, T: Archive>(
|
||||
bytes: &'a [u8],
|
||||
) -> Result<&'a T::Archived, CheckTypeError<T::Archived, DefaultValidator<'a>>>
|
||||
where
|
||||
T::Archived: CheckBytes<DefaultValidator<'a>>,
|
||||
{
|
||||
rkyv::check_archived_root::<T>(bytes)
|
||||
}
|
||||
|
||||
pub fn decode_from_bytes<'a, T: Archive>(
|
||||
bytes: &'a [u8],
|
||||
) -> Result<&'a T::Archived, CheckTypeError<T::Archived, DefaultValidator<'a>>>
|
||||
where
|
||||
T::Archived: CheckBytes<DefaultValidator<'a>>,
|
||||
{
|
||||
// rkyv::check_archived_root::<T>(bytes)
|
||||
unsafe { Ok(rkyv::archived_root::<T>(bytes)) }
|
||||
}
|
||||
|
||||
// allow deseraial T to Bytes
|
||||
pub fn encode_to_bytes<T, const N: usize>(val: &T) -> Bytes
|
||||
where
|
||||
T: Serialize<rkyv::ser::serializers::AllocSerializer<N>>,
|
||||
{
|
||||
let ret = rkyv::to_bytes::<_, N>(val).unwrap();
|
||||
// let mut r = BytesMut::new();
|
||||
// r.extend_from_slice(&ret);
|
||||
// r.freeze()
|
||||
ret.into_boxed_slice().into()
|
||||
}
|
||||
|
||||
pub fn extract_bytes_from_archived_vec(raw_data: &Bytes, archived_data: &ArchivedVec<u8>) -> Bytes {
|
||||
let ptr_range = archived_data.as_ptr_range();
|
||||
let offset = ptr_range.start as usize - raw_data.as_ptr() as usize;
|
||||
let len = ptr_range.end as usize - ptr_range.start as usize;
|
||||
return raw_data.slice(offset..offset + len);
|
||||
}
|
||||
|
||||
pub fn extract_bytes_from_archived_string(
|
||||
raw_data: &Bytes,
|
||||
archived_data: &ArchivedString,
|
||||
) -> Bytes {
|
||||
let offset = archived_data.as_ptr() as usize - raw_data.as_ptr() as usize;
|
||||
let len = archived_data.len();
|
||||
if offset + len > raw_data.len() {
|
||||
return Bytes::new();
|
||||
}
|
||||
|
||||
return raw_data.slice(offset..offset + archived_data.len());
|
||||
}
|
||||
|
||||
pub fn extract_bytes_mut_from_archived_vec(
|
||||
raw_data: &mut BytesMut,
|
||||
archived_data: &ArchivedVec<u8>,
|
||||
) -> BytesMut {
|
||||
let ptr_range = archived_data.as_ptr_range();
|
||||
let offset = ptr_range.start as usize - raw_data.as_ptr() as usize;
|
||||
let len = ptr_range.end as usize - ptr_range.start as usize;
|
||||
raw_data.split_off(offset).split_to(len)
|
||||
}
|
||||
|
||||
pub fn vec_to_string(vec: Vec<u8>) -> String {
|
||||
unsafe { String::from_utf8_unchecked(vec) }
|
||||
}
|
||||
@@ -0,0 +1,550 @@
|
||||
use std::net::SocketAddr;
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
|
||||
use crate::rpc::{NatType, StunInfo};
|
||||
use anyhow::Context;
|
||||
use crossbeam::atomic::AtomicCell;
|
||||
use tokio::net::{lookup_host, UdpSocket};
|
||||
use tokio::sync::RwLock;
|
||||
use tokio::task::JoinSet;
|
||||
use tracing::Level;
|
||||
|
||||
use bytecodec::{DecodeExt, EncodeExt};
|
||||
use stun_codec::rfc5389::methods::BINDING;
|
||||
use stun_codec::rfc5780::attributes::ChangeRequest;
|
||||
use stun_codec::{Message, MessageClass, MessageDecoder, MessageEncoder};
|
||||
|
||||
use crate::common::error::Error;
|
||||
|
||||
use super::stun_codec_ext::*;
|
||||
|
||||
struct HostResolverIter {
|
||||
hostnames: Vec<String>,
|
||||
ips: Vec<SocketAddr>,
|
||||
}
|
||||
|
||||
impl HostResolverIter {
|
||||
fn new(hostnames: Vec<String>) -> Self {
|
||||
Self {
|
||||
hostnames,
|
||||
ips: vec![],
|
||||
}
|
||||
}
|
||||
|
||||
#[async_recursion::async_recursion]
|
||||
async fn next(&mut self) -> Option<SocketAddr> {
|
||||
if self.ips.is_empty() {
|
||||
if self.hostnames.is_empty() {
|
||||
return None;
|
||||
}
|
||||
|
||||
let host = self.hostnames.remove(0);
|
||||
match lookup_host(&host).await {
|
||||
Ok(ips) => {
|
||||
self.ips = ips.collect();
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::warn!(?host, ?e, "lookup host for stun failed");
|
||||
return self.next().await;
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
Some(self.ips.remove(0))
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
struct BindRequestResponse {
|
||||
source_addr: SocketAddr,
|
||||
send_to_addr: SocketAddr,
|
||||
recv_from_addr: SocketAddr,
|
||||
mapped_socket_addr: Option<SocketAddr>,
|
||||
changed_socket_addr: Option<SocketAddr>,
|
||||
|
||||
ip_changed: bool,
|
||||
port_changed: bool,
|
||||
|
||||
real_ip_changed: bool,
|
||||
real_port_changed: bool,
|
||||
}
|
||||
|
||||
impl BindRequestResponse {
|
||||
pub fn get_mapped_addr_no_check(&self) -> &SocketAddr {
|
||||
self.mapped_socket_addr.as_ref().unwrap()
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct Stun {
|
||||
stun_server: SocketAddr,
|
||||
req_repeat: u8,
|
||||
resp_timeout: Duration,
|
||||
}
|
||||
|
||||
impl Stun {
|
||||
pub fn new(stun_server: SocketAddr) -> Self {
|
||||
Self {
|
||||
stun_server,
|
||||
req_repeat: 5,
|
||||
resp_timeout: Duration::from_millis(3000),
|
||||
}
|
||||
}
|
||||
|
||||
#[tracing::instrument(skip(self, buf))]
|
||||
async fn wait_stun_response<'a, const N: usize>(
|
||||
&self,
|
||||
buf: &'a mut [u8; N],
|
||||
udp: &UdpSocket,
|
||||
tids: &Vec<u128>,
|
||||
expected_ip_changed: bool,
|
||||
expected_port_changed: bool,
|
||||
stun_host: &SocketAddr,
|
||||
) -> Result<(Message<Attribute>, SocketAddr), Error> {
|
||||
let mut now = tokio::time::Instant::now();
|
||||
let deadline = now + self.resp_timeout;
|
||||
|
||||
while now < deadline {
|
||||
let mut udp_buf = [0u8; 1500];
|
||||
let (len, remote_addr) =
|
||||
tokio::time::timeout(deadline - now, udp.recv_from(udp_buf.as_mut_slice()))
|
||||
.await??;
|
||||
now = tokio::time::Instant::now();
|
||||
|
||||
if len < 20 {
|
||||
continue;
|
||||
}
|
||||
|
||||
// TODO:: we cannot borrow `buf` directly in udp recv_from, so we copy it here
|
||||
unsafe { std::ptr::copy(udp_buf.as_ptr(), buf.as_ptr() as *mut u8, len) };
|
||||
|
||||
let mut decoder = MessageDecoder::<Attribute>::new();
|
||||
let Ok(msg) = decoder
|
||||
.decode_from_bytes(&buf[..len])
|
||||
.with_context(|| format!("decode stun msg {:?}", buf))?
|
||||
else {
|
||||
continue;
|
||||
};
|
||||
|
||||
tracing::debug!(b = ?&udp_buf[..len], ?tids, ?remote_addr, ?stun_host, "recv stun response, msg: {:#?}", msg);
|
||||
|
||||
if msg.class() != MessageClass::SuccessResponse
|
||||
|| msg.method() != BINDING
|
||||
|| !tids.contains(&tid_to_u128(&msg.transaction_id()))
|
||||
{
|
||||
continue;
|
||||
}
|
||||
|
||||
// some stun server use changed socket even we don't ask for.
|
||||
if expected_ip_changed && stun_host.ip() == remote_addr.ip() {
|
||||
continue;
|
||||
}
|
||||
|
||||
if expected_port_changed
|
||||
&& stun_host.ip() == remote_addr.ip()
|
||||
&& stun_host.port() == remote_addr.port()
|
||||
{
|
||||
continue;
|
||||
}
|
||||
|
||||
return Ok((msg, remote_addr));
|
||||
}
|
||||
|
||||
Err(Error::Unknown)
|
||||
}
|
||||
|
||||
fn extrace_mapped_addr(msg: &Message<Attribute>) -> Option<SocketAddr> {
|
||||
let mut mapped_addr = None;
|
||||
for x in msg.attributes() {
|
||||
match x {
|
||||
Attribute::MappedAddress(addr) => {
|
||||
if mapped_addr.is_none() {
|
||||
let _ = mapped_addr.insert(addr.address());
|
||||
}
|
||||
}
|
||||
Attribute::XorMappedAddress(addr) => {
|
||||
if mapped_addr.is_none() {
|
||||
let _ = mapped_addr.insert(addr.address());
|
||||
}
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
mapped_addr
|
||||
}
|
||||
|
||||
fn extract_changed_addr(msg: &Message<Attribute>) -> Option<SocketAddr> {
|
||||
let mut changed_addr = None;
|
||||
for x in msg.attributes() {
|
||||
match x {
|
||||
Attribute::OtherAddress(m) => {
|
||||
if changed_addr.is_none() {
|
||||
let _ = changed_addr.insert(m.address());
|
||||
}
|
||||
}
|
||||
Attribute::ChangedAddress(m) => {
|
||||
if changed_addr.is_none() {
|
||||
let _ = changed_addr.insert(m.address());
|
||||
}
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
changed_addr
|
||||
}
|
||||
|
||||
#[tracing::instrument(ret, err, level = Level::DEBUG)]
|
||||
pub async fn bind_request(
|
||||
&self,
|
||||
source_port: u16,
|
||||
change_ip: bool,
|
||||
change_port: bool,
|
||||
) -> Result<BindRequestResponse, Error> {
|
||||
let stun_host = self.stun_server;
|
||||
let udp = UdpSocket::bind(format!("0.0.0.0:{}", source_port)).await?;
|
||||
|
||||
// repeat req in case of packet loss
|
||||
let mut tids = vec![];
|
||||
for _ in 0..self.req_repeat {
|
||||
let tid = rand::random::<u32>();
|
||||
let mut buf = [0u8; 28];
|
||||
// memset buf
|
||||
unsafe { std::ptr::write_bytes(buf.as_mut_ptr(), 0, buf.len()) };
|
||||
|
||||
let mut message =
|
||||
Message::<Attribute>::new(MessageClass::Request, BINDING, u128_to_tid(tid as u128));
|
||||
message.add_attribute(ChangeRequest::new(change_ip, change_port));
|
||||
|
||||
// Encodes the message
|
||||
let mut encoder = MessageEncoder::new();
|
||||
let msg = encoder
|
||||
.encode_into_bytes(message.clone())
|
||||
.with_context(|| "encode stun message")?;
|
||||
|
||||
tids.push(tid as u128);
|
||||
tracing::trace!(?message, ?msg, tid, "send stun request");
|
||||
udp.send_to(msg.as_slice().into(), &stun_host).await?;
|
||||
}
|
||||
|
||||
tracing::trace!("waiting stun response");
|
||||
let mut buf = [0; 1620];
|
||||
let (msg, recv_addr) = self
|
||||
.wait_stun_response(&mut buf, &udp, &tids, change_ip, change_port, &stun_host)
|
||||
.await?;
|
||||
|
||||
let changed_socket_addr = Self::extract_changed_addr(&msg);
|
||||
let real_ip_changed = stun_host.ip() != recv_addr.ip();
|
||||
let real_port_changed = stun_host.port() != recv_addr.port();
|
||||
|
||||
let resp = BindRequestResponse {
|
||||
source_addr: udp.local_addr()?,
|
||||
send_to_addr: stun_host,
|
||||
recv_from_addr: recv_addr,
|
||||
mapped_socket_addr: Self::extrace_mapped_addr(&msg),
|
||||
changed_socket_addr,
|
||||
ip_changed: change_ip,
|
||||
port_changed: change_port,
|
||||
|
||||
real_ip_changed,
|
||||
real_port_changed,
|
||||
};
|
||||
|
||||
tracing::debug!(
|
||||
?stun_host,
|
||||
?recv_addr,
|
||||
?changed_socket_addr,
|
||||
"finish stun bind request"
|
||||
);
|
||||
|
||||
Ok(resp)
|
||||
}
|
||||
}
|
||||
|
||||
pub struct UdpNatTypeDetector {
|
||||
stun_servers: Vec<String>,
|
||||
}
|
||||
|
||||
impl UdpNatTypeDetector {
|
||||
pub fn new(stun_servers: Vec<String>) -> Self {
|
||||
Self { stun_servers }
|
||||
}
|
||||
|
||||
pub async fn get_udp_nat_type(&self, mut source_port: u16) -> NatType {
|
||||
// Like classic STUN (rfc3489). Detect NAT behavior for UDP.
|
||||
// Modified from rfc3489. Requires at least two STUN servers.
|
||||
let mut ret_test1_1 = None;
|
||||
let mut ret_test1_2 = None;
|
||||
let mut ret_test2 = None;
|
||||
let mut ret_test3 = None;
|
||||
|
||||
if source_port == 0 {
|
||||
let udp = UdpSocket::bind("0.0.0.0:0").await.unwrap();
|
||||
source_port = udp.local_addr().unwrap().port();
|
||||
}
|
||||
|
||||
let mut succ = false;
|
||||
let mut ips = HostResolverIter::new(self.stun_servers.clone());
|
||||
while let Some(server_ip) = ips.next().await {
|
||||
let stun = Stun::new(server_ip.clone());
|
||||
let ret = stun.bind_request(source_port, false, false).await;
|
||||
if ret.is_err() {
|
||||
// Try another STUN server
|
||||
continue;
|
||||
}
|
||||
if ret_test1_1.is_none() {
|
||||
ret_test1_1 = ret.ok();
|
||||
continue;
|
||||
}
|
||||
ret_test1_2 = ret.ok();
|
||||
let ret = stun.bind_request(source_port, true, true).await;
|
||||
if let Ok(resp) = ret {
|
||||
if !resp.real_ip_changed || !resp.real_port_changed {
|
||||
tracing::debug!(
|
||||
?server_ip,
|
||||
?ret,
|
||||
"stun bind request return with unchanged ip and port"
|
||||
);
|
||||
// Try another STUN server
|
||||
continue;
|
||||
}
|
||||
}
|
||||
ret_test2 = ret.ok();
|
||||
ret_test3 = stun.bind_request(source_port, false, true).await.ok();
|
||||
tracing::debug!(?ret_test3, "stun bind request with changed port");
|
||||
succ = true;
|
||||
break;
|
||||
}
|
||||
|
||||
if !succ {
|
||||
return NatType::Unknown;
|
||||
}
|
||||
|
||||
tracing::debug!(
|
||||
?ret_test1_1,
|
||||
?ret_test1_2,
|
||||
?ret_test2,
|
||||
?ret_test3,
|
||||
"finish stun test, try to detect nat type"
|
||||
);
|
||||
|
||||
let ret_test1_1 = ret_test1_1.unwrap();
|
||||
let ret_test1_2 = ret_test1_2.unwrap();
|
||||
|
||||
if ret_test1_1.mapped_socket_addr != ret_test1_2.mapped_socket_addr {
|
||||
return NatType::Symmetric;
|
||||
}
|
||||
|
||||
if ret_test1_1.mapped_socket_addr.is_some()
|
||||
&& ret_test1_1.source_addr == ret_test1_1.mapped_socket_addr.unwrap()
|
||||
{
|
||||
if !ret_test2.is_none() {
|
||||
return NatType::OpenInternet;
|
||||
} else {
|
||||
return NatType::SymUdpFirewall;
|
||||
}
|
||||
} else {
|
||||
if let Some(ret_test2) = ret_test2 {
|
||||
if source_port == ret_test2.get_mapped_addr_no_check().port()
|
||||
&& source_port == ret_test1_1.get_mapped_addr_no_check().port()
|
||||
{
|
||||
return NatType::NoPat;
|
||||
} else {
|
||||
return NatType::FullCone;
|
||||
}
|
||||
} else {
|
||||
if !ret_test3.is_none() {
|
||||
return NatType::Restricted;
|
||||
} else {
|
||||
return NatType::PortRestricted;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait::async_trait]
|
||||
#[auto_impl::auto_impl(&, Arc, Box)]
|
||||
pub trait StunInfoCollectorTrait: Send + Sync {
|
||||
fn get_stun_info(&self) -> StunInfo;
|
||||
async fn get_udp_port_mapping(&self, local_port: u16) -> Result<SocketAddr, Error>;
|
||||
}
|
||||
|
||||
pub struct StunInfoCollector {
|
||||
stun_servers: Arc<RwLock<Vec<String>>>,
|
||||
udp_nat_type: Arc<AtomicCell<(NatType, std::time::Instant)>>,
|
||||
redetect_notify: Arc<tokio::sync::Notify>,
|
||||
tasks: JoinSet<()>,
|
||||
}
|
||||
|
||||
#[async_trait::async_trait]
|
||||
impl StunInfoCollectorTrait for StunInfoCollector {
|
||||
fn get_stun_info(&self) -> StunInfo {
|
||||
let (typ, time) = self.udp_nat_type.load();
|
||||
StunInfo {
|
||||
udp_nat_type: typ as i32,
|
||||
tcp_nat_type: 0,
|
||||
last_update_time: time.elapsed().as_secs() as i64,
|
||||
}
|
||||
}
|
||||
|
||||
async fn get_udp_port_mapping(&self, local_port: u16) -> Result<SocketAddr, Error> {
|
||||
let stun_servers = self.stun_servers.read().await.clone();
|
||||
let mut ips = HostResolverIter::new(stun_servers.clone());
|
||||
while let Some(server) = ips.next().await {
|
||||
let stun = Stun::new(server.clone());
|
||||
let Ok(ret) = stun.bind_request(local_port, false, false).await else {
|
||||
tracing::warn!(?server, "stun bind request failed");
|
||||
continue;
|
||||
};
|
||||
if let Some(mapped_addr) = ret.mapped_socket_addr {
|
||||
return Ok(mapped_addr);
|
||||
}
|
||||
}
|
||||
Err(Error::NotFound)
|
||||
}
|
||||
}
|
||||
|
||||
impl StunInfoCollector {
|
||||
pub fn new(stun_servers: Vec<String>) -> Self {
|
||||
let mut ret = Self {
|
||||
stun_servers: Arc::new(RwLock::new(stun_servers)),
|
||||
udp_nat_type: Arc::new(AtomicCell::new((
|
||||
NatType::Unknown,
|
||||
std::time::Instant::now(),
|
||||
))),
|
||||
redetect_notify: Arc::new(tokio::sync::Notify::new()),
|
||||
tasks: JoinSet::new(),
|
||||
};
|
||||
|
||||
ret.start_stun_routine();
|
||||
|
||||
ret
|
||||
}
|
||||
|
||||
pub fn new_with_default_servers() -> Self {
|
||||
Self::new(Self::get_default_servers())
|
||||
}
|
||||
|
||||
pub fn get_default_servers() -> Vec<String> {
|
||||
// NOTICE: we may need to choose stun stun server based on geo location
|
||||
// stun server cross nation may return a external ip address with high latency and loss rate
|
||||
vec![
|
||||
"stun.miwifi.com:3478".to_string(),
|
||||
"stun.qq.com:3478".to_string(),
|
||||
// "stun.chat.bilibili.com:3478".to_string(), // bilibili's stun server doesn't repond to change_ip and change_port
|
||||
"fwa.lifesizecloud.com:3478".to_string(),
|
||||
"stun.isp.net.au:3478".to_string(),
|
||||
"stun.nextcloud.com:3478".to_string(),
|
||||
"stun.freeswitch.org:3478".to_string(),
|
||||
"stun.voip.blackberry.com:3478".to_string(),
|
||||
"stunserver.stunprotocol.org:3478".to_string(),
|
||||
"stun.sipnet.com:3478".to_string(),
|
||||
"stun.radiojar.com:3478".to_string(),
|
||||
"stun.sonetel.com:3478".to_string(),
|
||||
"stun.voipgate.com:3478".to_string(),
|
||||
"stun.counterpath.com:3478".to_string(),
|
||||
"180.235.108.91:3478".to_string(),
|
||||
"193.22.2.248:3478".to_string(),
|
||||
]
|
||||
}
|
||||
|
||||
fn start_stun_routine(&mut self) {
|
||||
let stun_servers = self.stun_servers.clone();
|
||||
let udp_nat_type = self.udp_nat_type.clone();
|
||||
let redetect_notify = self.redetect_notify.clone();
|
||||
self.tasks.spawn(async move {
|
||||
loop {
|
||||
let detector = UdpNatTypeDetector::new(stun_servers.read().await.clone());
|
||||
let old_nat_type = udp_nat_type.load().0;
|
||||
let mut ret = NatType::Unknown;
|
||||
for _ in 1..5 {
|
||||
// if nat type degrade, sleep and retry. so result can be relatively stable.
|
||||
ret = detector.get_udp_nat_type(0).await;
|
||||
if ret == NatType::Unknown || ret <= old_nat_type {
|
||||
break;
|
||||
}
|
||||
tokio::time::sleep(Duration::from_secs(5)).await;
|
||||
}
|
||||
udp_nat_type.store((ret, std::time::Instant::now()));
|
||||
|
||||
let sleep_sec = match ret {
|
||||
NatType::Unknown => 15,
|
||||
_ => 60,
|
||||
};
|
||||
tracing::info!(?ret, ?sleep_sec, "finish udp nat type detect");
|
||||
|
||||
tokio::select! {
|
||||
_ = redetect_notify.notified() => {}
|
||||
_ = tokio::time::sleep(Duration::from_secs(sleep_sec)) => {}
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
pub fn update_stun_info(&self) {
|
||||
self.redetect_notify.notify_one();
|
||||
}
|
||||
|
||||
pub async fn set_stun_servers(&self, stun_servers: Vec<String>) {
|
||||
*self.stun_servers.write().await = stun_servers;
|
||||
self.update_stun_info();
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
pub fn enable_log() {
|
||||
let filter = tracing_subscriber::EnvFilter::builder()
|
||||
.with_default_directive(tracing::level_filters::LevelFilter::TRACE.into())
|
||||
.from_env()
|
||||
.unwrap()
|
||||
.add_directive("tarpc=error".parse().unwrap());
|
||||
tracing_subscriber::fmt::fmt()
|
||||
.pretty()
|
||||
.with_env_filter(filter)
|
||||
.init();
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_stun_bind_request() {
|
||||
// miwifi / qq seems not correctly responde to change_ip and change_port, they always try to change the src ip and port.
|
||||
let mut ips = HostResolverIter::new(vec!["stun1.l.google.com:19302".to_string()]);
|
||||
let stun = Stun::new(ips.next().await.unwrap());
|
||||
// let stun = Stun::new("180.235.108.91:3478".to_string());
|
||||
// let stun = Stun::new("193.22.2.248:3478".to_string());
|
||||
// let stun = Stun::new("stun.chat.bilibili.com:3478".to_string());
|
||||
// let stun = Stun::new("stun.miwifi.com:3478".to_string());
|
||||
|
||||
// github actions are port restricted nat, so we only test last one.
|
||||
|
||||
// let rs = stun.bind_request(12345, true, true).await.unwrap();
|
||||
// assert!(rs.ip_changed);
|
||||
// assert!(rs.port_changed);
|
||||
|
||||
// let rs = stun.bind_request(12345, true, false).await.unwrap();
|
||||
// assert!(rs.ip_changed);
|
||||
// assert!(!rs.port_changed);
|
||||
|
||||
// let rs = stun.bind_request(12345, false, true).await.unwrap();
|
||||
// assert!(!rs.ip_changed);
|
||||
// assert!(rs.port_changed);
|
||||
|
||||
let rs = stun.bind_request(12345, false, false).await.unwrap();
|
||||
assert!(!rs.ip_changed);
|
||||
assert!(!rs.port_changed);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_udp_nat_type_detect() {
|
||||
let detector = UdpNatTypeDetector::new(vec![
|
||||
"stun.counterpath.com:3478".to_string(),
|
||||
"180.235.108.91:3478".to_string(),
|
||||
]);
|
||||
let ret = detector.get_udp_nat_type(0).await;
|
||||
|
||||
assert_ne!(ret, NatType::Unknown);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,229 @@
|
||||
use std::net::SocketAddr;
|
||||
|
||||
use stun_codec::net::{socket_addr_xor, SocketAddrDecoder, SocketAddrEncoder};
|
||||
|
||||
use stun_codec::rfc5389::attributes::{
|
||||
MappedAddress, Software, XorMappedAddress, XorMappedAddress2,
|
||||
};
|
||||
use stun_codec::rfc5780::attributes::{ChangeRequest, OtherAddress, ResponseOrigin};
|
||||
use stun_codec::{define_attribute_enums, AttributeType, Message, TransactionId};
|
||||
|
||||
use bytecodec::{ByteCount, Decode, Encode, Eos, Result, SizedEncode, TryTaggedDecode};
|
||||
|
||||
use stun_codec::macros::track;
|
||||
|
||||
macro_rules! impl_decode {
|
||||
($decoder:ty, $item:ident, $and_then:expr) => {
|
||||
impl Decode for $decoder {
|
||||
type Item = $item;
|
||||
|
||||
fn decode(&mut self, buf: &[u8], eos: Eos) -> Result<usize> {
|
||||
track!(self.0.decode(buf, eos))
|
||||
}
|
||||
|
||||
fn finish_decoding(&mut self) -> Result<Self::Item> {
|
||||
track!(self.0.finish_decoding()).and_then($and_then)
|
||||
}
|
||||
|
||||
fn requiring_bytes(&self) -> ByteCount {
|
||||
self.0.requiring_bytes()
|
||||
}
|
||||
|
||||
fn is_idle(&self) -> bool {
|
||||
self.0.is_idle()
|
||||
}
|
||||
}
|
||||
impl TryTaggedDecode for $decoder {
|
||||
type Tag = AttributeType;
|
||||
|
||||
fn try_start_decoding(&mut self, attr_type: Self::Tag) -> Result<bool> {
|
||||
Ok(attr_type.as_u16() == $item::CODEPOINT)
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
macro_rules! impl_encode {
|
||||
($encoder:ty, $item:ty, $map_from:expr) => {
|
||||
impl Encode for $encoder {
|
||||
type Item = $item;
|
||||
|
||||
fn encode(&mut self, buf: &mut [u8], eos: Eos) -> Result<usize> {
|
||||
track!(self.0.encode(buf, eos))
|
||||
}
|
||||
|
||||
#[allow(clippy::redundant_closure_call)]
|
||||
fn start_encoding(&mut self, item: Self::Item) -> Result<()> {
|
||||
track!(self.0.start_encoding($map_from(item)))
|
||||
}
|
||||
|
||||
fn requiring_bytes(&self) -> ByteCount {
|
||||
self.0.requiring_bytes()
|
||||
}
|
||||
|
||||
fn is_idle(&self) -> bool {
|
||||
self.0.is_idle()
|
||||
}
|
||||
}
|
||||
impl SizedEncode for $encoder {
|
||||
fn exact_requiring_bytes(&self) -> u64 {
|
||||
self.0.exact_requiring_bytes()
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
|
||||
pub struct ChangedAddress(SocketAddr);
|
||||
impl ChangedAddress {
|
||||
/// The codepoint of the type of the attribute.
|
||||
pub const CODEPOINT: u16 = 0x0005;
|
||||
|
||||
pub fn new(addr: SocketAddr) -> Self {
|
||||
ChangedAddress(addr)
|
||||
}
|
||||
|
||||
/// Returns the address of this instance.
|
||||
pub fn address(&self) -> SocketAddr {
|
||||
self.0
|
||||
}
|
||||
}
|
||||
impl stun_codec::Attribute for ChangedAddress {
|
||||
type Decoder = ChangedAddressDecoder;
|
||||
type Encoder = ChangedAddressEncoder;
|
||||
|
||||
fn get_type(&self) -> AttributeType {
|
||||
AttributeType::new(Self::CODEPOINT)
|
||||
}
|
||||
|
||||
fn before_encode<A: stun_codec::Attribute>(
|
||||
&mut self,
|
||||
message: &Message<A>,
|
||||
) -> bytecodec::Result<()> {
|
||||
self.0 = socket_addr_xor(self.0, message.transaction_id());
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn after_decode<A: stun_codec::Attribute>(
|
||||
&mut self,
|
||||
message: &Message<A>,
|
||||
) -> bytecodec::Result<()> {
|
||||
self.0 = socket_addr_xor(self.0, message.transaction_id());
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Default)]
|
||||
pub struct ChangedAddressDecoder(SocketAddrDecoder);
|
||||
impl ChangedAddressDecoder {
|
||||
pub fn new() -> Self {
|
||||
Self::default()
|
||||
}
|
||||
}
|
||||
impl_decode!(ChangedAddressDecoder, ChangedAddress, |item| Ok(
|
||||
ChangedAddress(item)
|
||||
));
|
||||
|
||||
#[derive(Debug, Default)]
|
||||
pub struct ChangedAddressEncoder(SocketAddrEncoder);
|
||||
impl ChangedAddressEncoder {
|
||||
pub fn new() -> Self {
|
||||
Self::default()
|
||||
}
|
||||
}
|
||||
impl_encode!(ChangedAddressEncoder, ChangedAddress, |item: Self::Item| {
|
||||
item.0
|
||||
});
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
|
||||
pub struct SourceAddress(SocketAddr);
|
||||
impl SourceAddress {
|
||||
/// The codepoint of the type of the attribute.
|
||||
pub const CODEPOINT: u16 = 0x0004;
|
||||
|
||||
pub fn new(addr: SocketAddr) -> Self {
|
||||
SourceAddress(addr)
|
||||
}
|
||||
|
||||
/// Returns the address of this instance.
|
||||
pub fn address(&self) -> SocketAddr {
|
||||
self.0
|
||||
}
|
||||
}
|
||||
impl stun_codec::Attribute for SourceAddress {
|
||||
type Decoder = SourceAddressDecoder;
|
||||
type Encoder = SourceAddressEncoder;
|
||||
|
||||
fn get_type(&self) -> AttributeType {
|
||||
AttributeType::new(Self::CODEPOINT)
|
||||
}
|
||||
|
||||
fn before_encode<A: stun_codec::Attribute>(
|
||||
&mut self,
|
||||
message: &Message<A>,
|
||||
) -> bytecodec::Result<()> {
|
||||
self.0 = socket_addr_xor(self.0, message.transaction_id());
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn after_decode<A: stun_codec::Attribute>(
|
||||
&mut self,
|
||||
message: &Message<A>,
|
||||
) -> bytecodec::Result<()> {
|
||||
self.0 = socket_addr_xor(self.0, message.transaction_id());
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Default)]
|
||||
pub struct SourceAddressDecoder(SocketAddrDecoder);
|
||||
impl SourceAddressDecoder {
|
||||
pub fn new() -> Self {
|
||||
Self::default()
|
||||
}
|
||||
}
|
||||
impl_decode!(SourceAddressDecoder, SourceAddress, |item| Ok(
|
||||
SourceAddress(item)
|
||||
));
|
||||
|
||||
#[derive(Debug, Default)]
|
||||
pub struct SourceAddressEncoder(SocketAddrEncoder);
|
||||
impl SourceAddressEncoder {
|
||||
pub fn new() -> Self {
|
||||
Self::default()
|
||||
}
|
||||
}
|
||||
impl_encode!(SourceAddressEncoder, SourceAddress, |item: Self::Item| {
|
||||
item.0
|
||||
});
|
||||
|
||||
pub fn tid_to_u128(tid: &TransactionId) -> u128 {
|
||||
let mut tid_buf = [0u8; 16];
|
||||
// copy bytes from msg_tid to tid_buf
|
||||
tid_buf[..tid.as_bytes().len()].copy_from_slice(tid.as_bytes());
|
||||
u128::from_le_bytes(tid_buf)
|
||||
}
|
||||
|
||||
pub fn u128_to_tid(tid: u128) -> TransactionId {
|
||||
let tid_buf = tid.to_le_bytes();
|
||||
let mut tid_arr = [0u8; 12];
|
||||
tid_arr.copy_from_slice(&tid_buf[..12]);
|
||||
TransactionId::new(tid_arr)
|
||||
}
|
||||
|
||||
define_attribute_enums!(
|
||||
Attribute,
|
||||
AttributeDecoder,
|
||||
AttributeEncoder,
|
||||
[
|
||||
Software,
|
||||
MappedAddress,
|
||||
XorMappedAddress,
|
||||
XorMappedAddress2,
|
||||
OtherAddress,
|
||||
ChangeRequest,
|
||||
ChangedAddress,
|
||||
SourceAddress,
|
||||
ResponseOrigin
|
||||
]
|
||||
);
|
||||
@@ -0,0 +1,411 @@
|
||||
// try connect peers directly, with either its public ip or lan ip
|
||||
|
||||
use std::sync::Arc;
|
||||
|
||||
use crate::{
|
||||
common::{error::Error, global_ctx::ArcGlobalCtx, PeerId},
|
||||
peers::{peer_manager::PeerManager, peer_rpc::PeerRpcManager},
|
||||
};
|
||||
|
||||
use crate::rpc::{peer::GetIpListResponse, PeerConnInfo};
|
||||
use tokio::{task::JoinSet, time::timeout};
|
||||
use tracing::Instrument;
|
||||
|
||||
use super::create_connector_by_url;
|
||||
|
||||
pub const DIRECT_CONNECTOR_SERVICE_ID: u32 = 1;
|
||||
pub const DIRECT_CONNECTOR_BLACKLIST_TIMEOUT_SEC: u64 = 300;
|
||||
|
||||
#[tarpc::service]
|
||||
pub trait DirectConnectorRpc {
|
||||
async fn get_ip_list() -> GetIpListResponse;
|
||||
}
|
||||
|
||||
#[async_trait::async_trait]
|
||||
pub trait PeerManagerForDirectConnector {
|
||||
async fn list_peers(&self) -> Vec<PeerId>;
|
||||
async fn list_peer_conns(&self, peer_id: PeerId) -> Option<Vec<PeerConnInfo>>;
|
||||
fn get_peer_rpc_mgr(&self) -> Arc<PeerRpcManager>;
|
||||
}
|
||||
|
||||
#[async_trait::async_trait]
|
||||
impl PeerManagerForDirectConnector for PeerManager {
|
||||
async fn list_peers(&self) -> Vec<PeerId> {
|
||||
let mut ret = vec![];
|
||||
|
||||
let routes = self.list_routes().await;
|
||||
for r in routes.iter() {
|
||||
ret.push(r.peer_id);
|
||||
}
|
||||
|
||||
ret
|
||||
}
|
||||
|
||||
async fn list_peer_conns(&self, peer_id: PeerId) -> Option<Vec<PeerConnInfo>> {
|
||||
self.get_peer_map().list_peer_conns(peer_id).await
|
||||
}
|
||||
|
||||
fn get_peer_rpc_mgr(&self) -> Arc<PeerRpcManager> {
|
||||
self.get_peer_rpc_mgr()
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
struct DirectConnectorManagerRpcServer {
|
||||
// TODO: this only cache for one src peer, should make it global
|
||||
global_ctx: ArcGlobalCtx,
|
||||
}
|
||||
|
||||
#[tarpc::server]
|
||||
impl DirectConnectorRpc for DirectConnectorManagerRpcServer {
|
||||
async fn get_ip_list(self, _: tarpc::context::Context) -> GetIpListResponse {
|
||||
let mut ret = self.global_ctx.get_ip_collector().collect_ip_addrs().await;
|
||||
ret.listeners = self.global_ctx.get_running_listeners();
|
||||
ret
|
||||
}
|
||||
}
|
||||
|
||||
impl DirectConnectorManagerRpcServer {
|
||||
pub fn new(global_ctx: ArcGlobalCtx) -> Self {
|
||||
Self { global_ctx }
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Hash, Eq, PartialEq, Clone)]
|
||||
struct DstBlackListItem(PeerId, String);
|
||||
|
||||
#[derive(Hash, Eq, PartialEq, Clone)]
|
||||
struct DstSchemeBlackListItem(PeerId, String);
|
||||
|
||||
struct DirectConnectorManagerData {
|
||||
global_ctx: ArcGlobalCtx,
|
||||
peer_manager: Arc<PeerManager>,
|
||||
dst_blacklist: timedmap::TimedMap<DstBlackListItem, ()>,
|
||||
dst_sceme_blacklist: timedmap::TimedMap<DstSchemeBlackListItem, ()>,
|
||||
}
|
||||
|
||||
impl DirectConnectorManagerData {
|
||||
pub fn new(global_ctx: ArcGlobalCtx, peer_manager: Arc<PeerManager>) -> Self {
|
||||
Self {
|
||||
global_ctx,
|
||||
peer_manager,
|
||||
dst_blacklist: timedmap::TimedMap::new(),
|
||||
dst_sceme_blacklist: timedmap::TimedMap::new(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl std::fmt::Debug for DirectConnectorManagerData {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
f.debug_struct("DirectConnectorManagerData")
|
||||
.field("peer_manager", &self.peer_manager)
|
||||
.finish()
|
||||
}
|
||||
}
|
||||
|
||||
pub struct DirectConnectorManager {
|
||||
global_ctx: ArcGlobalCtx,
|
||||
data: Arc<DirectConnectorManagerData>,
|
||||
|
||||
tasks: JoinSet<()>,
|
||||
}
|
||||
|
||||
impl DirectConnectorManager {
|
||||
pub fn new(global_ctx: ArcGlobalCtx, peer_manager: Arc<PeerManager>) -> Self {
|
||||
Self {
|
||||
global_ctx: global_ctx.clone(),
|
||||
data: Arc::new(DirectConnectorManagerData::new(global_ctx, peer_manager)),
|
||||
tasks: JoinSet::new(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn run(&mut self) {
|
||||
self.run_as_server();
|
||||
self.run_as_client();
|
||||
}
|
||||
|
||||
pub fn run_as_server(&mut self) {
|
||||
self.data.peer_manager.get_peer_rpc_mgr().run_service(
|
||||
DIRECT_CONNECTOR_SERVICE_ID,
|
||||
DirectConnectorManagerRpcServer::new(self.global_ctx.clone()).serve(),
|
||||
);
|
||||
}
|
||||
|
||||
pub fn run_as_client(&mut self) {
|
||||
let data = self.data.clone();
|
||||
let my_peer_id = self.data.peer_manager.my_peer_id();
|
||||
self.tasks.spawn(
|
||||
async move {
|
||||
loop {
|
||||
let peers = data.peer_manager.list_peers().await;
|
||||
let mut tasks = JoinSet::new();
|
||||
for peer_id in peers {
|
||||
if peer_id == my_peer_id {
|
||||
continue;
|
||||
}
|
||||
tasks.spawn(Self::do_try_direct_connect(data.clone(), peer_id));
|
||||
}
|
||||
|
||||
while let Some(task_ret) = tasks.join_next().await {
|
||||
tracing::trace!(?task_ret, "direct connect task ret");
|
||||
}
|
||||
tokio::time::sleep(std::time::Duration::from_secs(5)).await;
|
||||
}
|
||||
}
|
||||
.instrument(
|
||||
tracing::info_span!("direct_connector_client", my_id = ?self.global_ctx.id),
|
||||
),
|
||||
);
|
||||
}
|
||||
|
||||
async fn do_try_connect_to_ip(
|
||||
data: Arc<DirectConnectorManagerData>,
|
||||
dst_peer_id: PeerId,
|
||||
addr: String,
|
||||
) -> Result<(), Error> {
|
||||
data.dst_blacklist.cleanup();
|
||||
if data
|
||||
.dst_blacklist
|
||||
.contains(&DstBlackListItem(dst_peer_id.clone(), addr.clone()))
|
||||
{
|
||||
tracing::trace!("try_connect_to_ip failed, addr in blacklist: {}", addr);
|
||||
return Err(Error::UrlInBlacklist);
|
||||
}
|
||||
|
||||
let connector = create_connector_by_url(&addr, &data.global_ctx).await?;
|
||||
let (peer_id, conn_id) = timeout(
|
||||
std::time::Duration::from_secs(5),
|
||||
data.peer_manager.try_connect(connector),
|
||||
)
|
||||
.await??;
|
||||
|
||||
// let (peer_id, conn_id) = data.peer_manager.try_connect(connector).await?;
|
||||
|
||||
if peer_id != dst_peer_id {
|
||||
tracing::info!(
|
||||
"connect to ip succ: {}, but peer id mismatch, expect: {}, actual: {}",
|
||||
addr,
|
||||
dst_peer_id,
|
||||
peer_id
|
||||
);
|
||||
data.peer_manager
|
||||
.get_peer_map()
|
||||
.close_peer_conn(peer_id, &conn_id)
|
||||
.await?;
|
||||
return Err(Error::InvalidUrl(addr));
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tracing::instrument]
|
||||
async fn try_connect_to_ip(
|
||||
data: Arc<DirectConnectorManagerData>,
|
||||
dst_peer_id: PeerId,
|
||||
addr: String,
|
||||
) -> Result<(), Error> {
|
||||
let ret = Self::do_try_connect_to_ip(data.clone(), dst_peer_id, addr.clone()).await;
|
||||
if let Err(e) = ret {
|
||||
if !matches!(e, Error::UrlInBlacklist) {
|
||||
tracing::info!(
|
||||
"try_connect_to_ip failed: {:?}, peer_id: {}",
|
||||
e,
|
||||
dst_peer_id
|
||||
);
|
||||
data.dst_blacklist.insert(
|
||||
DstBlackListItem(dst_peer_id.clone(), addr.clone()),
|
||||
(),
|
||||
std::time::Duration::from_secs(DIRECT_CONNECTOR_BLACKLIST_TIMEOUT_SEC),
|
||||
);
|
||||
}
|
||||
return Err(e);
|
||||
} else {
|
||||
log::info!("try_connect_to_ip success, peer_id: {}", dst_peer_id);
|
||||
return Ok(());
|
||||
}
|
||||
}
|
||||
|
||||
#[tracing::instrument]
|
||||
async fn do_try_direct_connect_internal(
|
||||
data: Arc<DirectConnectorManagerData>,
|
||||
dst_peer_id: PeerId,
|
||||
ip_list: GetIpListResponse,
|
||||
) -> Result<(), Error> {
|
||||
let available_listeners = ip_list
|
||||
.listeners
|
||||
.iter()
|
||||
.filter_map(|l| if l.scheme() != "ring" { Some(l) } else { None })
|
||||
.filter(|l| l.port().is_some())
|
||||
.filter(|l| {
|
||||
!data.dst_sceme_blacklist.contains(&DstSchemeBlackListItem(
|
||||
dst_peer_id.clone(),
|
||||
l.scheme().to_string(),
|
||||
))
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let mut listener = available_listeners.get(0).ok_or(anyhow::anyhow!(
|
||||
"peer {} have no valid listener",
|
||||
dst_peer_id
|
||||
))?;
|
||||
|
||||
// if have default listener, use it first
|
||||
listener = available_listeners
|
||||
.iter()
|
||||
.find(|l| l.scheme() == data.global_ctx.get_flags().default_protocol)
|
||||
.unwrap_or(listener);
|
||||
|
||||
let mut tasks = JoinSet::new();
|
||||
ip_list.interface_ipv4s.iter().for_each(|ip| {
|
||||
let addr = format!(
|
||||
"{}://{}:{}",
|
||||
listener.scheme(),
|
||||
ip,
|
||||
listener.port().unwrap_or(11010)
|
||||
);
|
||||
tasks.spawn(Self::try_connect_to_ip(
|
||||
data.clone(),
|
||||
dst_peer_id.clone(),
|
||||
addr,
|
||||
));
|
||||
});
|
||||
|
||||
let addr = format!(
|
||||
"{}://{}:{}",
|
||||
listener.scheme(),
|
||||
ip_list.public_ipv4.clone(),
|
||||
listener.port().unwrap_or(11010)
|
||||
);
|
||||
tasks.spawn(Self::try_connect_to_ip(
|
||||
data.clone(),
|
||||
dst_peer_id.clone(),
|
||||
addr,
|
||||
));
|
||||
|
||||
let mut has_succ = false;
|
||||
while let Some(ret) = tasks.join_next().await {
|
||||
if let Err(e) = ret {
|
||||
log::error!("join direct connect task failed: {:?}", e);
|
||||
} else if let Ok(Ok(_)) = ret {
|
||||
has_succ = true;
|
||||
}
|
||||
}
|
||||
|
||||
if !has_succ {
|
||||
data.dst_sceme_blacklist.insert(
|
||||
DstSchemeBlackListItem(dst_peer_id.clone(), listener.scheme().to_string()),
|
||||
(),
|
||||
std::time::Duration::from_secs(DIRECT_CONNECTOR_BLACKLIST_TIMEOUT_SEC),
|
||||
);
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tracing::instrument]
|
||||
async fn do_try_direct_connect(
|
||||
data: Arc<DirectConnectorManagerData>,
|
||||
dst_peer_id: PeerId,
|
||||
) -> Result<(), Error> {
|
||||
let peer_manager = data.peer_manager.clone();
|
||||
// check if we have direct connection with dst_peer_id
|
||||
if let Some(c) = peer_manager.list_peer_conns(dst_peer_id).await {
|
||||
// currently if we have any type of direct connection (udp or tcp), we will not try to connect
|
||||
if !c.is_empty() {
|
||||
return Ok(());
|
||||
}
|
||||
}
|
||||
|
||||
log::trace!("try direct connect to peer: {}", dst_peer_id);
|
||||
|
||||
let ip_list = peer_manager
|
||||
.get_peer_rpc_mgr()
|
||||
.do_client_rpc_scoped(1, dst_peer_id, |c| async {
|
||||
let client =
|
||||
DirectConnectorRpcClient::new(tarpc::client::Config::default(), c).spawn();
|
||||
let ip_list = client.get_ip_list(tarpc::context::current()).await;
|
||||
tracing::info!(ip_list = ?ip_list, dst_peer_id = ?dst_peer_id, "got ip list");
|
||||
ip_list
|
||||
})
|
||||
.await?;
|
||||
|
||||
Self::do_try_direct_connect_internal(data, dst_peer_id, ip_list).await
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::sync::Arc;
|
||||
|
||||
use crate::{
|
||||
connector::direct::{
|
||||
DirectConnectorManager, DirectConnectorManagerData, DstBlackListItem,
|
||||
DstSchemeBlackListItem,
|
||||
},
|
||||
instance::listeners::ListenerManager,
|
||||
peers::tests::{
|
||||
connect_peer_manager, create_mock_peer_manager, wait_route_appear,
|
||||
wait_route_appear_with_cost,
|
||||
},
|
||||
rpc::peer::GetIpListResponse,
|
||||
};
|
||||
|
||||
#[rstest::rstest]
|
||||
#[tokio::test]
|
||||
async fn direct_connector_basic_test(#[values("tcp", "udp", "wg")] proto: &str) {
|
||||
let p_a = create_mock_peer_manager().await;
|
||||
let p_b = create_mock_peer_manager().await;
|
||||
let p_c = create_mock_peer_manager().await;
|
||||
connect_peer_manager(p_a.clone(), p_b.clone()).await;
|
||||
connect_peer_manager(p_b.clone(), p_c.clone()).await;
|
||||
|
||||
wait_route_appear(p_a.clone(), p_c.clone()).await.unwrap();
|
||||
|
||||
let mut dm_a = DirectConnectorManager::new(p_a.get_global_ctx(), p_a.clone());
|
||||
let mut dm_c = DirectConnectorManager::new(p_c.get_global_ctx(), p_c.clone());
|
||||
|
||||
dm_a.run_as_client();
|
||||
dm_c.run_as_server();
|
||||
|
||||
let port = if proto == "wg" { 11040 } else { 11041 };
|
||||
p_c.get_global_ctx()
|
||||
.config
|
||||
.set_listeners(vec![format!("{}://0.0.0.0:{}", proto, port)
|
||||
.parse()
|
||||
.unwrap()]);
|
||||
let mut lis_c = ListenerManager::new(p_c.get_global_ctx(), p_c.clone());
|
||||
lis_c.prepare_listeners().await.unwrap();
|
||||
|
||||
lis_c.run().await.unwrap();
|
||||
|
||||
wait_route_appear_with_cost(p_a.clone(), p_c.my_peer_id(), Some(1))
|
||||
.await
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn direct_connector_scheme_blacklist() {
|
||||
let p_a = create_mock_peer_manager().await;
|
||||
let data = Arc::new(DirectConnectorManagerData::new(
|
||||
p_a.get_global_ctx(),
|
||||
p_a.clone(),
|
||||
));
|
||||
let mut ip_list = GetIpListResponse::new();
|
||||
ip_list
|
||||
.listeners
|
||||
.push("tcp://127.0.0.1:10222".parse().unwrap());
|
||||
|
||||
ip_list.interface_ipv4s.push("127.0.0.1".to_string());
|
||||
|
||||
DirectConnectorManager::do_try_direct_connect_internal(data.clone(), 1, ip_list.clone())
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert!(data
|
||||
.dst_sceme_blacklist
|
||||
.contains(&DstSchemeBlackListItem(1, "tcp".into())));
|
||||
|
||||
assert!(data
|
||||
.dst_blacklist
|
||||
.contains(&DstBlackListItem(1, ip_list.listeners[0].to_string())));
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,390 @@
|
||||
use std::{collections::BTreeSet, sync::Arc};
|
||||
|
||||
use dashmap::{DashMap, DashSet};
|
||||
use tokio::{
|
||||
sync::{broadcast::Receiver, mpsc, Mutex},
|
||||
task::JoinSet,
|
||||
time::timeout,
|
||||
};
|
||||
|
||||
use crate::{common::PeerId, peers::peer_conn::PeerConnId, rpc as easytier_rpc};
|
||||
|
||||
use crate::{
|
||||
common::{
|
||||
error::Error,
|
||||
global_ctx::{ArcGlobalCtx, GlobalCtxEvent},
|
||||
netns::NetNS,
|
||||
},
|
||||
connector::set_bind_addr_for_peer_connector,
|
||||
peers::peer_manager::PeerManager,
|
||||
rpc::{
|
||||
connector_manage_rpc_server::ConnectorManageRpc, Connector, ConnectorStatus,
|
||||
ListConnectorRequest, ManageConnectorRequest,
|
||||
},
|
||||
tunnels::{Tunnel, TunnelConnector},
|
||||
use_global_var,
|
||||
};
|
||||
|
||||
use super::create_connector_by_url;
|
||||
|
||||
type ConnectorMap = Arc<DashMap<String, Box<dyn TunnelConnector + Send + Sync>>>;
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct ReconnResult {
|
||||
dead_url: String,
|
||||
peer_id: PeerId,
|
||||
conn_id: PeerConnId,
|
||||
}
|
||||
|
||||
struct ConnectorManagerData {
|
||||
connectors: ConnectorMap,
|
||||
reconnecting: DashSet<String>,
|
||||
peer_manager: Arc<PeerManager>,
|
||||
alive_conn_urls: Arc<Mutex<BTreeSet<String>>>,
|
||||
// user removed connector urls
|
||||
removed_conn_urls: Arc<DashSet<String>>,
|
||||
net_ns: NetNS,
|
||||
global_ctx: ArcGlobalCtx,
|
||||
}
|
||||
|
||||
pub struct ManualConnectorManager {
|
||||
global_ctx: ArcGlobalCtx,
|
||||
data: Arc<ConnectorManagerData>,
|
||||
tasks: JoinSet<()>,
|
||||
}
|
||||
|
||||
impl ManualConnectorManager {
|
||||
pub fn new(global_ctx: ArcGlobalCtx, peer_manager: Arc<PeerManager>) -> Self {
|
||||
let connectors = Arc::new(DashMap::new());
|
||||
let tasks = JoinSet::new();
|
||||
let event_subscriber = global_ctx.subscribe();
|
||||
|
||||
let mut ret = Self {
|
||||
global_ctx: global_ctx.clone(),
|
||||
data: Arc::new(ConnectorManagerData {
|
||||
connectors,
|
||||
reconnecting: DashSet::new(),
|
||||
peer_manager,
|
||||
alive_conn_urls: Arc::new(Mutex::new(BTreeSet::new())),
|
||||
removed_conn_urls: Arc::new(DashSet::new()),
|
||||
net_ns: global_ctx.net_ns.clone(),
|
||||
global_ctx,
|
||||
}),
|
||||
tasks,
|
||||
};
|
||||
|
||||
ret.tasks
|
||||
.spawn(Self::conn_mgr_routine(ret.data.clone(), event_subscriber));
|
||||
|
||||
ret
|
||||
}
|
||||
|
||||
pub fn add_connector<T>(&self, connector: T)
|
||||
where
|
||||
T: TunnelConnector + Send + Sync + 'static,
|
||||
{
|
||||
log::info!("add_connector: {}", connector.remote_url());
|
||||
self.data
|
||||
.connectors
|
||||
.insert(connector.remote_url().into(), Box::new(connector));
|
||||
}
|
||||
|
||||
pub async fn add_connector_by_url(&self, url: &str) -> Result<(), Error> {
|
||||
self.add_connector(create_connector_by_url(url, &self.global_ctx).await?);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn remove_connector(&self, url: &str) -> Result<(), Error> {
|
||||
log::info!("remove_connector: {}", url);
|
||||
if !self.list_connectors().await.iter().any(|x| x.url == url) {
|
||||
return Err(Error::NotFound);
|
||||
}
|
||||
self.data.removed_conn_urls.insert(url.into());
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn list_connectors(&self) -> Vec<Connector> {
|
||||
let conn_urls: BTreeSet<String> = self
|
||||
.data
|
||||
.connectors
|
||||
.iter()
|
||||
.map(|x| x.key().clone().into())
|
||||
.collect();
|
||||
|
||||
let dead_urls: BTreeSet<String> = Self::collect_dead_conns(self.data.clone())
|
||||
.await
|
||||
.into_iter()
|
||||
.collect();
|
||||
|
||||
let mut ret = Vec::new();
|
||||
|
||||
for conn_url in conn_urls {
|
||||
let mut status = ConnectorStatus::Connected;
|
||||
if dead_urls.contains(&conn_url) {
|
||||
status = ConnectorStatus::Disconnected;
|
||||
}
|
||||
ret.insert(
|
||||
0,
|
||||
Connector {
|
||||
url: conn_url,
|
||||
status: status.into(),
|
||||
},
|
||||
);
|
||||
}
|
||||
|
||||
let reconnecting_urls: BTreeSet<String> = self
|
||||
.data
|
||||
.reconnecting
|
||||
.iter()
|
||||
.map(|x| x.clone().into())
|
||||
.collect();
|
||||
|
||||
for conn_url in reconnecting_urls {
|
||||
ret.insert(
|
||||
0,
|
||||
Connector {
|
||||
url: conn_url,
|
||||
status: ConnectorStatus::Connecting.into(),
|
||||
},
|
||||
);
|
||||
}
|
||||
|
||||
ret
|
||||
}
|
||||
|
||||
async fn conn_mgr_routine(
|
||||
data: Arc<ConnectorManagerData>,
|
||||
mut event_recv: Receiver<GlobalCtxEvent>,
|
||||
) {
|
||||
log::warn!("conn_mgr_routine started");
|
||||
let mut reconn_interval = tokio::time::interval(std::time::Duration::from_millis(
|
||||
use_global_var!(MANUAL_CONNECTOR_RECONNECT_INTERVAL_MS),
|
||||
));
|
||||
let mut reconn_tasks = JoinSet::new();
|
||||
let (reconn_result_send, mut reconn_result_recv) = mpsc::channel(100);
|
||||
|
||||
loop {
|
||||
tokio::select! {
|
||||
event = event_recv.recv() => {
|
||||
if let Ok(event) = event {
|
||||
Self::handle_event(&event, data.clone()).await;
|
||||
} else {
|
||||
log::warn!("event_recv closed");
|
||||
panic!("event_recv closed");
|
||||
}
|
||||
}
|
||||
|
||||
_ = reconn_interval.tick() => {
|
||||
let dead_urls = Self::collect_dead_conns(data.clone()).await;
|
||||
if dead_urls.is_empty() {
|
||||
continue;
|
||||
}
|
||||
for dead_url in dead_urls {
|
||||
let data_clone = data.clone();
|
||||
let sender = reconn_result_send.clone();
|
||||
let (_, connector) = data.connectors.remove(&dead_url).unwrap();
|
||||
let insert_succ = data.reconnecting.insert(dead_url.clone());
|
||||
assert!(insert_succ);
|
||||
reconn_tasks.spawn(async move {
|
||||
sender.send(Self::conn_reconnect(data_clone.clone(), dead_url, connector).await).await.unwrap();
|
||||
});
|
||||
}
|
||||
log::info!("reconn_interval tick, done");
|
||||
}
|
||||
|
||||
ret = reconn_result_recv.recv() => {
|
||||
log::warn!("reconn_tasks done, out: {:?}", ret);
|
||||
let _ = reconn_tasks.join_next().await.unwrap();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn handle_event(event: &GlobalCtxEvent, data: Arc<ConnectorManagerData>) {
|
||||
match event {
|
||||
GlobalCtxEvent::PeerConnAdded(conn_info) => {
|
||||
let addr = conn_info.tunnel.as_ref().unwrap().remote_addr.clone();
|
||||
data.alive_conn_urls.lock().await.insert(addr);
|
||||
log::warn!("peer conn added: {:?}", conn_info);
|
||||
}
|
||||
|
||||
GlobalCtxEvent::PeerConnRemoved(conn_info) => {
|
||||
let addr = conn_info.tunnel.as_ref().unwrap().remote_addr.clone();
|
||||
data.alive_conn_urls.lock().await.remove(&addr);
|
||||
log::warn!("peer conn removed: {:?}", conn_info);
|
||||
}
|
||||
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
|
||||
fn handle_remove_connector(data: Arc<ConnectorManagerData>) {
|
||||
let remove_later = DashSet::new();
|
||||
for it in data.removed_conn_urls.iter() {
|
||||
let url = it.key();
|
||||
if let Some(_) = data.connectors.remove(url) {
|
||||
log::warn!("connector: {}, removed", url);
|
||||
continue;
|
||||
} else if data.reconnecting.contains(url) {
|
||||
log::warn!("connector: {}, reconnecting, remove later.", url);
|
||||
remove_later.insert(url.clone());
|
||||
continue;
|
||||
} else {
|
||||
log::warn!("connector: {}, not found", url);
|
||||
}
|
||||
}
|
||||
data.removed_conn_urls.clear();
|
||||
for it in remove_later.iter() {
|
||||
data.removed_conn_urls.insert(it.key().clone());
|
||||
}
|
||||
}
|
||||
|
||||
async fn collect_dead_conns(data: Arc<ConnectorManagerData>) -> BTreeSet<String> {
|
||||
Self::handle_remove_connector(data.clone());
|
||||
|
||||
let curr_alive = data.alive_conn_urls.lock().await.clone();
|
||||
let all_urls: BTreeSet<String> = data
|
||||
.connectors
|
||||
.iter()
|
||||
.map(|x| x.key().clone().into())
|
||||
.collect();
|
||||
&all_urls - &curr_alive
|
||||
}
|
||||
|
||||
async fn conn_reconnect(
|
||||
data: Arc<ConnectorManagerData>,
|
||||
dead_url: String,
|
||||
connector: Box<dyn TunnelConnector + Send + Sync>,
|
||||
) -> Result<ReconnResult, Error> {
|
||||
let connector = Arc::new(Mutex::new(Some(connector)));
|
||||
let net_ns = data.net_ns.clone();
|
||||
|
||||
log::info!("reconnect: {}", dead_url);
|
||||
|
||||
let connector_clone = connector.clone();
|
||||
let data_clone = data.clone();
|
||||
let url_clone = dead_url.clone();
|
||||
let ip_collector = data.global_ctx.get_ip_collector();
|
||||
let reconn_task = async move {
|
||||
let mut locked = connector_clone.lock().await;
|
||||
let conn = locked.as_mut().unwrap();
|
||||
// TODO: should support set v6 here, use url in connector array
|
||||
set_bind_addr_for_peer_connector(conn, true, &ip_collector).await;
|
||||
|
||||
data_clone
|
||||
.global_ctx
|
||||
.issue_event(GlobalCtxEvent::Connecting(conn.remote_url().clone()));
|
||||
|
||||
let _g = net_ns.guard();
|
||||
log::info!("reconnect try connect... conn: {:?}", conn);
|
||||
let tunnel = conn.connect().await?;
|
||||
log::info!("reconnect get tunnel succ: {:?}", tunnel);
|
||||
assert_eq!(
|
||||
url_clone,
|
||||
tunnel.info().unwrap().remote_addr,
|
||||
"info: {:?}",
|
||||
tunnel.info()
|
||||
);
|
||||
let (peer_id, conn_id) = data_clone.peer_manager.add_client_tunnel(tunnel).await?;
|
||||
log::info!("reconnect succ: {} {} {}", peer_id, conn_id, url_clone);
|
||||
Ok(ReconnResult {
|
||||
dead_url: url_clone,
|
||||
peer_id,
|
||||
conn_id,
|
||||
})
|
||||
};
|
||||
|
||||
let ret = timeout(std::time::Duration::from_secs(1), reconn_task).await;
|
||||
log::info!("reconnect: {} done, ret: {:?}", dead_url, ret);
|
||||
|
||||
if ret.is_err() || ret.as_ref().unwrap().is_err() {
|
||||
data.global_ctx.issue_event(GlobalCtxEvent::ConnectError(
|
||||
dead_url.clone(),
|
||||
format!("{:?}", ret),
|
||||
));
|
||||
}
|
||||
|
||||
let conn = connector.lock().await.take().unwrap();
|
||||
data.reconnecting.remove(&dead_url).unwrap();
|
||||
data.connectors.insert(dead_url.clone(), conn);
|
||||
|
||||
ret?
|
||||
}
|
||||
}
|
||||
|
||||
pub struct ConnectorManagerRpcService(pub Arc<ManualConnectorManager>);
|
||||
|
||||
#[tonic::async_trait]
|
||||
impl ConnectorManageRpc for ConnectorManagerRpcService {
|
||||
async fn list_connector(
|
||||
&self,
|
||||
_request: tonic::Request<ListConnectorRequest>,
|
||||
) -> Result<tonic::Response<easytier_rpc::ListConnectorResponse>, tonic::Status> {
|
||||
let mut ret = easytier_rpc::ListConnectorResponse::default();
|
||||
let connectors = self.0.list_connectors().await;
|
||||
ret.connectors = connectors;
|
||||
Ok(tonic::Response::new(ret))
|
||||
}
|
||||
|
||||
async fn manage_connector(
|
||||
&self,
|
||||
request: tonic::Request<ManageConnectorRequest>,
|
||||
) -> Result<tonic::Response<easytier_rpc::ManageConnectorResponse>, tonic::Status> {
|
||||
let req = request.into_inner();
|
||||
let url = url::Url::parse(&req.url)
|
||||
.map_err(|_| tonic::Status::invalid_argument("invalid url"))?;
|
||||
if req.action == easytier_rpc::ConnectorManageAction::Remove as i32 {
|
||||
self.0.remove_connector(url.path()).await.map_err(|e| {
|
||||
tonic::Status::invalid_argument(format!("remove connector failed: {:?}", e))
|
||||
})?;
|
||||
return Ok(tonic::Response::new(
|
||||
easytier_rpc::ManageConnectorResponse::default(),
|
||||
));
|
||||
} else {
|
||||
self.0
|
||||
.add_connector_by_url(url.as_str())
|
||||
.await
|
||||
.map_err(|e| {
|
||||
tonic::Status::invalid_argument(format!("add connector failed: {:?}", e))
|
||||
})?;
|
||||
}
|
||||
Ok(tonic::Response::new(
|
||||
easytier_rpc::ManageConnectorResponse::default(),
|
||||
))
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use crate::{
|
||||
peers::tests::create_mock_peer_manager,
|
||||
set_global_var,
|
||||
tunnels::{Tunnel, TunnelError},
|
||||
};
|
||||
|
||||
use super::*;
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_reconnect_with_connecting_addr() {
|
||||
set_global_var!(MANUAL_CONNECTOR_RECONNECT_INTERVAL_MS, 1);
|
||||
|
||||
let peer_mgr = create_mock_peer_manager().await;
|
||||
let mgr = ManualConnectorManager::new(peer_mgr.get_global_ctx(), peer_mgr);
|
||||
|
||||
struct MockConnector {}
|
||||
#[async_trait::async_trait]
|
||||
impl TunnelConnector for MockConnector {
|
||||
fn remote_url(&self) -> url::Url {
|
||||
url::Url::parse("tcp://aa.com").unwrap()
|
||||
}
|
||||
async fn connect(&mut self) -> Result<Box<dyn Tunnel>, TunnelError> {
|
||||
tokio::time::sleep(std::time::Duration::from_millis(10)).await;
|
||||
Err(TunnelError::CommonError("fake error".into()))
|
||||
}
|
||||
}
|
||||
|
||||
mgr.add_connector(MockConnector {});
|
||||
|
||||
tokio::time::sleep(std::time::Duration::from_secs(5)).await;
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,99 @@
|
||||
use std::{
|
||||
net::{SocketAddr, SocketAddrV4, SocketAddrV6},
|
||||
sync::Arc,
|
||||
};
|
||||
|
||||
use crate::{
|
||||
common::{error::Error, global_ctx::ArcGlobalCtx, network::IPCollector},
|
||||
tunnels::{
|
||||
ring_tunnel::RingTunnelConnector,
|
||||
tcp_tunnel::TcpTunnelConnector,
|
||||
udp_tunnel::UdpTunnelConnector,
|
||||
wireguard::{WgConfig, WgTunnelConnector},
|
||||
TunnelConnector,
|
||||
},
|
||||
};
|
||||
|
||||
pub mod direct;
|
||||
pub mod manual;
|
||||
pub mod udp_hole_punch;
|
||||
|
||||
async fn set_bind_addr_for_peer_connector(
|
||||
connector: &mut impl TunnelConnector,
|
||||
is_ipv4: bool,
|
||||
ip_collector: &Arc<IPCollector>,
|
||||
) {
|
||||
let ips = ip_collector.collect_ip_addrs().await;
|
||||
if is_ipv4 {
|
||||
let mut bind_addrs = vec![];
|
||||
for ipv4 in ips.interface_ipv4s {
|
||||
let socket_addr = SocketAddrV4::new(ipv4.parse().unwrap(), 0).into();
|
||||
bind_addrs.push(socket_addr);
|
||||
}
|
||||
connector.set_bind_addrs(bind_addrs);
|
||||
} else {
|
||||
let mut bind_addrs = vec![];
|
||||
for ipv6 in ips.interface_ipv6s {
|
||||
let socket_addr = SocketAddrV6::new(ipv6.parse().unwrap(), 0, 0, 0).into();
|
||||
bind_addrs.push(socket_addr);
|
||||
}
|
||||
connector.set_bind_addrs(bind_addrs);
|
||||
}
|
||||
let _ = connector;
|
||||
}
|
||||
|
||||
pub async fn create_connector_by_url(
|
||||
url: &str,
|
||||
global_ctx: &ArcGlobalCtx,
|
||||
) -> Result<Box<dyn TunnelConnector + Send + Sync + 'static>, Error> {
|
||||
let url = url::Url::parse(url).map_err(|_| Error::InvalidUrl(url.to_owned()))?;
|
||||
match url.scheme() {
|
||||
"tcp" => {
|
||||
let dst_addr =
|
||||
crate::tunnels::check_scheme_and_get_socket_addr::<SocketAddr>(&url, "tcp")?;
|
||||
let mut connector = TcpTunnelConnector::new(url);
|
||||
set_bind_addr_for_peer_connector(
|
||||
&mut connector,
|
||||
dst_addr.is_ipv4(),
|
||||
&global_ctx.get_ip_collector(),
|
||||
)
|
||||
.await;
|
||||
return Ok(Box::new(connector));
|
||||
}
|
||||
"udp" => {
|
||||
let dst_addr =
|
||||
crate::tunnels::check_scheme_and_get_socket_addr::<SocketAddr>(&url, "udp")?;
|
||||
let mut connector = UdpTunnelConnector::new(url);
|
||||
set_bind_addr_for_peer_connector(
|
||||
&mut connector,
|
||||
dst_addr.is_ipv4(),
|
||||
&global_ctx.get_ip_collector(),
|
||||
)
|
||||
.await;
|
||||
return Ok(Box::new(connector));
|
||||
}
|
||||
"ring" => {
|
||||
crate::tunnels::check_scheme_and_get_socket_addr::<uuid::Uuid>(&url, "ring")?;
|
||||
let connector = RingTunnelConnector::new(url);
|
||||
return Ok(Box::new(connector));
|
||||
}
|
||||
"wg" => {
|
||||
let dst_addr =
|
||||
crate::tunnels::check_scheme_and_get_socket_addr::<SocketAddr>(&url, "wg")?;
|
||||
let nid = global_ctx.get_network_identity();
|
||||
let wg_config =
|
||||
WgConfig::new_from_network_identity(&nid.network_name, &nid.network_secret);
|
||||
let mut connector = WgTunnelConnector::new(url, wg_config);
|
||||
set_bind_addr_for_peer_connector(
|
||||
&mut connector,
|
||||
dst_addr.is_ipv4(),
|
||||
&global_ctx.get_ip_collector(),
|
||||
)
|
||||
.await;
|
||||
return Ok(Box::new(connector));
|
||||
}
|
||||
_ => {
|
||||
return Err(Error::InvalidUrl(url.into()));
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,545 @@
|
||||
use std::{net::SocketAddr, sync::Arc};
|
||||
|
||||
use anyhow::Context;
|
||||
use crossbeam::atomic::AtomicCell;
|
||||
use rand::{seq::SliceRandom, Rng, SeedableRng};
|
||||
use tokio::{net::UdpSocket, sync::Mutex, task::JoinSet};
|
||||
use tracing::Instrument;
|
||||
|
||||
use crate::{
|
||||
common::{
|
||||
constants, error::Error, global_ctx::ArcGlobalCtx, join_joinset_background,
|
||||
rkyv_util::encode_to_bytes, stun::StunInfoCollectorTrait, PeerId,
|
||||
},
|
||||
peers::peer_manager::PeerManager,
|
||||
rpc::NatType,
|
||||
tunnels::{
|
||||
common::setup_sokcet2,
|
||||
udp_tunnel::{UdpPacket, UdpTunnelConnector, UdpTunnelListener},
|
||||
Tunnel, TunnelConnCounter, TunnelListener,
|
||||
},
|
||||
};
|
||||
|
||||
use super::direct::PeerManagerForDirectConnector;
|
||||
|
||||
#[tarpc::service]
|
||||
pub trait UdpHolePunchService {
|
||||
async fn try_punch_hole(local_mapped_addr: SocketAddr) -> Option<SocketAddr>;
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct UdpHolePunchListener {
|
||||
socket: Arc<UdpSocket>,
|
||||
tasks: JoinSet<()>,
|
||||
running: Arc<AtomicCell<bool>>,
|
||||
mapped_addr: SocketAddr,
|
||||
conn_counter: Arc<Box<dyn TunnelConnCounter>>,
|
||||
|
||||
listen_time: std::time::Instant,
|
||||
last_select_time: AtomicCell<std::time::Instant>,
|
||||
last_connected_time: Arc<AtomicCell<std::time::Instant>>,
|
||||
}
|
||||
|
||||
impl UdpHolePunchListener {
|
||||
async fn get_avail_port() -> Result<u16, Error> {
|
||||
let socket = UdpSocket::bind("0.0.0.0:0").await?;
|
||||
Ok(socket.local_addr()?.port())
|
||||
}
|
||||
|
||||
pub async fn new(peer_mgr: Arc<PeerManager>) -> Result<Self, Error> {
|
||||
let port = Self::get_avail_port().await?;
|
||||
let listen_url = format!("udp://0.0.0.0:{}", port);
|
||||
|
||||
let gctx = peer_mgr.get_global_ctx();
|
||||
let stun_info_collect = gctx.get_stun_info_collector();
|
||||
let mapped_addr = stun_info_collect.get_udp_port_mapping(port).await?;
|
||||
|
||||
let mut listener = UdpTunnelListener::new(listen_url.parse().unwrap());
|
||||
|
||||
{
|
||||
let _g = peer_mgr.get_global_ctx().net_ns.guard();
|
||||
listener.listen().await?;
|
||||
}
|
||||
let socket = listener.get_socket().unwrap();
|
||||
|
||||
let running = Arc::new(AtomicCell::new(true));
|
||||
let running_clone = running.clone();
|
||||
|
||||
let last_connected_time = Arc::new(AtomicCell::new(std::time::Instant::now()));
|
||||
let last_connected_time_clone = last_connected_time.clone();
|
||||
|
||||
let conn_counter = listener.get_conn_counter();
|
||||
let mut tasks = JoinSet::new();
|
||||
|
||||
tasks.spawn(async move {
|
||||
while let Ok(conn) = listener.accept().await {
|
||||
last_connected_time_clone.store(std::time::Instant::now());
|
||||
tracing::warn!(?conn, "udp hole punching listener got peer connection");
|
||||
let peer_mgr = peer_mgr.clone();
|
||||
tokio::spawn(async move {
|
||||
if let Err(e) = peer_mgr.add_tunnel_as_server(conn).await {
|
||||
tracing::error!(
|
||||
?e,
|
||||
"failed to add tunnel as server in hole punch listener"
|
||||
);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
running_clone.store(false);
|
||||
});
|
||||
|
||||
tracing::warn!(?mapped_addr, ?socket, "udp hole punching listener started");
|
||||
|
||||
Ok(Self {
|
||||
tasks,
|
||||
socket,
|
||||
running,
|
||||
mapped_addr,
|
||||
conn_counter,
|
||||
|
||||
listen_time: std::time::Instant::now(),
|
||||
last_select_time: AtomicCell::new(std::time::Instant::now()),
|
||||
last_connected_time,
|
||||
})
|
||||
}
|
||||
|
||||
pub async fn get_socket(&self) -> Arc<UdpSocket> {
|
||||
self.last_select_time.store(std::time::Instant::now());
|
||||
self.socket.clone()
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct UdpHolePunchConnectorData {
|
||||
global_ctx: ArcGlobalCtx,
|
||||
peer_mgr: Arc<PeerManager>,
|
||||
listeners: Arc<Mutex<Vec<UdpHolePunchListener>>>,
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
struct UdpHolePunchRpcServer {
|
||||
data: Arc<UdpHolePunchConnectorData>,
|
||||
|
||||
tasks: Arc<std::sync::Mutex<JoinSet<()>>>,
|
||||
}
|
||||
|
||||
#[tarpc::server]
|
||||
impl UdpHolePunchService for UdpHolePunchRpcServer {
|
||||
async fn try_punch_hole(
|
||||
self,
|
||||
_: tarpc::context::Context,
|
||||
local_mapped_addr: SocketAddr,
|
||||
) -> Option<SocketAddr> {
|
||||
let (socket, mapped_addr) = self.select_listener().await?;
|
||||
tracing::warn!(?local_mapped_addr, ?mapped_addr, "start hole punching");
|
||||
|
||||
let my_udp_nat_type = self
|
||||
.data
|
||||
.global_ctx
|
||||
.get_stun_info_collector()
|
||||
.get_stun_info()
|
||||
.udp_nat_type;
|
||||
|
||||
// if we are restricted, we need to send hole punching resp to client
|
||||
if my_udp_nat_type == NatType::PortRestricted as i32
|
||||
|| my_udp_nat_type == NatType::Restricted as i32
|
||||
{
|
||||
// send punch msg to local_mapped_addr for 3 seconds, 3.3 packet per second
|
||||
self.tasks.lock().unwrap().spawn(async move {
|
||||
for _ in 0..10 {
|
||||
tracing::info!(?local_mapped_addr, "sending hole punching packet");
|
||||
// generate a 128 bytes vec with random data
|
||||
let mut rng = rand::rngs::StdRng::from_entropy();
|
||||
let mut buf = vec![0u8; 128];
|
||||
rng.fill(&mut buf[..]);
|
||||
|
||||
let udp_packet = UdpPacket::new_hole_punch_packet(buf);
|
||||
let udp_packet_bytes = encode_to_bytes::<_, 256>(&udp_packet);
|
||||
let _ = socket
|
||||
.send_to(udp_packet_bytes.as_ref(), local_mapped_addr)
|
||||
.await;
|
||||
tokio::time::sleep(std::time::Duration::from_millis(300)).await;
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
Some(mapped_addr)
|
||||
}
|
||||
}
|
||||
|
||||
impl UdpHolePunchRpcServer {
|
||||
pub fn new(data: Arc<UdpHolePunchConnectorData>) -> Self {
|
||||
let tasks = Arc::new(std::sync::Mutex::new(JoinSet::new()));
|
||||
join_joinset_background(tasks.clone(), "UdpHolePunchRpcServer".to_owned());
|
||||
Self { data, tasks }
|
||||
}
|
||||
|
||||
async fn select_listener(&self) -> Option<(Arc<UdpSocket>, SocketAddr)> {
|
||||
let all_listener_sockets = &self.data.listeners;
|
||||
|
||||
// remove listener that not have connection in for 20 seconds
|
||||
all_listener_sockets.lock().await.retain(|listener| {
|
||||
listener.last_connected_time.load().elapsed().as_secs() < 20
|
||||
&& listener.conn_counter.get() > 0
|
||||
});
|
||||
|
||||
let mut use_last = false;
|
||||
if all_listener_sockets.lock().await.len() < 4 {
|
||||
tracing::warn!("creating new udp hole punching listener");
|
||||
all_listener_sockets.lock().await.push(
|
||||
UdpHolePunchListener::new(self.data.peer_mgr.clone())
|
||||
.await
|
||||
.ok()?,
|
||||
);
|
||||
use_last = true;
|
||||
}
|
||||
|
||||
let locked = all_listener_sockets.lock().await;
|
||||
|
||||
let listener = if use_last {
|
||||
locked.last()?
|
||||
} else {
|
||||
locked.choose(&mut rand::rngs::StdRng::from_entropy())?
|
||||
};
|
||||
|
||||
Some((listener.get_socket().await, listener.mapped_addr))
|
||||
}
|
||||
}
|
||||
|
||||
pub struct UdpHolePunchConnector {
|
||||
data: Arc<UdpHolePunchConnectorData>,
|
||||
tasks: JoinSet<()>,
|
||||
}
|
||||
|
||||
// Currently support:
|
||||
// Symmetric -> Full Cone
|
||||
// Any Type of Full Cone -> Any Type of Full Cone
|
||||
|
||||
// if same level of full cone, node with smaller peer_id will be the initiator
|
||||
// if different level of full cone, node with more strict level will be the initiator
|
||||
|
||||
impl UdpHolePunchConnector {
|
||||
pub fn new(global_ctx: ArcGlobalCtx, peer_mgr: Arc<PeerManager>) -> Self {
|
||||
Self {
|
||||
data: Arc::new(UdpHolePunchConnectorData {
|
||||
global_ctx,
|
||||
peer_mgr,
|
||||
listeners: Arc::new(Mutex::new(Vec::new())),
|
||||
}),
|
||||
tasks: JoinSet::new(),
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn run_as_client(&mut self) -> Result<(), Error> {
|
||||
let data = self.data.clone();
|
||||
self.tasks.spawn(async move {
|
||||
Self::main_loop(data).await;
|
||||
});
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn run_as_server(&mut self) -> Result<(), Error> {
|
||||
self.data.peer_mgr.get_peer_rpc_mgr().run_service(
|
||||
constants::UDP_HOLE_PUNCH_CONNECTOR_SERVICE_ID,
|
||||
UdpHolePunchRpcServer::new(self.data.clone()).serve(),
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn run(&mut self) -> Result<(), Error> {
|
||||
self.run_as_client().await?;
|
||||
self.run_as_server().await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn collect_peer_to_connect(data: Arc<UdpHolePunchConnectorData>) -> Vec<PeerId> {
|
||||
let mut peers_to_connect = Vec::new();
|
||||
|
||||
// do not do anything if:
|
||||
// 1. our stun test has not finished
|
||||
// 2. our nat type is OpenInternet or NoPat, which means we can wait other peers to connect us
|
||||
let my_nat_type = data
|
||||
.global_ctx
|
||||
.get_stun_info_collector()
|
||||
.get_stun_info()
|
||||
.udp_nat_type;
|
||||
|
||||
let my_nat_type = NatType::try_from(my_nat_type).unwrap();
|
||||
|
||||
if my_nat_type == NatType::Unknown
|
||||
|| my_nat_type == NatType::OpenInternet
|
||||
|| my_nat_type == NatType::NoPat
|
||||
{
|
||||
return peers_to_connect;
|
||||
}
|
||||
|
||||
// collect peer list from peer manager and do some filter:
|
||||
// 1. peers without direct conns;
|
||||
// 2. peers is full cone (any restricted type);
|
||||
for route in data.peer_mgr.list_routes().await.iter() {
|
||||
let Some(peer_stun_info) = route.stun_info.as_ref() else {
|
||||
continue;
|
||||
};
|
||||
let Ok(peer_nat_type) = NatType::try_from(peer_stun_info.udp_nat_type) else {
|
||||
continue;
|
||||
};
|
||||
|
||||
let peer_id: PeerId = route.peer_id;
|
||||
let conns = data.peer_mgr.list_peer_conns(peer_id).await;
|
||||
if conns.is_some() && conns.unwrap().len() > 0 {
|
||||
continue;
|
||||
}
|
||||
|
||||
// if peer is symmetric ignore it because we cannot connect to it
|
||||
// if peer is open internet or no pat, direct connector will connecto to it
|
||||
if peer_nat_type == NatType::Unknown
|
||||
|| peer_nat_type == NatType::OpenInternet
|
||||
|| peer_nat_type == NatType::NoPat
|
||||
|| peer_nat_type == NatType::Symmetric
|
||||
|| peer_nat_type == NatType::SymUdpFirewall
|
||||
{
|
||||
continue;
|
||||
}
|
||||
|
||||
// if we are symmetric, we can only connect to full cone
|
||||
// TODO: can also connect to restricted full cone, with some extra work
|
||||
if (my_nat_type == NatType::Symmetric || my_nat_type == NatType::SymUdpFirewall)
|
||||
&& peer_nat_type != NatType::FullCone
|
||||
{
|
||||
continue;
|
||||
}
|
||||
|
||||
// if we have smae level of full cone, node with smaller peer_id will be the initiator
|
||||
if my_nat_type == peer_nat_type {
|
||||
if data.peer_mgr.my_peer_id() > peer_id {
|
||||
continue;
|
||||
}
|
||||
} else {
|
||||
// if we have different level of full cone
|
||||
// we will be the initiator if we have more strict level
|
||||
if my_nat_type < peer_nat_type {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
tracing::info!(
|
||||
?peer_id,
|
||||
?peer_nat_type,
|
||||
?my_nat_type,
|
||||
?data.global_ctx.id,
|
||||
"found peer to do hole punching"
|
||||
);
|
||||
|
||||
peers_to_connect.push(peer_id);
|
||||
}
|
||||
|
||||
peers_to_connect
|
||||
}
|
||||
|
||||
#[tracing::instrument]
|
||||
async fn do_hole_punching(
|
||||
data: Arc<UdpHolePunchConnectorData>,
|
||||
dst_peer_id: PeerId,
|
||||
) -> Result<Box<dyn Tunnel>, anyhow::Error> {
|
||||
tracing::info!(?dst_peer_id, "start hole punching");
|
||||
// client: choose a local udp port, and get the pubic mapped port from stun server
|
||||
let socket = {
|
||||
let _g = data.global_ctx.net_ns.guard();
|
||||
UdpSocket::bind("0.0.0.0:0").await.with_context(|| "")?
|
||||
};
|
||||
let local_socket_addr = socket.local_addr()?;
|
||||
let local_port = socket.local_addr()?.port();
|
||||
drop(socket); // drop the socket to release the port
|
||||
|
||||
let local_mapped_addr = data
|
||||
.global_ctx
|
||||
.get_stun_info_collector()
|
||||
.get_udp_port_mapping(local_port)
|
||||
.await
|
||||
.with_context(|| "failed to get udp port mapping")?;
|
||||
|
||||
// client -> server: tell server the mapped port, server will return the mapped address of listening port.
|
||||
let Some(remote_mapped_addr) = data
|
||||
.peer_mgr
|
||||
.get_peer_rpc_mgr()
|
||||
.do_client_rpc_scoped(
|
||||
constants::UDP_HOLE_PUNCH_CONNECTOR_SERVICE_ID,
|
||||
dst_peer_id,
|
||||
|c| async {
|
||||
let client =
|
||||
UdpHolePunchServiceClient::new(tarpc::client::Config::default(), c).spawn();
|
||||
let remote_mapped_addr = client
|
||||
.try_punch_hole(tarpc::context::current(), local_mapped_addr)
|
||||
.await;
|
||||
tracing::info!(?remote_mapped_addr, ?dst_peer_id, "got remote mapped addr");
|
||||
remote_mapped_addr
|
||||
},
|
||||
)
|
||||
.await?
|
||||
else {
|
||||
return Err(anyhow::anyhow!("failed to get remote mapped addr"));
|
||||
};
|
||||
|
||||
// server: will send some punching resps, total 10 packets.
|
||||
// client: use the socket to create UdpTunnel with UdpTunnelConnector
|
||||
// NOTICE: UdpTunnelConnector will ignore the punching resp packet sent by remote.
|
||||
|
||||
let connector = UdpTunnelConnector::new(
|
||||
format!(
|
||||
"udp://{}:{}",
|
||||
remote_mapped_addr.ip(),
|
||||
remote_mapped_addr.port()
|
||||
)
|
||||
.to_string()
|
||||
.parse()
|
||||
.unwrap(),
|
||||
);
|
||||
|
||||
let _g = data.global_ctx.net_ns.guard();
|
||||
let socket2_socket = socket2::Socket::new(
|
||||
socket2::Domain::for_address(local_socket_addr),
|
||||
socket2::Type::DGRAM,
|
||||
Some(socket2::Protocol::UDP),
|
||||
)?;
|
||||
setup_sokcet2(&socket2_socket, &local_socket_addr)?;
|
||||
let socket = UdpSocket::from_std(socket2_socket.into())?;
|
||||
|
||||
Ok(connector
|
||||
.try_connect_with_socket(socket)
|
||||
.await
|
||||
.with_context(|| "UdpTunnelConnector failed to connect remote")?)
|
||||
}
|
||||
|
||||
async fn main_loop(data: Arc<UdpHolePunchConnectorData>) {
|
||||
loop {
|
||||
let peers_to_connect = Self::collect_peer_to_connect(data.clone()).await;
|
||||
tracing::trace!(?peers_to_connect, "peers to connect");
|
||||
if peers_to_connect.len() == 0 {
|
||||
tokio::time::sleep(std::time::Duration::from_secs(5)).await;
|
||||
continue;
|
||||
}
|
||||
|
||||
let mut tasks: JoinSet<Result<(), anyhow::Error>> = JoinSet::new();
|
||||
for peer_id in peers_to_connect {
|
||||
let data = data.clone();
|
||||
tasks.spawn(
|
||||
async move {
|
||||
let tunnel = Self::do_hole_punching(data.clone(), peer_id)
|
||||
.await
|
||||
.with_context(|| "failed to do hole punching")?;
|
||||
|
||||
let _ =
|
||||
data.peer_mgr
|
||||
.add_client_tunnel(tunnel)
|
||||
.await
|
||||
.with_context(|| {
|
||||
"failed to add tunnel as client in hole punch connector"
|
||||
})?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
.instrument(tracing::info_span!("doing hole punching client", ?peer_id)),
|
||||
);
|
||||
}
|
||||
|
||||
while let Some(res) = tasks.join_next().await {
|
||||
if let Err(e) = res {
|
||||
tracing::error!(?e, "failed to join hole punching job");
|
||||
continue;
|
||||
}
|
||||
|
||||
match res.unwrap() {
|
||||
Err(e) => {
|
||||
tracing::error!(?e, "failed to do hole punching job");
|
||||
}
|
||||
Ok(_) => {
|
||||
tracing::info!("hole punching job succeed");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
tokio::time::sleep(std::time::Duration::from_secs(10)).await;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
pub mod tests {
|
||||
use std::sync::Arc;
|
||||
|
||||
use crate::rpc::{NatType, StunInfo};
|
||||
|
||||
use crate::{
|
||||
common::{error::Error, stun::StunInfoCollectorTrait},
|
||||
connector::udp_hole_punch::UdpHolePunchConnector,
|
||||
peers::{
|
||||
peer_manager::PeerManager,
|
||||
tests::{
|
||||
connect_peer_manager, create_mock_peer_manager, wait_route_appear,
|
||||
wait_route_appear_with_cost,
|
||||
},
|
||||
},
|
||||
};
|
||||
|
||||
struct MockStunInfoCollector {
|
||||
udp_nat_type: NatType,
|
||||
}
|
||||
|
||||
#[async_trait::async_trait]
|
||||
impl StunInfoCollectorTrait for MockStunInfoCollector {
|
||||
fn get_stun_info(&self) -> StunInfo {
|
||||
StunInfo {
|
||||
udp_nat_type: self.udp_nat_type as i32,
|
||||
tcp_nat_type: NatType::Unknown as i32,
|
||||
last_update_time: std::time::Instant::now().elapsed().as_secs() as i64,
|
||||
}
|
||||
}
|
||||
|
||||
async fn get_udp_port_mapping(&self, port: u16) -> Result<std::net::SocketAddr, Error> {
|
||||
Ok(format!("127.0.0.1:{}", port).parse().unwrap())
|
||||
}
|
||||
}
|
||||
|
||||
pub fn replace_stun_info_collector(peer_mgr: Arc<PeerManager>, udp_nat_type: NatType) {
|
||||
let collector = Box::new(MockStunInfoCollector { udp_nat_type });
|
||||
peer_mgr
|
||||
.get_global_ctx()
|
||||
.replace_stun_info_collector(collector);
|
||||
}
|
||||
|
||||
pub async fn create_mock_peer_manager_with_mock_stun(
|
||||
udp_nat_type: NatType,
|
||||
) -> Arc<PeerManager> {
|
||||
let p_a = create_mock_peer_manager().await;
|
||||
replace_stun_info_collector(p_a.clone(), udp_nat_type);
|
||||
p_a
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn hole_punching() {
|
||||
let p_a = create_mock_peer_manager_with_mock_stun(NatType::PortRestricted).await;
|
||||
let p_b = create_mock_peer_manager_with_mock_stun(NatType::Symmetric).await;
|
||||
let p_c = create_mock_peer_manager_with_mock_stun(NatType::PortRestricted).await;
|
||||
connect_peer_manager(p_a.clone(), p_b.clone()).await;
|
||||
connect_peer_manager(p_b.clone(), p_c.clone()).await;
|
||||
|
||||
wait_route_appear(p_a.clone(), p_c.clone()).await.unwrap();
|
||||
|
||||
println!("{:?}", p_a.list_routes().await);
|
||||
|
||||
let mut hole_punching_a = UdpHolePunchConnector::new(p_a.get_global_ctx(), p_a.clone());
|
||||
let mut hole_punching_c = UdpHolePunchConnector::new(p_c.get_global_ctx(), p_c.clone());
|
||||
|
||||
hole_punching_a.run().await.unwrap();
|
||||
hole_punching_c.run().await.unwrap();
|
||||
|
||||
wait_route_appear_with_cost(p_a.clone(), p_c.my_peer_id(), Some(1))
|
||||
.await
|
||||
.unwrap();
|
||||
println!("{:?}", p_a.list_routes().await);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,490 @@
|
||||
#![allow(dead_code)]
|
||||
|
||||
use std::{net::SocketAddr, vec};
|
||||
|
||||
use clap::{command, Args, Parser, Subcommand};
|
||||
use rpc::vpn_portal_rpc_client::VpnPortalRpcClient;
|
||||
|
||||
mod arch;
|
||||
mod common;
|
||||
mod rpc;
|
||||
mod tunnels;
|
||||
|
||||
use crate::{
|
||||
common::stun::{StunInfoCollector, UdpNatTypeDetector},
|
||||
rpc::{
|
||||
connector_manage_rpc_client::ConnectorManageRpcClient,
|
||||
peer_center_rpc_client::PeerCenterRpcClient, peer_manage_rpc_client::PeerManageRpcClient,
|
||||
*,
|
||||
},
|
||||
};
|
||||
use humansize::format_size;
|
||||
use tabled::settings::Style;
|
||||
|
||||
#[derive(Parser, Debug)]
|
||||
#[command(author, version, about, long_about = None)]
|
||||
struct Cli {
|
||||
/// the instance name
|
||||
#[arg(short = 'p', long, default_value = "127.0.0.1:15888")]
|
||||
rpc_portal: SocketAddr,
|
||||
|
||||
#[command(subcommand)]
|
||||
sub_command: SubCommand,
|
||||
}
|
||||
|
||||
#[derive(Subcommand, Debug)]
|
||||
enum SubCommand {
|
||||
Peer(PeerArgs),
|
||||
Connector(ConnectorArgs),
|
||||
Stun,
|
||||
Route,
|
||||
PeerCenter,
|
||||
VpnPortal,
|
||||
}
|
||||
|
||||
#[derive(Args, Debug)]
|
||||
struct PeerArgs {
|
||||
#[arg(short, long)]
|
||||
ipv4: Option<String>,
|
||||
|
||||
#[arg(short, long)]
|
||||
peers: Vec<String>,
|
||||
|
||||
#[command(subcommand)]
|
||||
sub_command: Option<PeerSubCommand>,
|
||||
}
|
||||
|
||||
#[derive(Args, Debug)]
|
||||
struct PeerListArgs {
|
||||
#[arg(short, long)]
|
||||
verbose: bool,
|
||||
}
|
||||
|
||||
#[derive(Subcommand, Debug)]
|
||||
enum PeerSubCommand {
|
||||
Add,
|
||||
Remove,
|
||||
List(PeerListArgs),
|
||||
}
|
||||
|
||||
#[derive(Args, Debug)]
|
||||
struct ConnectorArgs {
|
||||
#[arg(short, long)]
|
||||
ipv4: Option<String>,
|
||||
|
||||
#[arg(short, long)]
|
||||
peers: Vec<String>,
|
||||
|
||||
#[command(subcommand)]
|
||||
sub_command: Option<ConnectorSubCommand>,
|
||||
}
|
||||
|
||||
#[derive(Subcommand, Debug)]
|
||||
enum ConnectorSubCommand {
|
||||
Add,
|
||||
Remove,
|
||||
List,
|
||||
}
|
||||
|
||||
#[derive(thiserror::Error, Debug)]
|
||||
enum Error {
|
||||
#[error("tonic transport error")]
|
||||
TonicTransportError(#[from] tonic::transport::Error),
|
||||
#[error("tonic rpc error")]
|
||||
TonicRpcError(#[from] tonic::Status),
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct PeerRoutePair {
|
||||
route: Route,
|
||||
peer: Option<PeerInfo>,
|
||||
}
|
||||
|
||||
impl PeerRoutePair {
|
||||
fn get_latency_ms(&self) -> Option<f64> {
|
||||
let mut ret = u64::MAX;
|
||||
let p = self.peer.as_ref()?;
|
||||
for conn in p.conns.iter() {
|
||||
let Some(stats) = &conn.stats else {
|
||||
continue;
|
||||
};
|
||||
ret = ret.min(stats.latency_us);
|
||||
}
|
||||
|
||||
if ret == u64::MAX {
|
||||
None
|
||||
} else {
|
||||
Some(f64::from(ret as u32) / 1000.0)
|
||||
}
|
||||
}
|
||||
|
||||
fn get_rx_bytes(&self) -> Option<u64> {
|
||||
let mut ret = 0;
|
||||
let p = self.peer.as_ref()?;
|
||||
for conn in p.conns.iter() {
|
||||
let Some(stats) = &conn.stats else {
|
||||
continue;
|
||||
};
|
||||
ret += stats.rx_bytes;
|
||||
}
|
||||
|
||||
if ret == 0 {
|
||||
None
|
||||
} else {
|
||||
Some(ret)
|
||||
}
|
||||
}
|
||||
|
||||
fn get_tx_bytes(&self) -> Option<u64> {
|
||||
let mut ret = 0;
|
||||
let p = self.peer.as_ref()?;
|
||||
for conn in p.conns.iter() {
|
||||
let Some(stats) = &conn.stats else {
|
||||
continue;
|
||||
};
|
||||
ret += stats.tx_bytes;
|
||||
}
|
||||
|
||||
if ret == 0 {
|
||||
None
|
||||
} else {
|
||||
Some(ret)
|
||||
}
|
||||
}
|
||||
|
||||
fn get_loss_rate(&self) -> Option<f64> {
|
||||
let mut ret = 0.0;
|
||||
let p = self.peer.as_ref()?;
|
||||
for conn in p.conns.iter() {
|
||||
ret += conn.loss_rate;
|
||||
}
|
||||
|
||||
if ret == 0.0 {
|
||||
None
|
||||
} else {
|
||||
Some(ret as f64)
|
||||
}
|
||||
}
|
||||
|
||||
fn get_conn_protos(&self) -> Option<Vec<String>> {
|
||||
let mut ret = vec![];
|
||||
let p = self.peer.as_ref()?;
|
||||
for conn in p.conns.iter() {
|
||||
let Some(tunnel_info) = &conn.tunnel else {
|
||||
continue;
|
||||
};
|
||||
// insert if not exists
|
||||
if !ret.contains(&tunnel_info.tunnel_type) {
|
||||
ret.push(tunnel_info.tunnel_type.clone());
|
||||
}
|
||||
}
|
||||
|
||||
if ret.is_empty() {
|
||||
None
|
||||
} else {
|
||||
Some(ret)
|
||||
}
|
||||
}
|
||||
|
||||
fn get_udp_nat_type(self: &Self) -> String {
|
||||
let mut ret = NatType::Unknown;
|
||||
if let Some(r) = &self.route.stun_info {
|
||||
ret = NatType::try_from(r.udp_nat_type).unwrap();
|
||||
}
|
||||
format!("{:?}", ret)
|
||||
}
|
||||
}
|
||||
|
||||
struct CommandHandler {
|
||||
addr: String,
|
||||
}
|
||||
|
||||
impl CommandHandler {
|
||||
async fn get_peer_manager_client(
|
||||
&self,
|
||||
) -> Result<PeerManageRpcClient<tonic::transport::Channel>, Error> {
|
||||
Ok(PeerManageRpcClient::connect(self.addr.clone()).await?)
|
||||
}
|
||||
|
||||
async fn get_connector_manager_client(
|
||||
&self,
|
||||
) -> Result<ConnectorManageRpcClient<tonic::transport::Channel>, Error> {
|
||||
Ok(ConnectorManageRpcClient::connect(self.addr.clone()).await?)
|
||||
}
|
||||
|
||||
async fn get_peer_center_client(
|
||||
&self,
|
||||
) -> Result<PeerCenterRpcClient<tonic::transport::Channel>, Error> {
|
||||
Ok(PeerCenterRpcClient::connect(self.addr.clone()).await?)
|
||||
}
|
||||
|
||||
async fn get_vpn_portal_client(
|
||||
&self,
|
||||
) -> Result<VpnPortalRpcClient<tonic::transport::Channel>, Error> {
|
||||
Ok(VpnPortalRpcClient::connect(self.addr.clone()).await?)
|
||||
}
|
||||
|
||||
async fn list_peers(&self) -> Result<ListPeerResponse, Error> {
|
||||
let mut client = self.get_peer_manager_client().await?;
|
||||
let request = tonic::Request::new(ListPeerRequest::default());
|
||||
let response = client.list_peer(request).await?;
|
||||
Ok(response.into_inner())
|
||||
}
|
||||
|
||||
async fn list_routes(&self) -> Result<ListRouteResponse, Error> {
|
||||
let mut client = self.get_peer_manager_client().await?;
|
||||
let request = tonic::Request::new(ListRouteRequest::default());
|
||||
let response = client.list_route(request).await?;
|
||||
Ok(response.into_inner())
|
||||
}
|
||||
|
||||
async fn list_peer_route_pair(&self) -> Result<Vec<PeerRoutePair>, Error> {
|
||||
let mut peers = self.list_peers().await?.peer_infos;
|
||||
let mut routes = self.list_routes().await?.routes;
|
||||
let mut pairs: Vec<PeerRoutePair> = vec![];
|
||||
|
||||
for route in routes.iter_mut() {
|
||||
let peer = peers.iter_mut().find(|peer| peer.peer_id == route.peer_id);
|
||||
pairs.push(PeerRoutePair {
|
||||
route: route.clone(),
|
||||
peer: peer.cloned(),
|
||||
});
|
||||
}
|
||||
|
||||
Ok(pairs)
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
fn handle_peer_add(&self, _args: PeerArgs) {
|
||||
println!("add peer");
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
fn handle_peer_remove(&self, _args: PeerArgs) {
|
||||
println!("remove peer");
|
||||
}
|
||||
|
||||
async fn handle_peer_list(&self, _args: &PeerArgs) -> Result<(), Error> {
|
||||
#[derive(tabled::Tabled)]
|
||||
struct PeerTableItem {
|
||||
ipv4: String,
|
||||
hostname: String,
|
||||
cost: String,
|
||||
lat_ms: String,
|
||||
loss_rate: String,
|
||||
rx_bytes: String,
|
||||
tx_bytes: String,
|
||||
tunnel_proto: String,
|
||||
nat_type: String,
|
||||
id: String,
|
||||
}
|
||||
|
||||
fn cost_to_str(cost: i32) -> String {
|
||||
if cost == 1 {
|
||||
"p2p".to_string()
|
||||
} else {
|
||||
format!("relay({})", cost)
|
||||
}
|
||||
}
|
||||
|
||||
fn float_to_str(f: f64, precision: usize) -> String {
|
||||
format!("{:.1$}", f, precision)
|
||||
}
|
||||
|
||||
impl From<PeerRoutePair> for PeerTableItem {
|
||||
fn from(p: PeerRoutePair) -> Self {
|
||||
PeerTableItem {
|
||||
ipv4: p.route.ipv4_addr.clone(),
|
||||
hostname: p.route.hostname.clone(),
|
||||
cost: cost_to_str(p.route.cost),
|
||||
lat_ms: float_to_str(p.get_latency_ms().unwrap_or(0.0), 3),
|
||||
loss_rate: float_to_str(p.get_loss_rate().unwrap_or(0.0), 3),
|
||||
rx_bytes: format_size(p.get_rx_bytes().unwrap_or(0), humansize::DECIMAL),
|
||||
tx_bytes: format_size(p.get_tx_bytes().unwrap_or(0), humansize::DECIMAL),
|
||||
tunnel_proto: p.get_conn_protos().unwrap_or(vec![]).join(",").to_string(),
|
||||
nat_type: p.get_udp_nat_type(),
|
||||
id: p.route.peer_id.to_string(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let mut items: Vec<PeerTableItem> = vec![];
|
||||
let peer_routes = self.list_peer_route_pair().await?;
|
||||
for p in peer_routes {
|
||||
items.push(p.into());
|
||||
}
|
||||
|
||||
println!(
|
||||
"{}",
|
||||
tabled::Table::new(items).with(Style::modern()).to_string()
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn handle_route_list(&self) -> Result<(), Error> {
|
||||
#[derive(tabled::Tabled)]
|
||||
struct RouteTableItem {
|
||||
ipv4: String,
|
||||
hostname: String,
|
||||
proxy_cidrs: String,
|
||||
next_hop_ipv4: String,
|
||||
next_hop_hostname: String,
|
||||
next_hop_lat: f64,
|
||||
cost: i32,
|
||||
}
|
||||
|
||||
let mut items: Vec<RouteTableItem> = vec![];
|
||||
let peer_routes = self.list_peer_route_pair().await?;
|
||||
for p in peer_routes.iter() {
|
||||
let Some(next_hop_pair) = peer_routes
|
||||
.iter()
|
||||
.find(|pair| pair.route.peer_id == p.route.next_hop_peer_id)
|
||||
else {
|
||||
continue;
|
||||
};
|
||||
|
||||
if p.route.cost == 1 {
|
||||
items.push(RouteTableItem {
|
||||
ipv4: p.route.ipv4_addr.clone(),
|
||||
hostname: p.route.hostname.clone(),
|
||||
proxy_cidrs: p.route.proxy_cidrs.clone().join(",").to_string(),
|
||||
next_hop_ipv4: "DIRECT".to_string(),
|
||||
next_hop_hostname: "".to_string(),
|
||||
next_hop_lat: next_hop_pair.get_latency_ms().unwrap_or(0.0),
|
||||
cost: p.route.cost,
|
||||
});
|
||||
} else {
|
||||
items.push(RouteTableItem {
|
||||
ipv4: p.route.ipv4_addr.clone(),
|
||||
hostname: p.route.hostname.clone(),
|
||||
proxy_cidrs: p.route.proxy_cidrs.clone().join(",").to_string(),
|
||||
next_hop_ipv4: next_hop_pair.route.ipv4_addr.clone(),
|
||||
next_hop_hostname: next_hop_pair.route.hostname.clone(),
|
||||
next_hop_lat: next_hop_pair.get_latency_ms().unwrap_or(0.0),
|
||||
cost: p.route.cost,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
println!(
|
||||
"{}",
|
||||
tabled::Table::new(items).with(Style::modern()).to_string()
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn handle_connector_list(&self) -> Result<(), Error> {
|
||||
let mut client = self.get_connector_manager_client().await?;
|
||||
let request = tonic::Request::new(ListConnectorRequest::default());
|
||||
let response = client.list_connector(request).await?;
|
||||
println!("response: {:#?}", response.into_inner());
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::main]
|
||||
#[tracing::instrument]
|
||||
async fn main() -> Result<(), Error> {
|
||||
let cli = Cli::parse();
|
||||
let handler = CommandHandler {
|
||||
addr: format!("http://{}:{}", cli.rpc_portal.ip(), cli.rpc_portal.port()),
|
||||
};
|
||||
|
||||
match cli.sub_command {
|
||||
SubCommand::Peer(peer_args) => match &peer_args.sub_command {
|
||||
Some(PeerSubCommand::Add) => {
|
||||
println!("add peer");
|
||||
}
|
||||
Some(PeerSubCommand::Remove) => {
|
||||
println!("remove peer");
|
||||
}
|
||||
Some(PeerSubCommand::List(arg)) => {
|
||||
if arg.verbose {
|
||||
println!("{:#?}", handler.list_peer_route_pair().await?);
|
||||
} else {
|
||||
handler.handle_peer_list(&peer_args).await?;
|
||||
}
|
||||
}
|
||||
None => {
|
||||
handler.handle_peer_list(&peer_args).await?;
|
||||
}
|
||||
},
|
||||
SubCommand::Connector(conn_args) => match conn_args.sub_command {
|
||||
Some(ConnectorSubCommand::Add) => {
|
||||
println!("add connector");
|
||||
}
|
||||
Some(ConnectorSubCommand::Remove) => {
|
||||
println!("remove connector");
|
||||
}
|
||||
Some(ConnectorSubCommand::List) => {
|
||||
handler.handle_connector_list().await?;
|
||||
}
|
||||
None => {
|
||||
handler.handle_connector_list().await?;
|
||||
}
|
||||
},
|
||||
SubCommand::Route => {
|
||||
handler.handle_route_list().await?;
|
||||
}
|
||||
SubCommand::Stun => {
|
||||
let stun = UdpNatTypeDetector::new(StunInfoCollector::get_default_servers());
|
||||
println!("udp type: {:?}", stun.get_udp_nat_type(0).await);
|
||||
}
|
||||
SubCommand::PeerCenter => {
|
||||
let mut peer_center_client = handler.get_peer_center_client().await?;
|
||||
let resp = peer_center_client
|
||||
.get_global_peer_map(GetGlobalPeerMapRequest::default())
|
||||
.await?
|
||||
.into_inner();
|
||||
|
||||
#[derive(tabled::Tabled)]
|
||||
struct PeerCenterTableItem {
|
||||
node_id: String,
|
||||
direct_peers: String,
|
||||
}
|
||||
|
||||
let mut table_rows = vec![];
|
||||
for (k, v) in resp.global_peer_map.iter() {
|
||||
let node_id = k;
|
||||
let direct_peers = v
|
||||
.direct_peers
|
||||
.iter()
|
||||
.map(|(k, v)| {
|
||||
format!(
|
||||
"{}:{:?}",
|
||||
k,
|
||||
LatencyLevel::try_from(v.latency_level).unwrap()
|
||||
)
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
table_rows.push(PeerCenterTableItem {
|
||||
node_id: node_id.to_string(),
|
||||
direct_peers: direct_peers.join("\n"),
|
||||
});
|
||||
}
|
||||
|
||||
println!(
|
||||
"{}",
|
||||
tabled::Table::new(table_rows)
|
||||
.with(Style::modern())
|
||||
.to_string()
|
||||
);
|
||||
}
|
||||
SubCommand::VpnPortal => {
|
||||
let mut vpn_portal_client = handler.get_vpn_portal_client().await?;
|
||||
let resp = vpn_portal_client
|
||||
.get_vpn_portal_info(GetVpnPortalInfoRequest::default())
|
||||
.await?
|
||||
.into_inner()
|
||||
.vpn_portal_info
|
||||
.unwrap_or_default();
|
||||
println!("portal_name: {}\n", resp.vpn_type);
|
||||
println!("client_config:{}", resp.client_config);
|
||||
println!("connected_clients:\n{:#?}", resp.connected_clients);
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
@@ -0,0 +1,428 @@
|
||||
#![allow(dead_code)]
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests;
|
||||
|
||||
use std::{backtrace, io::Write as _, net::SocketAddr};
|
||||
|
||||
use anyhow::Context;
|
||||
use clap::Parser;
|
||||
|
||||
mod arch;
|
||||
mod common;
|
||||
mod connector;
|
||||
mod gateway;
|
||||
mod instance;
|
||||
mod peer_center;
|
||||
mod peers;
|
||||
mod rpc;
|
||||
mod tunnels;
|
||||
mod vpn_portal;
|
||||
|
||||
use common::{
|
||||
config::{ConsoleLoggerConfig, FileLoggerConfig, NetworkIdentity, PeerConfig, VpnPortalConfig},
|
||||
get_logger_timer_rfc3339,
|
||||
};
|
||||
use instance::instance::Instance;
|
||||
use tracing::level_filters::LevelFilter;
|
||||
use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt, EnvFilter, Layer};
|
||||
|
||||
use crate::common::{
|
||||
config::{ConfigLoader, TomlConfigLoader},
|
||||
global_ctx::GlobalCtxEvent,
|
||||
};
|
||||
|
||||
#[derive(Parser, Debug)]
|
||||
#[command(author, version, about, long_about = None)]
|
||||
struct Cli {
|
||||
#[arg(
|
||||
long,
|
||||
help = "network name to identify this vpn network",
|
||||
default_value = "default"
|
||||
)]
|
||||
network_name: String,
|
||||
#[arg(
|
||||
long,
|
||||
help = "network secret to verify this node belongs to the vpn network",
|
||||
default_value = ""
|
||||
)]
|
||||
network_secret: String,
|
||||
|
||||
#[arg(short, long, help = "ipv4 address of this vpn node")]
|
||||
ipv4: Option<String>,
|
||||
|
||||
#[arg(short, long, help = "peers to connect initially")]
|
||||
peers: Vec<String>,
|
||||
|
||||
#[arg(short, long, help = "use a public shared node to discover peers")]
|
||||
external_node: Option<String>,
|
||||
|
||||
#[arg(
|
||||
short = 'n',
|
||||
long,
|
||||
help = "export local networks to other peers in the vpn"
|
||||
)]
|
||||
proxy_networks: Vec<String>,
|
||||
|
||||
#[arg(
|
||||
short,
|
||||
long,
|
||||
default_value = "127.0.0.1:15888",
|
||||
help = "rpc portal address to listen for management"
|
||||
)]
|
||||
rpc_portal: SocketAddr,
|
||||
|
||||
#[arg(short, long, help = "listeners to accept connections, pass '' to avoid listening.",
|
||||
default_values_t = ["tcp://0.0.0.0:11010".to_string(),
|
||||
"udp://0.0.0.0:11010".to_string(),
|
||||
"wg://0.0.0.0:11011".to_string()])]
|
||||
listeners: Vec<String>,
|
||||
|
||||
/// specify the linux network namespace, default is the root namespace
|
||||
#[arg(long)]
|
||||
net_ns: Option<String>,
|
||||
|
||||
#[arg(long, help = "console log level",
|
||||
value_parser = clap::builder::PossibleValuesParser::new(["trace", "debug", "info", "warn", "error", "off"]))]
|
||||
console_log_level: Option<String>,
|
||||
|
||||
#[arg(long, help = "file log level",
|
||||
value_parser = clap::builder::PossibleValuesParser::new(["trace", "debug", "info", "warn", "error", "off"]))]
|
||||
file_log_level: Option<String>,
|
||||
#[arg(long, help = "directory to store log files")]
|
||||
file_log_dir: Option<String>,
|
||||
|
||||
#[arg(
|
||||
short = 'm',
|
||||
long,
|
||||
default_value = "default",
|
||||
help = "instance name to identify this vpn node in same machine"
|
||||
)]
|
||||
instance_name: String,
|
||||
|
||||
#[arg(
|
||||
short = 'd',
|
||||
long,
|
||||
help = "instance uuid to identify this vpn node in whole vpn network example: 123e4567-e89b-12d3-a456-426614174000"
|
||||
)]
|
||||
instance_id: Option<String>,
|
||||
|
||||
#[arg(
|
||||
long,
|
||||
help = "url that defines the vpn portal, allow other vpn clients to connect.
|
||||
example: wg://0.0.0.0:11010/10.14.14.0/24, means the vpn portal is a wireguard server listening on vpn.example.com:11010,
|
||||
and the vpn client is in network of 10.14.14.0/24"
|
||||
)]
|
||||
vpn_portal: Option<String>,
|
||||
|
||||
#[arg(long, help = "default protocol to use when connecting to peers")]
|
||||
default_protocol: Option<String>,
|
||||
}
|
||||
|
||||
impl From<Cli> for TomlConfigLoader {
|
||||
fn from(cli: Cli) -> Self {
|
||||
let cfg = TomlConfigLoader::default();
|
||||
cfg.set_inst_name(cli.instance_name.clone());
|
||||
cfg.set_network_identity(NetworkIdentity {
|
||||
network_name: cli.network_name.clone(),
|
||||
network_secret: cli.network_secret.clone(),
|
||||
});
|
||||
|
||||
cfg.set_netns(cli.net_ns.clone());
|
||||
if let Some(ipv4) = &cli.ipv4 {
|
||||
cfg.set_ipv4(
|
||||
ipv4.parse()
|
||||
.with_context(|| format!("failed to parse ipv4 address: {}", ipv4))
|
||||
.unwrap(),
|
||||
)
|
||||
}
|
||||
|
||||
cfg.set_peers(
|
||||
cli.peers
|
||||
.iter()
|
||||
.map(|s| PeerConfig {
|
||||
uri: s
|
||||
.parse()
|
||||
.with_context(|| format!("failed to parse peer uri: {}", s))
|
||||
.unwrap(),
|
||||
})
|
||||
.collect(),
|
||||
);
|
||||
|
||||
cfg.set_listeners(
|
||||
cli.listeners
|
||||
.iter()
|
||||
.filter_map(|s| {
|
||||
if s.is_empty() {
|
||||
return None;
|
||||
}
|
||||
|
||||
Some(
|
||||
s.parse()
|
||||
.with_context(|| format!("failed to parse listener uri: {}", s))
|
||||
.unwrap(),
|
||||
)
|
||||
})
|
||||
.collect(),
|
||||
);
|
||||
|
||||
for n in cli.proxy_networks.iter() {
|
||||
cfg.add_proxy_cidr(
|
||||
n.parse()
|
||||
.with_context(|| format!("failed to parse proxy network: {}", n))
|
||||
.unwrap(),
|
||||
);
|
||||
}
|
||||
|
||||
cfg.set_rpc_portal(cli.rpc_portal);
|
||||
|
||||
if cli.external_node.is_some() {
|
||||
let mut old_peers = cfg.get_peers();
|
||||
old_peers.push(PeerConfig {
|
||||
uri: cli
|
||||
.external_node
|
||||
.clone()
|
||||
.unwrap()
|
||||
.parse()
|
||||
.with_context(|| {
|
||||
format!(
|
||||
"failed to parse external node uri: {}",
|
||||
cli.external_node.unwrap()
|
||||
)
|
||||
})
|
||||
.unwrap(),
|
||||
});
|
||||
cfg.set_peers(old_peers);
|
||||
}
|
||||
|
||||
if cli.console_log_level.is_some() {
|
||||
cfg.set_console_logger_config(ConsoleLoggerConfig {
|
||||
level: cli.console_log_level.clone(),
|
||||
});
|
||||
}
|
||||
|
||||
if cli.file_log_dir.is_some() || cli.file_log_level.is_some() {
|
||||
cfg.set_file_logger_config(FileLoggerConfig {
|
||||
level: cli.file_log_level.clone(),
|
||||
dir: cli.file_log_dir.clone(),
|
||||
file: Some(format!("easytier-{}", cli.instance_name)),
|
||||
});
|
||||
}
|
||||
|
||||
if cli.vpn_portal.is_some() {
|
||||
let url: url::Url = cli
|
||||
.vpn_portal
|
||||
.clone()
|
||||
.unwrap()
|
||||
.parse()
|
||||
.with_context(|| {
|
||||
format!(
|
||||
"failed to parse vpn portal url: {}",
|
||||
cli.vpn_portal.unwrap()
|
||||
)
|
||||
})
|
||||
.unwrap();
|
||||
cfg.set_vpn_portal_config(VpnPortalConfig {
|
||||
client_cidr: url.path()[1..]
|
||||
.parse()
|
||||
.with_context(|| {
|
||||
format!("failed to parse vpn portal client cidr: {}", url.path())
|
||||
})
|
||||
.unwrap(),
|
||||
wireguard_listen: format!("{}:{}", url.host_str().unwrap(), url.port().unwrap())
|
||||
.parse()
|
||||
.with_context(|| {
|
||||
format!(
|
||||
"failed to parse vpn portal wireguard listen address: {}",
|
||||
url.host_str().unwrap()
|
||||
)
|
||||
})
|
||||
.unwrap(),
|
||||
});
|
||||
}
|
||||
|
||||
if cli.default_protocol.is_some() {
|
||||
let mut f = cfg.get_flags();
|
||||
f.default_protocol = cli.default_protocol.as_ref().unwrap().clone();
|
||||
cfg.set_flags(f);
|
||||
}
|
||||
|
||||
cfg
|
||||
}
|
||||
}
|
||||
|
||||
fn init_logger(config: impl ConfigLoader) {
|
||||
let file_config = config.get_file_logger_config();
|
||||
let file_level = file_config
|
||||
.level
|
||||
.map(|s| s.parse().unwrap())
|
||||
.unwrap_or(LevelFilter::OFF);
|
||||
|
||||
// logger to rolling file
|
||||
let mut file_layer = None;
|
||||
if file_level != LevelFilter::OFF {
|
||||
let mut l = tracing_subscriber::fmt::layer();
|
||||
l.set_ansi(false);
|
||||
let file_filter = EnvFilter::builder()
|
||||
.with_default_directive(file_level.into())
|
||||
.from_env()
|
||||
.unwrap();
|
||||
let file_appender = tracing_appender::rolling::Builder::new()
|
||||
.rotation(tracing_appender::rolling::Rotation::DAILY)
|
||||
.max_log_files(5)
|
||||
.filename_prefix(file_config.file.unwrap_or("easytier".to_string()))
|
||||
.build(file_config.dir.unwrap_or("./".to_string()))
|
||||
.expect("failed to initialize rolling file appender");
|
||||
file_layer = Some(
|
||||
l.with_writer(file_appender)
|
||||
.with_timer(get_logger_timer_rfc3339())
|
||||
.with_filter(file_filter),
|
||||
);
|
||||
}
|
||||
|
||||
// logger to console
|
||||
let console_config = config.get_console_logger_config();
|
||||
let console_level = console_config
|
||||
.level
|
||||
.map(|s| s.parse().unwrap())
|
||||
.unwrap_or(LevelFilter::OFF);
|
||||
|
||||
let console_filter = EnvFilter::builder()
|
||||
.with_default_directive(console_level.into())
|
||||
.from_env()
|
||||
.unwrap();
|
||||
let console_layer = tracing_subscriber::fmt::layer()
|
||||
.pretty()
|
||||
.with_timer(get_logger_timer_rfc3339())
|
||||
.with_writer(std::io::stderr)
|
||||
.with_filter(console_filter);
|
||||
|
||||
tracing_subscriber::Registry::default()
|
||||
.with(console_layer)
|
||||
.with(file_layer)
|
||||
.init();
|
||||
}
|
||||
|
||||
fn print_event(msg: String) {
|
||||
println!(
|
||||
"{}: {}",
|
||||
chrono::Local::now().format("%Y-%m-%d %H:%M:%S"),
|
||||
msg
|
||||
);
|
||||
}
|
||||
|
||||
fn peer_conn_info_to_string(p: crate::rpc::PeerConnInfo) -> String {
|
||||
format!(
|
||||
"my_peer_id: {}, dst_peer_id: {}, tunnel_info: {:?}",
|
||||
p.my_peer_id, p.peer_id, p.tunnel
|
||||
)
|
||||
}
|
||||
|
||||
fn setup_panic_handler() {
|
||||
std::panic::set_hook(Box::new(|info| {
|
||||
let backtrace = backtrace::Backtrace::force_capture();
|
||||
println!("panic occurred: {:?}", info);
|
||||
let _ = std::fs::File::create("easytier-panic.log")
|
||||
.and_then(|mut f| f.write_all(format!("{:?}\n{:#?}", info, backtrace).as_bytes()));
|
||||
std::process::exit(1);
|
||||
}));
|
||||
}
|
||||
|
||||
#[tokio::main(flavor = "current_thread")]
|
||||
#[tracing::instrument]
|
||||
pub async fn main() {
|
||||
setup_panic_handler();
|
||||
|
||||
let cli = Cli::parse();
|
||||
tracing::info!(cli = ?cli, "cli args parsed");
|
||||
|
||||
let cfg: TomlConfigLoader = cli.into();
|
||||
|
||||
init_logger(&cfg);
|
||||
let mut inst = Instance::new(cfg.clone());
|
||||
|
||||
let mut events = inst.get_global_ctx().subscribe();
|
||||
tokio::spawn(async move {
|
||||
while let Ok(e) = events.recv().await {
|
||||
match e {
|
||||
GlobalCtxEvent::PeerAdded(p) => {
|
||||
print_event(format!("new peer added. peer_id: {}", p));
|
||||
}
|
||||
|
||||
GlobalCtxEvent::PeerRemoved(p) => {
|
||||
print_event(format!("peer removed. peer_id: {}", p));
|
||||
}
|
||||
|
||||
GlobalCtxEvent::PeerConnAdded(p) => {
|
||||
print_event(format!(
|
||||
"new peer connection added. conn_info: {}",
|
||||
peer_conn_info_to_string(p)
|
||||
));
|
||||
}
|
||||
|
||||
GlobalCtxEvent::PeerConnRemoved(p) => {
|
||||
print_event(format!(
|
||||
"peer connection removed. conn_info: {}",
|
||||
peer_conn_info_to_string(p)
|
||||
));
|
||||
}
|
||||
|
||||
GlobalCtxEvent::ListenerAdded(p) => {
|
||||
if p.scheme() == "ring" {
|
||||
continue;
|
||||
}
|
||||
print_event(format!("new listener added. listener: {}", p));
|
||||
}
|
||||
|
||||
GlobalCtxEvent::ConnectionAccepted(local, remote) => {
|
||||
print_event(format!(
|
||||
"new connection accepted. local: {}, remote: {}",
|
||||
local, remote
|
||||
));
|
||||
}
|
||||
|
||||
GlobalCtxEvent::ConnectionError(local, remote, err) => {
|
||||
print_event(format!(
|
||||
"connection error. local: {}, remote: {}, err: {}",
|
||||
local, remote, err
|
||||
));
|
||||
}
|
||||
|
||||
GlobalCtxEvent::TunDeviceReady(dev) => {
|
||||
print_event(format!("tun device ready. dev: {}", dev));
|
||||
}
|
||||
|
||||
GlobalCtxEvent::Connecting(dst) => {
|
||||
print_event(format!("connecting to peer. dst: {}", dst));
|
||||
}
|
||||
|
||||
GlobalCtxEvent::ConnectError(dst, err) => {
|
||||
print_event(format!("connect to peer error. dst: {}, err: {}", dst, err));
|
||||
}
|
||||
|
||||
GlobalCtxEvent::VpnPortalClientConnected(portal, client_addr) => {
|
||||
print_event(format!(
|
||||
"vpn portal client connected. portal: {}, client_addr: {}",
|
||||
portal, client_addr
|
||||
));
|
||||
}
|
||||
|
||||
GlobalCtxEvent::VpnPortalClientDisconnected(portal, client_addr) => {
|
||||
print_event(format!(
|
||||
"vpn portal client disconnected. portal: {}, client_addr: {}",
|
||||
portal, client_addr
|
||||
));
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
println!("Starting easytier with config:");
|
||||
println!("############### TOML ##############\n");
|
||||
println!("{}", cfg.dump());
|
||||
println!("-----------------------------------");
|
||||
|
||||
inst.run().await.unwrap();
|
||||
|
||||
inst.wait().await;
|
||||
}
|
||||
@@ -0,0 +1,293 @@
|
||||
use std::{
|
||||
mem::MaybeUninit,
|
||||
net::{IpAddr, Ipv4Addr, SocketAddrV4},
|
||||
sync::Arc,
|
||||
thread,
|
||||
};
|
||||
|
||||
use pnet::packet::{
|
||||
icmp::{self, IcmpTypes},
|
||||
ip::IpNextHeaderProtocols,
|
||||
ipv4::{self, Ipv4Packet, MutableIpv4Packet},
|
||||
Packet,
|
||||
};
|
||||
use socket2::Socket;
|
||||
use tokio::{
|
||||
sync::{mpsc::UnboundedSender, Mutex},
|
||||
task::JoinSet,
|
||||
};
|
||||
use tokio_util::bytes::Bytes;
|
||||
use tracing::Instrument;
|
||||
|
||||
use crate::{
|
||||
common::{error::Error, global_ctx::ArcGlobalCtx, PeerId},
|
||||
peers::{packet, peer_manager::PeerManager, PeerPacketFilter},
|
||||
};
|
||||
|
||||
use super::CidrSet;
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
|
||||
struct IcmpNatKey {
|
||||
dst_ip: std::net::IpAddr,
|
||||
icmp_id: u16,
|
||||
icmp_seq: u16,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct IcmpNatEntry {
|
||||
src_peer_id: PeerId,
|
||||
my_peer_id: PeerId,
|
||||
src_ip: IpAddr,
|
||||
start_time: std::time::Instant,
|
||||
}
|
||||
|
||||
impl IcmpNatEntry {
|
||||
fn new(src_peer_id: PeerId, my_peer_id: PeerId, src_ip: IpAddr) -> Result<Self, Error> {
|
||||
Ok(Self {
|
||||
src_peer_id,
|
||||
my_peer_id,
|
||||
src_ip,
|
||||
start_time: std::time::Instant::now(),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
type IcmpNatTable = Arc<dashmap::DashMap<IcmpNatKey, IcmpNatEntry>>;
|
||||
type NewPacketSender = tokio::sync::mpsc::UnboundedSender<IcmpNatKey>;
|
||||
type NewPacketReceiver = tokio::sync::mpsc::UnboundedReceiver<IcmpNatKey>;
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct IcmpProxy {
|
||||
global_ctx: ArcGlobalCtx,
|
||||
peer_manager: Arc<PeerManager>,
|
||||
|
||||
cidr_set: CidrSet,
|
||||
socket: socket2::Socket,
|
||||
|
||||
nat_table: IcmpNatTable,
|
||||
|
||||
tasks: Mutex<JoinSet<()>>,
|
||||
}
|
||||
|
||||
fn socket_recv(socket: &Socket, buf: &mut [MaybeUninit<u8>]) -> Result<(usize, IpAddr), Error> {
|
||||
let (size, addr) = socket.recv_from(buf)?;
|
||||
let addr = match addr.as_socket() {
|
||||
None => IpAddr::V4(Ipv4Addr::UNSPECIFIED),
|
||||
Some(add) => add.ip(),
|
||||
};
|
||||
Ok((size, addr))
|
||||
}
|
||||
|
||||
fn socket_recv_loop(
|
||||
socket: Socket,
|
||||
nat_table: IcmpNatTable,
|
||||
sender: UnboundedSender<packet::Packet>,
|
||||
) {
|
||||
let mut buf = [0u8; 4096];
|
||||
let data: &mut [MaybeUninit<u8>] = unsafe { std::mem::transmute(&mut buf[12..]) };
|
||||
|
||||
loop {
|
||||
let Ok((len, peer_ip)) = socket_recv(&socket, data) else {
|
||||
continue;
|
||||
};
|
||||
|
||||
if !peer_ip.is_ipv4() {
|
||||
continue;
|
||||
}
|
||||
|
||||
let Some(mut ipv4_packet) = MutableIpv4Packet::new(&mut buf[12..12 + len]) else {
|
||||
continue;
|
||||
};
|
||||
|
||||
let Some(icmp_packet) = icmp::echo_reply::EchoReplyPacket::new(ipv4_packet.payload())
|
||||
else {
|
||||
continue;
|
||||
};
|
||||
|
||||
if icmp_packet.get_icmp_type() != IcmpTypes::EchoReply {
|
||||
continue;
|
||||
}
|
||||
|
||||
let key = IcmpNatKey {
|
||||
dst_ip: peer_ip,
|
||||
icmp_id: icmp_packet.get_identifier(),
|
||||
icmp_seq: icmp_packet.get_sequence_number(),
|
||||
};
|
||||
|
||||
let Some((_, v)) = nat_table.remove(&key) else {
|
||||
continue;
|
||||
};
|
||||
|
||||
// send packet back to the peer where this request origin.
|
||||
let IpAddr::V4(dest_ip) = v.src_ip else {
|
||||
continue;
|
||||
};
|
||||
|
||||
ipv4_packet.set_destination(dest_ip);
|
||||
ipv4_packet.set_checksum(ipv4::checksum(&ipv4_packet.to_immutable()));
|
||||
|
||||
let peer_packet = packet::Packet::new_data_packet(
|
||||
v.my_peer_id,
|
||||
v.src_peer_id,
|
||||
&ipv4_packet.to_immutable().packet(),
|
||||
);
|
||||
|
||||
if let Err(e) = sender.send(peer_packet) {
|
||||
tracing::error!("send icmp packet to peer failed: {:?}, may exiting..", e);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait::async_trait]
|
||||
impl PeerPacketFilter for IcmpProxy {
|
||||
async fn try_process_packet_from_peer(
|
||||
&self,
|
||||
packet: &packet::ArchivedPacket,
|
||||
_: &Bytes,
|
||||
) -> Option<()> {
|
||||
let _ = self.global_ctx.get_ipv4()?;
|
||||
|
||||
if packet.packet_type != packet::PacketType::Data {
|
||||
return None;
|
||||
};
|
||||
|
||||
let ipv4 = Ipv4Packet::new(&packet.payload.as_bytes())?;
|
||||
|
||||
if ipv4.get_version() != 4 || ipv4.get_next_level_protocol() != IpNextHeaderProtocols::Icmp
|
||||
{
|
||||
return None;
|
||||
}
|
||||
|
||||
if !self.cidr_set.contains_v4(ipv4.get_destination()) {
|
||||
return None;
|
||||
}
|
||||
|
||||
let icmp_packet = icmp::echo_request::EchoRequestPacket::new(&ipv4.payload())?;
|
||||
|
||||
if icmp_packet.get_icmp_type() != IcmpTypes::EchoRequest {
|
||||
// drop it because we do not support other icmp types
|
||||
tracing::trace!("unsupported icmp type: {:?}", icmp_packet.get_icmp_type());
|
||||
return Some(());
|
||||
}
|
||||
|
||||
let icmp_id = icmp_packet.get_identifier();
|
||||
let icmp_seq = icmp_packet.get_sequence_number();
|
||||
|
||||
let key = IcmpNatKey {
|
||||
dst_ip: ipv4.get_destination().into(),
|
||||
icmp_id,
|
||||
icmp_seq,
|
||||
};
|
||||
|
||||
let value = IcmpNatEntry::new(
|
||||
packet.from_peer.into(),
|
||||
packet.to_peer.into(),
|
||||
ipv4.get_source().into(),
|
||||
)
|
||||
.ok()?;
|
||||
|
||||
if let Some(old) = self.nat_table.insert(key, value) {
|
||||
tracing::info!("icmp nat table entry replaced: {:?}", old);
|
||||
}
|
||||
|
||||
if let Err(e) = self.send_icmp_packet(ipv4.get_destination(), &icmp_packet) {
|
||||
tracing::error!("send icmp packet failed: {:?}", e);
|
||||
}
|
||||
|
||||
Some(())
|
||||
}
|
||||
}
|
||||
|
||||
impl IcmpProxy {
|
||||
pub fn new(
|
||||
global_ctx: ArcGlobalCtx,
|
||||
peer_manager: Arc<PeerManager>,
|
||||
) -> Result<Arc<Self>, Error> {
|
||||
let cidr_set = CidrSet::new(global_ctx.clone());
|
||||
|
||||
let _g = global_ctx.net_ns.guard();
|
||||
let socket = socket2::Socket::new(
|
||||
socket2::Domain::IPV4,
|
||||
socket2::Type::RAW,
|
||||
Some(socket2::Protocol::ICMPV4),
|
||||
)?;
|
||||
socket.bind(&socket2::SockAddr::from(SocketAddrV4::new(
|
||||
std::net::Ipv4Addr::UNSPECIFIED,
|
||||
0,
|
||||
)))?;
|
||||
|
||||
let ret = Self {
|
||||
global_ctx,
|
||||
peer_manager,
|
||||
cidr_set,
|
||||
socket,
|
||||
|
||||
nat_table: Arc::new(dashmap::DashMap::new()),
|
||||
tasks: Mutex::new(JoinSet::new()),
|
||||
};
|
||||
|
||||
Ok(Arc::new(ret))
|
||||
}
|
||||
|
||||
pub async fn start(self: &Arc<Self>) -> Result<(), Error> {
|
||||
self.start_icmp_proxy().await?;
|
||||
self.start_nat_table_cleaner().await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn start_nat_table_cleaner(self: &Arc<Self>) -> Result<(), Error> {
|
||||
let nat_table = self.nat_table.clone();
|
||||
self.tasks.lock().await.spawn(
|
||||
async move {
|
||||
loop {
|
||||
tokio::time::sleep(std::time::Duration::from_secs(1)).await;
|
||||
nat_table.retain(|_, v| v.start_time.elapsed().as_secs() < 20);
|
||||
}
|
||||
}
|
||||
.instrument(tracing::info_span!("icmp proxy nat table cleaner")),
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn start_icmp_proxy(self: &Arc<Self>) -> Result<(), Error> {
|
||||
let socket = self.socket.try_clone()?;
|
||||
let (sender, mut receiver) = tokio::sync::mpsc::unbounded_channel();
|
||||
let nat_table = self.nat_table.clone();
|
||||
thread::spawn(|| {
|
||||
socket_recv_loop(socket, nat_table, sender);
|
||||
});
|
||||
|
||||
let peer_manager = self.peer_manager.clone();
|
||||
self.tasks.lock().await.spawn(
|
||||
async move {
|
||||
while let Some(msg) = receiver.recv().await {
|
||||
let to_peer_id = msg.to_peer.into();
|
||||
let ret = peer_manager.send_msg(msg.into(), to_peer_id).await;
|
||||
if ret.is_err() {
|
||||
tracing::error!("send icmp packet to peer failed: {:?}", ret);
|
||||
}
|
||||
}
|
||||
}
|
||||
.instrument(tracing::info_span!("icmp proxy send loop")),
|
||||
);
|
||||
|
||||
self.peer_manager
|
||||
.add_packet_process_pipeline(Box::new(self.clone()))
|
||||
.await;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn send_icmp_packet(
|
||||
&self,
|
||||
dst_ip: Ipv4Addr,
|
||||
icmp_packet: &icmp::echo_request::EchoRequestPacket,
|
||||
) -> Result<(), Error> {
|
||||
self.socket.send_to(
|
||||
icmp_packet.packet(),
|
||||
&SocketAddrV4::new(dst_ip.into(), 0).into(),
|
||||
)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,56 @@
|
||||
use dashmap::DashSet;
|
||||
use std::sync::Arc;
|
||||
use tokio::task::JoinSet;
|
||||
|
||||
use crate::common::global_ctx::ArcGlobalCtx;
|
||||
|
||||
pub mod icmp_proxy;
|
||||
pub mod tcp_proxy;
|
||||
pub mod udp_proxy;
|
||||
|
||||
#[derive(Debug)]
|
||||
struct CidrSet {
|
||||
global_ctx: ArcGlobalCtx,
|
||||
cidr_set: Arc<DashSet<cidr::IpCidr>>,
|
||||
tasks: JoinSet<()>,
|
||||
}
|
||||
|
||||
impl CidrSet {
|
||||
pub fn new(global_ctx: ArcGlobalCtx) -> Self {
|
||||
let mut ret = Self {
|
||||
global_ctx,
|
||||
cidr_set: Arc::new(DashSet::new()),
|
||||
tasks: JoinSet::new(),
|
||||
};
|
||||
ret.run_cidr_updater();
|
||||
ret
|
||||
}
|
||||
|
||||
fn run_cidr_updater(&mut self) {
|
||||
let global_ctx = self.global_ctx.clone();
|
||||
let cidr_set = self.cidr_set.clone();
|
||||
self.tasks.spawn(async move {
|
||||
let mut last_cidrs = vec![];
|
||||
loop {
|
||||
let cidrs = global_ctx.get_proxy_cidrs();
|
||||
if cidrs != last_cidrs {
|
||||
last_cidrs = cidrs.clone();
|
||||
cidr_set.clear();
|
||||
for cidr in cidrs.iter() {
|
||||
cidr_set.insert(cidr.clone());
|
||||
}
|
||||
}
|
||||
tokio::time::sleep(std::time::Duration::from_secs(1)).await;
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
pub fn contains_v4(&self, ip: std::net::Ipv4Addr) -> bool {
|
||||
let ip = ip.into();
|
||||
return self.cidr_set.iter().any(|cidr| cidr.contains(&ip));
|
||||
}
|
||||
|
||||
pub fn is_empty(&self) -> bool {
|
||||
return self.cidr_set.is_empty();
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,407 @@
|
||||
use crossbeam::atomic::AtomicCell;
|
||||
use dashmap::DashMap;
|
||||
use pnet::packet::ip::IpNextHeaderProtocols;
|
||||
use pnet::packet::ipv4::{Ipv4Packet, MutableIpv4Packet};
|
||||
use pnet::packet::tcp::{ipv4_checksum, MutableTcpPacket};
|
||||
use std::net::{IpAddr, Ipv4Addr, SocketAddr, SocketAddrV4};
|
||||
use std::sync::atomic::AtomicU16;
|
||||
use std::sync::Arc;
|
||||
use std::time::{Duration, Instant};
|
||||
use tokio::io::copy_bidirectional;
|
||||
use tokio::net::{TcpListener, TcpSocket, TcpStream};
|
||||
use tokio::sync::Mutex;
|
||||
use tokio::task::JoinSet;
|
||||
use tokio_util::bytes::{Bytes, BytesMut};
|
||||
use tracing::Instrument;
|
||||
|
||||
use crate::common::error::Result;
|
||||
use crate::common::global_ctx::GlobalCtx;
|
||||
use crate::common::join_joinset_background;
|
||||
use crate::common::netns::NetNS;
|
||||
use crate::peers::packet::{self, ArchivedPacket};
|
||||
use crate::peers::peer_manager::PeerManager;
|
||||
use crate::peers::{NicPacketFilter, PeerPacketFilter};
|
||||
|
||||
use super::CidrSet;
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq)]
|
||||
enum NatDstEntryState {
|
||||
// receive syn packet but not start connecting to dst
|
||||
SynReceived,
|
||||
// connecting to dst
|
||||
ConnectingDst,
|
||||
// connected to dst
|
||||
Connected,
|
||||
// connection closed
|
||||
Closed,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct NatDstEntry {
|
||||
id: uuid::Uuid,
|
||||
src: SocketAddr,
|
||||
dst: SocketAddr,
|
||||
start_time: Instant,
|
||||
tasks: Mutex<JoinSet<()>>,
|
||||
state: AtomicCell<NatDstEntryState>,
|
||||
}
|
||||
|
||||
impl NatDstEntry {
|
||||
pub fn new(src: SocketAddr, dst: SocketAddr) -> Self {
|
||||
Self {
|
||||
id: uuid::Uuid::new_v4(),
|
||||
src,
|
||||
dst,
|
||||
start_time: Instant::now(),
|
||||
tasks: Mutex::new(JoinSet::new()),
|
||||
state: AtomicCell::new(NatDstEntryState::SynReceived),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
type ArcNatDstEntry = Arc<NatDstEntry>;
|
||||
|
||||
type SynSockMap = Arc<DashMap<SocketAddr, ArcNatDstEntry>>;
|
||||
type ConnSockMap = Arc<DashMap<uuid::Uuid, ArcNatDstEntry>>;
|
||||
// peer src addr to nat entry, when respond tcp packet, should modify the tcp src addr to the nat entry's dst addr
|
||||
type AddrConnSockMap = Arc<DashMap<SocketAddr, ArcNatDstEntry>>;
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct TcpProxy {
|
||||
global_ctx: Arc<GlobalCtx>,
|
||||
peer_manager: Arc<PeerManager>,
|
||||
local_port: AtomicU16,
|
||||
|
||||
tasks: Arc<std::sync::Mutex<JoinSet<()>>>,
|
||||
|
||||
syn_map: SynSockMap,
|
||||
conn_map: ConnSockMap,
|
||||
addr_conn_map: AddrConnSockMap,
|
||||
|
||||
cidr_set: CidrSet,
|
||||
}
|
||||
|
||||
#[async_trait::async_trait]
|
||||
impl PeerPacketFilter for TcpProxy {
|
||||
async fn try_process_packet_from_peer(&self, packet: &ArchivedPacket, _: &Bytes) -> Option<()> {
|
||||
let ipv4_addr = self.global_ctx.get_ipv4()?;
|
||||
|
||||
if packet.packet_type != packet::PacketType::Data {
|
||||
return None;
|
||||
};
|
||||
|
||||
let payload_bytes = packet.payload.as_bytes();
|
||||
|
||||
let ipv4 = Ipv4Packet::new(payload_bytes)?;
|
||||
if ipv4.get_version() != 4 || ipv4.get_next_level_protocol() != IpNextHeaderProtocols::Tcp {
|
||||
return None;
|
||||
}
|
||||
|
||||
if !self.cidr_set.contains_v4(ipv4.get_destination()) {
|
||||
return None;
|
||||
}
|
||||
|
||||
tracing::trace!(ipv4 = ?ipv4, cidr_set = ?self.cidr_set, "proxy tcp packet received");
|
||||
|
||||
let mut packet_buffer = BytesMut::with_capacity(payload_bytes.len());
|
||||
packet_buffer.extend_from_slice(&payload_bytes.to_vec());
|
||||
|
||||
let (ip_buffer, tcp_buffer) =
|
||||
packet_buffer.split_at_mut(ipv4.get_header_length() as usize * 4);
|
||||
|
||||
let mut ip_packet = MutableIpv4Packet::new(ip_buffer).unwrap();
|
||||
let mut tcp_packet = MutableTcpPacket::new(tcp_buffer).unwrap();
|
||||
|
||||
let is_tcp_syn = tcp_packet.get_flags() & pnet::packet::tcp::TcpFlags::SYN != 0;
|
||||
if is_tcp_syn {
|
||||
let source_ip = ip_packet.get_source();
|
||||
let source_port = tcp_packet.get_source();
|
||||
let src = SocketAddr::V4(SocketAddrV4::new(source_ip, source_port));
|
||||
|
||||
let dest_ip = ip_packet.get_destination();
|
||||
let dest_port = tcp_packet.get_destination();
|
||||
let dst = SocketAddr::V4(SocketAddrV4::new(dest_ip, dest_port));
|
||||
|
||||
let old_val = self
|
||||
.syn_map
|
||||
.insert(src, Arc::new(NatDstEntry::new(src, dst)));
|
||||
tracing::trace!(src = ?src, dst = ?dst, old_entry = ?old_val, "tcp syn received");
|
||||
}
|
||||
|
||||
ip_packet.set_destination(ipv4_addr);
|
||||
tcp_packet.set_destination(self.get_local_port());
|
||||
Self::update_ipv4_packet_checksum(&mut ip_packet, &mut tcp_packet);
|
||||
|
||||
tracing::trace!(ip_packet = ?ip_packet, tcp_packet = ?tcp_packet, "tcp packet forwarded");
|
||||
|
||||
if let Err(e) = self
|
||||
.peer_manager
|
||||
.get_nic_channel()
|
||||
.send(packet_buffer.freeze())
|
||||
.await
|
||||
{
|
||||
tracing::error!("send to nic failed: {:?}", e);
|
||||
}
|
||||
|
||||
Some(())
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait::async_trait]
|
||||
impl NicPacketFilter for TcpProxy {
|
||||
async fn try_process_packet_from_nic(&self, mut data: BytesMut) -> BytesMut {
|
||||
let Some(my_ipv4) = self.global_ctx.get_ipv4() else {
|
||||
return data;
|
||||
};
|
||||
|
||||
let header_len = {
|
||||
let Some(ipv4) = &Ipv4Packet::new(&data[..]) else {
|
||||
return data;
|
||||
};
|
||||
|
||||
if ipv4.get_version() != 4
|
||||
|| ipv4.get_source() != my_ipv4
|
||||
|| ipv4.get_next_level_protocol() != IpNextHeaderProtocols::Tcp
|
||||
{
|
||||
return data;
|
||||
}
|
||||
|
||||
ipv4.get_header_length() as usize * 4
|
||||
};
|
||||
|
||||
let (ip_buffer, tcp_buffer) = data.split_at_mut(header_len);
|
||||
let mut ip_packet = MutableIpv4Packet::new(ip_buffer).unwrap();
|
||||
let mut tcp_packet = MutableTcpPacket::new(tcp_buffer).unwrap();
|
||||
|
||||
if tcp_packet.get_source() != self.get_local_port() {
|
||||
return data;
|
||||
}
|
||||
|
||||
let dst_addr = SocketAddr::V4(SocketAddrV4::new(
|
||||
ip_packet.get_destination(),
|
||||
tcp_packet.get_destination(),
|
||||
));
|
||||
|
||||
tracing::trace!(dst_addr = ?dst_addr, "tcp packet try find entry");
|
||||
let entry = if let Some(entry) = self.addr_conn_map.get(&dst_addr) {
|
||||
entry
|
||||
} else {
|
||||
let Some(syn_entry) = self.syn_map.get(&dst_addr) else {
|
||||
return data;
|
||||
};
|
||||
syn_entry
|
||||
};
|
||||
let nat_entry = entry.clone();
|
||||
drop(entry);
|
||||
assert_eq!(nat_entry.src, dst_addr);
|
||||
|
||||
let IpAddr::V4(ip) = nat_entry.dst.ip() else {
|
||||
panic!("v4 nat entry src ip is not v4");
|
||||
};
|
||||
|
||||
ip_packet.set_source(ip);
|
||||
tcp_packet.set_source(nat_entry.dst.port());
|
||||
Self::update_ipv4_packet_checksum(&mut ip_packet, &mut tcp_packet);
|
||||
|
||||
tracing::trace!(dst_addr = ?dst_addr, nat_entry = ?nat_entry, packet = ?ip_packet, "tcp packet after modified");
|
||||
|
||||
data
|
||||
}
|
||||
}
|
||||
|
||||
impl TcpProxy {
|
||||
pub fn new(global_ctx: Arc<GlobalCtx>, peer_manager: Arc<PeerManager>) -> Arc<Self> {
|
||||
Arc::new(Self {
|
||||
global_ctx: global_ctx.clone(),
|
||||
peer_manager,
|
||||
|
||||
local_port: AtomicU16::new(0),
|
||||
tasks: Arc::new(std::sync::Mutex::new(JoinSet::new())),
|
||||
|
||||
syn_map: Arc::new(DashMap::new()),
|
||||
conn_map: Arc::new(DashMap::new()),
|
||||
addr_conn_map: Arc::new(DashMap::new()),
|
||||
|
||||
cidr_set: CidrSet::new(global_ctx),
|
||||
})
|
||||
}
|
||||
|
||||
fn update_ipv4_packet_checksum(
|
||||
ipv4_packet: &mut MutableIpv4Packet,
|
||||
tcp_packet: &mut MutableTcpPacket,
|
||||
) {
|
||||
tcp_packet.set_checksum(ipv4_checksum(
|
||||
&tcp_packet.to_immutable(),
|
||||
&ipv4_packet.get_source(),
|
||||
&ipv4_packet.get_destination(),
|
||||
));
|
||||
|
||||
ipv4_packet.set_checksum(pnet::packet::ipv4::checksum(&ipv4_packet.to_immutable()));
|
||||
}
|
||||
|
||||
pub async fn start(self: &Arc<Self>) -> Result<()> {
|
||||
self.run_syn_map_cleaner().await?;
|
||||
self.run_listener().await?;
|
||||
self.peer_manager
|
||||
.add_packet_process_pipeline(Box::new(self.clone()))
|
||||
.await;
|
||||
self.peer_manager
|
||||
.add_nic_packet_process_pipeline(Box::new(self.clone()))
|
||||
.await;
|
||||
join_joinset_background(self.tasks.clone(), "TcpProxy".to_owned());
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn run_syn_map_cleaner(&self) -> Result<()> {
|
||||
let syn_map = self.syn_map.clone();
|
||||
let tasks = self.tasks.clone();
|
||||
let syn_map_cleaner_task = async move {
|
||||
loop {
|
||||
syn_map.retain(|_, entry| {
|
||||
if entry.start_time.elapsed() > Duration::from_secs(30) {
|
||||
tracing::warn!(entry = ?entry, "syn nat entry expired");
|
||||
entry.state.store(NatDstEntryState::Closed);
|
||||
false
|
||||
} else {
|
||||
true
|
||||
}
|
||||
});
|
||||
tokio::time::sleep(Duration::from_secs(10)).await;
|
||||
}
|
||||
};
|
||||
tasks.lock().unwrap().spawn(syn_map_cleaner_task);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn run_listener(&self) -> Result<()> {
|
||||
// bind on both v4 & v6
|
||||
let listen_addr = SocketAddr::new(Ipv4Addr::UNSPECIFIED.into(), 0);
|
||||
|
||||
let net_ns = self.global_ctx.net_ns.clone();
|
||||
let tcp_listener = net_ns
|
||||
.run_async(|| async { TcpListener::bind(&listen_addr).await })
|
||||
.await?;
|
||||
|
||||
self.local_port.store(
|
||||
tcp_listener.local_addr()?.port(),
|
||||
std::sync::atomic::Ordering::Relaxed,
|
||||
);
|
||||
|
||||
let tasks = self.tasks.clone();
|
||||
let syn_map = self.syn_map.clone();
|
||||
let conn_map = self.conn_map.clone();
|
||||
let addr_conn_map = self.addr_conn_map.clone();
|
||||
let accept_task = async move {
|
||||
tracing::info!(listener = ?tcp_listener, "tcp connection start accepting");
|
||||
|
||||
let conn_map = conn_map.clone();
|
||||
while let Ok((tcp_stream, socket_addr)) = tcp_listener.accept().await {
|
||||
let Some(entry) = syn_map.get(&socket_addr) else {
|
||||
tracing::error!("tcp connection from unknown source: {:?}", socket_addr);
|
||||
continue;
|
||||
};
|
||||
assert_eq!(entry.state.load(), NatDstEntryState::SynReceived);
|
||||
|
||||
let entry_clone = entry.clone();
|
||||
drop(entry);
|
||||
syn_map.remove_if(&socket_addr, |_, entry| entry.id == entry_clone.id);
|
||||
|
||||
entry_clone.state.store(NatDstEntryState::ConnectingDst);
|
||||
|
||||
let _ = addr_conn_map.insert(entry_clone.src, entry_clone.clone());
|
||||
let old_nat_val = conn_map.insert(entry_clone.id, entry_clone.clone());
|
||||
assert!(old_nat_val.is_none());
|
||||
|
||||
tasks.lock().unwrap().spawn(Self::connect_to_nat_dst(
|
||||
net_ns.clone(),
|
||||
tcp_stream,
|
||||
conn_map.clone(),
|
||||
addr_conn_map.clone(),
|
||||
entry_clone,
|
||||
));
|
||||
}
|
||||
tracing::error!("nat tcp listener exited");
|
||||
panic!("nat tcp listener exited");
|
||||
};
|
||||
self.tasks
|
||||
.lock()
|
||||
.unwrap()
|
||||
.spawn(accept_task.instrument(tracing::info_span!("tcp_proxy_listener")));
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn remove_entry_from_all_conn_map(
|
||||
conn_map: ConnSockMap,
|
||||
addr_conn_map: AddrConnSockMap,
|
||||
nat_entry: ArcNatDstEntry,
|
||||
) {
|
||||
conn_map.remove(&nat_entry.id);
|
||||
addr_conn_map.remove_if(&nat_entry.src, |_, entry| entry.id == nat_entry.id);
|
||||
}
|
||||
|
||||
async fn connect_to_nat_dst(
|
||||
net_ns: NetNS,
|
||||
src_tcp_stream: TcpStream,
|
||||
conn_map: ConnSockMap,
|
||||
addr_conn_map: AddrConnSockMap,
|
||||
nat_entry: ArcNatDstEntry,
|
||||
) {
|
||||
if let Err(e) = src_tcp_stream.set_nodelay(true) {
|
||||
tracing::warn!("set_nodelay failed, ignore it: {:?}", e);
|
||||
}
|
||||
|
||||
let _guard = net_ns.guard();
|
||||
let socket = TcpSocket::new_v4().unwrap();
|
||||
if let Err(e) = socket.set_nodelay(true) {
|
||||
tracing::warn!("set_nodelay failed, ignore it: {:?}", e);
|
||||
}
|
||||
let Ok(Ok(dst_tcp_stream)) = tokio::time::timeout(
|
||||
Duration::from_secs(10),
|
||||
TcpSocket::new_v4().unwrap().connect(nat_entry.dst),
|
||||
)
|
||||
.await
|
||||
else {
|
||||
tracing::error!("connect to dst failed: {:?}", nat_entry);
|
||||
nat_entry.state.store(NatDstEntryState::Closed);
|
||||
Self::remove_entry_from_all_conn_map(conn_map, addr_conn_map, nat_entry);
|
||||
return;
|
||||
};
|
||||
drop(_guard);
|
||||
|
||||
assert_eq!(nat_entry.state.load(), NatDstEntryState::ConnectingDst);
|
||||
nat_entry.state.store(NatDstEntryState::Connected);
|
||||
|
||||
Self::handle_nat_connection(
|
||||
src_tcp_stream,
|
||||
dst_tcp_stream,
|
||||
conn_map,
|
||||
addr_conn_map,
|
||||
nat_entry,
|
||||
)
|
||||
.await;
|
||||
}
|
||||
|
||||
async fn handle_nat_connection(
|
||||
mut src_tcp_stream: TcpStream,
|
||||
mut dst_tcp_stream: TcpStream,
|
||||
conn_map: ConnSockMap,
|
||||
addr_conn_map: AddrConnSockMap,
|
||||
nat_entry: ArcNatDstEntry,
|
||||
) {
|
||||
let nat_entry_clone = nat_entry.clone();
|
||||
nat_entry.tasks.lock().await.spawn(async move {
|
||||
let ret = copy_bidirectional(&mut src_tcp_stream, &mut dst_tcp_stream).await;
|
||||
tracing::trace!(nat_entry = ?nat_entry_clone, ret = ?ret, "nat tcp connection closed");
|
||||
nat_entry_clone.state.store(NatDstEntryState::Closed);
|
||||
|
||||
Self::remove_entry_from_all_conn_map(conn_map, addr_conn_map, nat_entry_clone);
|
||||
});
|
||||
}
|
||||
|
||||
pub fn get_local_port(&self) -> u16 {
|
||||
self.local_port.load(std::sync::atomic::Ordering::Relaxed)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,383 @@
|
||||
use std::{
|
||||
net::{SocketAddr, SocketAddrV4},
|
||||
sync::{atomic::AtomicBool, Arc},
|
||||
time::Duration,
|
||||
};
|
||||
|
||||
use dashmap::DashMap;
|
||||
use pnet::packet::{
|
||||
ip::IpNextHeaderProtocols,
|
||||
ipv4::{self, Ipv4Flags, Ipv4Packet, MutableIpv4Packet},
|
||||
udp::{self, MutableUdpPacket},
|
||||
Packet,
|
||||
};
|
||||
use tokio::{
|
||||
net::UdpSocket,
|
||||
sync::{
|
||||
mpsc::{unbounded_channel, UnboundedReceiver, UnboundedSender},
|
||||
Mutex,
|
||||
},
|
||||
task::{JoinHandle, JoinSet},
|
||||
time::timeout,
|
||||
};
|
||||
|
||||
use tokio_util::bytes::Bytes;
|
||||
use tracing::Level;
|
||||
|
||||
use crate::{
|
||||
common::{error::Error, global_ctx::ArcGlobalCtx, PeerId},
|
||||
peers::{packet, peer_manager::PeerManager, PeerPacketFilter},
|
||||
tunnels::common::setup_sokcet2,
|
||||
};
|
||||
|
||||
use super::CidrSet;
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
|
||||
struct UdpNatKey {
|
||||
src_socket: SocketAddr,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct UdpNatEntry {
|
||||
src_peer_id: PeerId,
|
||||
my_peer_id: PeerId,
|
||||
src_socket: SocketAddr,
|
||||
socket: UdpSocket,
|
||||
forward_task: Mutex<Option<JoinHandle<()>>>,
|
||||
stopped: AtomicBool,
|
||||
start_time: std::time::Instant,
|
||||
}
|
||||
|
||||
impl UdpNatEntry {
|
||||
#[tracing::instrument(err(level = Level::WARN))]
|
||||
fn new(src_peer_id: PeerId, my_peer_id: PeerId, src_socket: SocketAddr) -> Result<Self, Error> {
|
||||
// TODO: try use src port, so we will be ip restricted nat type
|
||||
let socket2_socket = socket2::Socket::new(
|
||||
socket2::Domain::IPV4,
|
||||
socket2::Type::DGRAM,
|
||||
Some(socket2::Protocol::UDP),
|
||||
)?;
|
||||
let dst_socket_addr = "0.0.0.0:0".parse().unwrap();
|
||||
setup_sokcet2(&socket2_socket, &dst_socket_addr)?;
|
||||
let socket = UdpSocket::from_std(socket2_socket.into())?;
|
||||
|
||||
Ok(Self {
|
||||
src_peer_id,
|
||||
my_peer_id,
|
||||
src_socket,
|
||||
socket,
|
||||
forward_task: Mutex::new(None),
|
||||
stopped: AtomicBool::new(false),
|
||||
start_time: std::time::Instant::now(),
|
||||
})
|
||||
}
|
||||
|
||||
pub fn stop(&self) {
|
||||
self.stopped
|
||||
.store(true, std::sync::atomic::Ordering::Relaxed);
|
||||
}
|
||||
|
||||
async fn compose_ipv4_packet(
|
||||
self: &Arc<Self>,
|
||||
packet_sender: &mut UnboundedSender<packet::Packet>,
|
||||
buf: &mut [u8],
|
||||
src_v4: &SocketAddrV4,
|
||||
payload_len: usize,
|
||||
payload_mtu: usize,
|
||||
ip_id: u16,
|
||||
) -> Result<(), Error> {
|
||||
let SocketAddr::V4(nat_src_v4) = self.src_socket else {
|
||||
return Err(Error::Unknown);
|
||||
};
|
||||
|
||||
assert_eq!(0, payload_mtu % 8);
|
||||
|
||||
// udp payload is in buf[20 + 8..]
|
||||
let mut udp_packet = MutableUdpPacket::new(&mut buf[20..28 + payload_len]).unwrap();
|
||||
udp_packet.set_source(src_v4.port());
|
||||
udp_packet.set_destination(self.src_socket.port());
|
||||
udp_packet.set_length(payload_len as u16 + 8);
|
||||
udp_packet.set_checksum(udp::ipv4_checksum(
|
||||
&udp_packet.to_immutable(),
|
||||
src_v4.ip(),
|
||||
nat_src_v4.ip(),
|
||||
));
|
||||
|
||||
let payload_len = payload_len + 8; // include udp header
|
||||
let total_pieces = (payload_len + payload_mtu - 1) / payload_mtu;
|
||||
let mut buf_offset = 0;
|
||||
let mut fragment_offset = 0;
|
||||
let mut cur_piece = 0;
|
||||
while fragment_offset < payload_len {
|
||||
let next_fragment_offset = std::cmp::min(fragment_offset + payload_mtu, payload_len);
|
||||
let fragment_len = next_fragment_offset - fragment_offset;
|
||||
let mut ipv4_packet =
|
||||
MutableIpv4Packet::new(&mut buf[buf_offset..buf_offset + fragment_len + 20])
|
||||
.unwrap();
|
||||
ipv4_packet.set_version(4);
|
||||
ipv4_packet.set_header_length(5);
|
||||
ipv4_packet.set_total_length((fragment_len + 20) as u16);
|
||||
ipv4_packet.set_identification(ip_id);
|
||||
if total_pieces > 1 {
|
||||
if cur_piece != total_pieces - 1 {
|
||||
ipv4_packet.set_flags(Ipv4Flags::MoreFragments);
|
||||
} else {
|
||||
ipv4_packet.set_flags(0);
|
||||
}
|
||||
assert_eq!(0, fragment_offset % 8);
|
||||
ipv4_packet.set_fragment_offset(fragment_offset as u16 / 8);
|
||||
} else {
|
||||
ipv4_packet.set_flags(Ipv4Flags::DontFragment);
|
||||
ipv4_packet.set_fragment_offset(0);
|
||||
}
|
||||
ipv4_packet.set_ecn(0);
|
||||
ipv4_packet.set_dscp(0);
|
||||
ipv4_packet.set_ttl(32);
|
||||
ipv4_packet.set_source(src_v4.ip().clone());
|
||||
ipv4_packet.set_destination(nat_src_v4.ip().clone());
|
||||
ipv4_packet.set_next_level_protocol(IpNextHeaderProtocols::Udp);
|
||||
ipv4_packet.set_checksum(ipv4::checksum(&ipv4_packet.to_immutable()));
|
||||
|
||||
tracing::trace!(?ipv4_packet, "udp nat packet response send");
|
||||
|
||||
let peer_packet = packet::Packet::new_data_packet(
|
||||
self.my_peer_id,
|
||||
self.src_peer_id,
|
||||
&ipv4_packet.to_immutable().packet(),
|
||||
);
|
||||
|
||||
if let Err(e) = packet_sender.send(peer_packet) {
|
||||
tracing::error!("send icmp packet to peer failed: {:?}, may exiting..", e);
|
||||
return Err(Error::AnyhowError(e.into()));
|
||||
}
|
||||
|
||||
buf_offset += next_fragment_offset - fragment_offset;
|
||||
fragment_offset = next_fragment_offset;
|
||||
cur_piece += 1;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn forward_task(self: Arc<Self>, mut packet_sender: UnboundedSender<packet::Packet>) {
|
||||
let mut buf = [0u8; 8192];
|
||||
let mut udp_body: &mut [u8] = unsafe { std::mem::transmute(&mut buf[20 + 8..]) };
|
||||
let mut ip_id = 1;
|
||||
|
||||
loop {
|
||||
let (len, src_socket) = match timeout(
|
||||
Duration::from_secs(120),
|
||||
self.socket.recv_from(&mut udp_body),
|
||||
)
|
||||
.await
|
||||
{
|
||||
Ok(Ok(x)) => x,
|
||||
Ok(Err(err)) => {
|
||||
tracing::error!(?err, "udp nat recv failed");
|
||||
break;
|
||||
}
|
||||
Err(err) => {
|
||||
tracing::error!(?err, "udp nat recv timeout");
|
||||
break;
|
||||
}
|
||||
};
|
||||
|
||||
tracing::trace!(?len, ?src_socket, "udp nat packet response received");
|
||||
|
||||
if self.stopped.load(std::sync::atomic::Ordering::Relaxed) {
|
||||
break;
|
||||
}
|
||||
|
||||
let SocketAddr::V4(src_v4) = src_socket else {
|
||||
continue;
|
||||
};
|
||||
|
||||
let Ok(_) = Self::compose_ipv4_packet(
|
||||
&self,
|
||||
&mut packet_sender,
|
||||
&mut buf,
|
||||
&src_v4,
|
||||
len,
|
||||
1200,
|
||||
ip_id,
|
||||
)
|
||||
.await
|
||||
else {
|
||||
break;
|
||||
};
|
||||
ip_id = ip_id.wrapping_add(1);
|
||||
}
|
||||
|
||||
self.stop();
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct UdpProxy {
|
||||
global_ctx: ArcGlobalCtx,
|
||||
peer_manager: Arc<PeerManager>,
|
||||
|
||||
cidr_set: CidrSet,
|
||||
|
||||
nat_table: Arc<DashMap<UdpNatKey, Arc<UdpNatEntry>>>,
|
||||
|
||||
sender: UnboundedSender<packet::Packet>,
|
||||
receiver: Mutex<Option<UnboundedReceiver<packet::Packet>>>,
|
||||
|
||||
tasks: Mutex<JoinSet<()>>,
|
||||
}
|
||||
|
||||
#[async_trait::async_trait]
|
||||
impl PeerPacketFilter for UdpProxy {
|
||||
async fn try_process_packet_from_peer(
|
||||
&self,
|
||||
packet: &packet::ArchivedPacket,
|
||||
_: &Bytes,
|
||||
) -> Option<()> {
|
||||
if self.cidr_set.is_empty() {
|
||||
return None;
|
||||
}
|
||||
|
||||
let _ = self.global_ctx.get_ipv4()?;
|
||||
|
||||
if packet.packet_type != packet::PacketType::Data {
|
||||
return None;
|
||||
};
|
||||
|
||||
let ipv4 = Ipv4Packet::new(packet.payload.as_bytes())?;
|
||||
|
||||
if ipv4.get_version() != 4 || ipv4.get_next_level_protocol() != IpNextHeaderProtocols::Udp {
|
||||
return None;
|
||||
}
|
||||
|
||||
if !self.cidr_set.contains_v4(ipv4.get_destination()) {
|
||||
return None;
|
||||
}
|
||||
|
||||
let udp_packet = udp::UdpPacket::new(ipv4.payload())?;
|
||||
|
||||
tracing::trace!(
|
||||
?packet,
|
||||
?ipv4,
|
||||
?udp_packet,
|
||||
"udp nat packet request received"
|
||||
);
|
||||
|
||||
let nat_key = UdpNatKey {
|
||||
src_socket: SocketAddr::new(ipv4.get_source().into(), udp_packet.get_source()),
|
||||
};
|
||||
let nat_entry = self
|
||||
.nat_table
|
||||
.entry(nat_key)
|
||||
.or_try_insert_with::<Error>(|| {
|
||||
tracing::info!(?packet, ?ipv4, ?udp_packet, "udp nat table entry created");
|
||||
let _g = self.global_ctx.net_ns.guard();
|
||||
Ok(Arc::new(UdpNatEntry::new(
|
||||
packet.from_peer.into(),
|
||||
packet.to_peer.into(),
|
||||
nat_key.src_socket,
|
||||
)?))
|
||||
})
|
||||
.ok()?
|
||||
.clone();
|
||||
|
||||
if nat_entry.forward_task.lock().await.is_none() {
|
||||
nat_entry
|
||||
.forward_task
|
||||
.lock()
|
||||
.await
|
||||
.replace(tokio::spawn(UdpNatEntry::forward_task(
|
||||
nat_entry.clone(),
|
||||
self.sender.clone(),
|
||||
)));
|
||||
}
|
||||
|
||||
// TODO: should it be async.
|
||||
let dst_socket =
|
||||
SocketAddr::new(ipv4.get_destination().into(), udp_packet.get_destination());
|
||||
let send_ret = {
|
||||
let _g = self.global_ctx.net_ns.guard();
|
||||
nat_entry
|
||||
.socket
|
||||
.send_to(udp_packet.payload(), dst_socket)
|
||||
.await
|
||||
};
|
||||
|
||||
if let Err(send_err) = send_ret {
|
||||
tracing::error!(
|
||||
?send_err,
|
||||
?nat_key,
|
||||
?nat_entry,
|
||||
?send_err,
|
||||
"udp nat send failed"
|
||||
);
|
||||
}
|
||||
|
||||
Some(())
|
||||
}
|
||||
}
|
||||
|
||||
impl UdpProxy {
|
||||
pub fn new(
|
||||
global_ctx: ArcGlobalCtx,
|
||||
peer_manager: Arc<PeerManager>,
|
||||
) -> Result<Arc<Self>, Error> {
|
||||
let cidr_set = CidrSet::new(global_ctx.clone());
|
||||
let (sender, receiver) = unbounded_channel();
|
||||
let ret = Self {
|
||||
global_ctx,
|
||||
peer_manager,
|
||||
cidr_set,
|
||||
nat_table: Arc::new(DashMap::new()),
|
||||
sender,
|
||||
receiver: Mutex::new(Some(receiver)),
|
||||
tasks: Mutex::new(JoinSet::new()),
|
||||
};
|
||||
Ok(Arc::new(ret))
|
||||
}
|
||||
|
||||
pub async fn start(self: &Arc<Self>) -> Result<(), Error> {
|
||||
self.peer_manager
|
||||
.add_packet_process_pipeline(Box::new(self.clone()))
|
||||
.await;
|
||||
|
||||
// clean up nat table
|
||||
let nat_table = self.nat_table.clone();
|
||||
self.tasks.lock().await.spawn(async move {
|
||||
loop {
|
||||
tokio::time::sleep(Duration::from_secs(15)).await;
|
||||
nat_table.retain(|_, v| {
|
||||
if v.start_time.elapsed().as_secs() > 120 {
|
||||
tracing::info!(?v, "udp nat table entry removed");
|
||||
v.stop();
|
||||
false
|
||||
} else {
|
||||
true
|
||||
}
|
||||
});
|
||||
}
|
||||
});
|
||||
|
||||
// forward packets to peer manager
|
||||
let mut receiver = self.receiver.lock().await.take().unwrap();
|
||||
let peer_manager = self.peer_manager.clone();
|
||||
self.tasks.lock().await.spawn(async move {
|
||||
while let Some(msg) = receiver.recv().await {
|
||||
let to_peer_id: PeerId = msg.to_peer.into();
|
||||
tracing::trace!(?msg, ?to_peer_id, "udp nat packet response send");
|
||||
let ret = peer_manager.send_msg(msg.into(), to_peer_id).await;
|
||||
if ret.is_err() {
|
||||
tracing::error!("send icmp packet to peer failed: {:?}", ret);
|
||||
}
|
||||
}
|
||||
});
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for UdpProxy {
|
||||
fn drop(&mut self) {
|
||||
for v in self.nat_table.iter() {
|
||||
v.stop();
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,481 @@
|
||||
use std::borrow::BorrowMut;
|
||||
use std::net::Ipv4Addr;
|
||||
use std::sync::{Arc, Weak};
|
||||
|
||||
use anyhow::Context;
|
||||
use futures::StreamExt;
|
||||
use pnet::packet::ethernet::EthernetPacket;
|
||||
use pnet::packet::ipv4::Ipv4Packet;
|
||||
|
||||
use tokio::{sync::Mutex, task::JoinSet};
|
||||
use tokio_util::bytes::{Bytes, BytesMut};
|
||||
use tonic::transport::Server;
|
||||
|
||||
use crate::common::config::ConfigLoader;
|
||||
use crate::common::error::Error;
|
||||
use crate::common::global_ctx::{ArcGlobalCtx, GlobalCtx, GlobalCtxEvent};
|
||||
use crate::common::PeerId;
|
||||
use crate::connector::direct::DirectConnectorManager;
|
||||
use crate::connector::manual::{ConnectorManagerRpcService, ManualConnectorManager};
|
||||
use crate::connector::udp_hole_punch::UdpHolePunchConnector;
|
||||
use crate::gateway::icmp_proxy::IcmpProxy;
|
||||
use crate::gateway::tcp_proxy::TcpProxy;
|
||||
use crate::gateway::udp_proxy::UdpProxy;
|
||||
use crate::peer_center::instance::PeerCenterInstance;
|
||||
use crate::peers::peer_conn::PeerConnId;
|
||||
use crate::peers::peer_manager::{PeerManager, RouteAlgoType};
|
||||
use crate::peers::rpc_service::PeerManagerRpcService;
|
||||
use crate::rpc::vpn_portal_rpc_server::VpnPortalRpc;
|
||||
use crate::rpc::{GetVpnPortalInfoRequest, GetVpnPortalInfoResponse, VpnPortalInfo};
|
||||
use crate::tunnels::SinkItem;
|
||||
use crate::vpn_portal::{self, VpnPortal};
|
||||
|
||||
use tokio_stream::wrappers::ReceiverStream;
|
||||
|
||||
use super::listeners::ListenerManager;
|
||||
use super::virtual_nic;
|
||||
|
||||
pub struct Instance {
|
||||
inst_name: String,
|
||||
|
||||
id: uuid::Uuid,
|
||||
|
||||
virtual_nic: Option<Arc<virtual_nic::VirtualNic>>,
|
||||
peer_packet_receiver: Option<ReceiverStream<SinkItem>>,
|
||||
|
||||
tasks: JoinSet<()>,
|
||||
|
||||
peer_manager: Arc<PeerManager>,
|
||||
listener_manager: Arc<Mutex<ListenerManager<PeerManager>>>,
|
||||
conn_manager: Arc<ManualConnectorManager>,
|
||||
direct_conn_manager: Arc<DirectConnectorManager>,
|
||||
udp_hole_puncher: Arc<Mutex<UdpHolePunchConnector>>,
|
||||
|
||||
tcp_proxy: Arc<TcpProxy>,
|
||||
icmp_proxy: Arc<IcmpProxy>,
|
||||
udp_proxy: Arc<UdpProxy>,
|
||||
|
||||
peer_center: Arc<PeerCenterInstance>,
|
||||
|
||||
vpn_portal: Arc<Mutex<Box<dyn VpnPortal>>>,
|
||||
|
||||
global_ctx: ArcGlobalCtx,
|
||||
}
|
||||
|
||||
impl Instance {
|
||||
pub fn new(config: impl ConfigLoader + Send + Sync + 'static) -> Self {
|
||||
let global_ctx = Arc::new(GlobalCtx::new(config));
|
||||
|
||||
log::info!(
|
||||
"[INIT] instance creating. config: {}",
|
||||
global_ctx.config.dump()
|
||||
);
|
||||
|
||||
let (peer_packet_sender, peer_packet_receiver) = tokio::sync::mpsc::channel(100);
|
||||
|
||||
let id = global_ctx.get_id();
|
||||
|
||||
let peer_manager = Arc::new(PeerManager::new(
|
||||
RouteAlgoType::Ospf,
|
||||
global_ctx.clone(),
|
||||
peer_packet_sender.clone(),
|
||||
));
|
||||
|
||||
let listener_manager = Arc::new(Mutex::new(ListenerManager::new(
|
||||
global_ctx.clone(),
|
||||
peer_manager.clone(),
|
||||
)));
|
||||
|
||||
let conn_manager = Arc::new(ManualConnectorManager::new(
|
||||
global_ctx.clone(),
|
||||
peer_manager.clone(),
|
||||
));
|
||||
|
||||
let mut direct_conn_manager =
|
||||
DirectConnectorManager::new(global_ctx.clone(), peer_manager.clone());
|
||||
direct_conn_manager.run();
|
||||
|
||||
let udp_hole_puncher = UdpHolePunchConnector::new(global_ctx.clone(), peer_manager.clone());
|
||||
|
||||
let arc_tcp_proxy = TcpProxy::new(global_ctx.clone(), peer_manager.clone());
|
||||
let arc_icmp_proxy = IcmpProxy::new(global_ctx.clone(), peer_manager.clone())
|
||||
.with_context(|| "create icmp proxy failed")
|
||||
.unwrap();
|
||||
let arc_udp_proxy = UdpProxy::new(global_ctx.clone(), peer_manager.clone())
|
||||
.with_context(|| "create udp proxy failed")
|
||||
.unwrap();
|
||||
|
||||
let peer_center = Arc::new(PeerCenterInstance::new(peer_manager.clone()));
|
||||
|
||||
let vpn_portal_inst = vpn_portal::wireguard::WireGuard::default();
|
||||
|
||||
Instance {
|
||||
inst_name: global_ctx.inst_name.clone(),
|
||||
id,
|
||||
|
||||
virtual_nic: None,
|
||||
peer_packet_receiver: Some(ReceiverStream::new(peer_packet_receiver)),
|
||||
|
||||
tasks: JoinSet::new(),
|
||||
peer_manager,
|
||||
listener_manager,
|
||||
conn_manager,
|
||||
direct_conn_manager: Arc::new(direct_conn_manager),
|
||||
udp_hole_puncher: Arc::new(Mutex::new(udp_hole_puncher)),
|
||||
|
||||
tcp_proxy: arc_tcp_proxy,
|
||||
icmp_proxy: arc_icmp_proxy,
|
||||
udp_proxy: arc_udp_proxy,
|
||||
|
||||
peer_center,
|
||||
|
||||
vpn_portal: Arc::new(Mutex::new(Box::new(vpn_portal_inst))),
|
||||
|
||||
global_ctx,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn get_conn_manager(&self) -> Arc<ManualConnectorManager> {
|
||||
self.conn_manager.clone()
|
||||
}
|
||||
|
||||
async fn do_forward_nic_to_peers_ipv4(ret: BytesMut, mgr: &PeerManager) {
|
||||
if let Some(ipv4) = Ipv4Packet::new(&ret) {
|
||||
if ipv4.get_version() != 4 {
|
||||
tracing::info!("[USER_PACKET] not ipv4 packet: {:?}", ipv4);
|
||||
return;
|
||||
}
|
||||
let dst_ipv4 = ipv4.get_destination();
|
||||
tracing::trace!(
|
||||
?ret,
|
||||
"[USER_PACKET] recv new packet from tun device and forward to peers."
|
||||
);
|
||||
let send_ret = mgr.send_msg_ipv4(ret, dst_ipv4).await;
|
||||
if send_ret.is_err() {
|
||||
tracing::trace!(?send_ret, "[USER_PACKET] send_msg_ipv4 failed")
|
||||
}
|
||||
} else {
|
||||
tracing::warn!(?ret, "[USER_PACKET] not ipv4 packet");
|
||||
}
|
||||
}
|
||||
|
||||
async fn do_forward_nic_to_peers_ethernet(mut ret: BytesMut, mgr: &PeerManager) {
|
||||
if let Some(eth) = EthernetPacket::new(&ret) {
|
||||
log::warn!("begin to forward: {:?}, type: {}", eth, eth.get_ethertype());
|
||||
Self::do_forward_nic_to_peers_ipv4(ret.split_off(14), mgr).await;
|
||||
} else {
|
||||
log::warn!("not ipv4 packet: {:?}", ret);
|
||||
}
|
||||
}
|
||||
|
||||
fn do_forward_nic_to_peers(&mut self) -> Result<(), Error> {
|
||||
// read from nic and write to corresponding tunnel
|
||||
let nic = self.virtual_nic.as_ref().unwrap();
|
||||
let nic = nic.clone();
|
||||
let mgr = self.peer_manager.clone();
|
||||
|
||||
self.tasks.spawn(async move {
|
||||
let mut stream = nic.pin_recv_stream();
|
||||
while let Some(ret) = stream.next().await {
|
||||
if ret.is_err() {
|
||||
log::error!("read from nic failed: {:?}", ret);
|
||||
break;
|
||||
}
|
||||
Self::do_forward_nic_to_peers_ipv4(ret.unwrap(), mgr.as_ref()).await;
|
||||
// Self::do_forward_nic_to_peers_ethernet(ret.into(), mgr.as_ref()).await;
|
||||
}
|
||||
});
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn do_forward_peers_to_nic(
|
||||
tasks: &mut JoinSet<()>,
|
||||
nic: Arc<virtual_nic::VirtualNic>,
|
||||
channel: Option<ReceiverStream<Bytes>>,
|
||||
) {
|
||||
tasks.spawn(async move {
|
||||
let send = nic.pin_send_stream();
|
||||
let channel = channel.unwrap();
|
||||
let ret = channel
|
||||
.map(|packet| {
|
||||
log::trace!(
|
||||
"[USER_PACKET] forward packet from peers to nic. packet: {:?}",
|
||||
packet
|
||||
);
|
||||
Ok(packet)
|
||||
})
|
||||
.forward(send)
|
||||
.await;
|
||||
if ret.is_err() {
|
||||
panic!("do_forward_tunnel_to_nic");
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
async fn add_initial_peers(&mut self) -> Result<(), Error> {
|
||||
for peer in self.global_ctx.config.get_peers().iter() {
|
||||
self.get_conn_manager()
|
||||
.add_connector_by_url(peer.uri.as_str())
|
||||
.await?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn prepare_tun_device(&mut self) -> Result<(), Error> {
|
||||
let nic = virtual_nic::VirtualNic::new(self.get_global_ctx())
|
||||
.create_dev()
|
||||
.await?;
|
||||
|
||||
self.global_ctx
|
||||
.issue_event(GlobalCtxEvent::TunDeviceReady(nic.ifname().to_string()));
|
||||
|
||||
self.virtual_nic = Some(Arc::new(nic));
|
||||
|
||||
self.do_forward_nic_to_peers().unwrap();
|
||||
Self::do_forward_peers_to_nic(
|
||||
self.tasks.borrow_mut(),
|
||||
self.virtual_nic.as_ref().unwrap().clone(),
|
||||
self.peer_packet_receiver.take(),
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn assign_ipv4_to_tun_device(&mut self, ipv4_addr: Ipv4Addr) -> Result<(), Error> {
|
||||
let nic = self.virtual_nic.as_ref().unwrap().clone();
|
||||
nic.link_up().await?;
|
||||
nic.remove_ip(None).await?;
|
||||
nic.add_ip(ipv4_addr, 24).await?;
|
||||
if cfg!(target_os = "macos") {
|
||||
nic.add_route(ipv4_addr, 24).await?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn run(&mut self) -> Result<(), Error> {
|
||||
self.prepare_tun_device().await?;
|
||||
if let Some(ipv4_addr) = self.global_ctx.get_ipv4() {
|
||||
self.assign_ipv4_to_tun_device(ipv4_addr).await?;
|
||||
}
|
||||
|
||||
self.listener_manager
|
||||
.lock()
|
||||
.await
|
||||
.prepare_listeners()
|
||||
.await?;
|
||||
self.listener_manager.lock().await.run().await?;
|
||||
self.peer_manager.run().await?;
|
||||
|
||||
self.run_rpc_server().unwrap();
|
||||
|
||||
self.tcp_proxy.start().await.unwrap();
|
||||
self.icmp_proxy.start().await.unwrap();
|
||||
self.udp_proxy.start().await.unwrap();
|
||||
self.run_proxy_cidrs_route_updater();
|
||||
|
||||
self.udp_hole_puncher.lock().await.run().await?;
|
||||
|
||||
self.peer_center.init().await;
|
||||
|
||||
self.add_initial_peers().await?;
|
||||
|
||||
if let Some(_) = self.global_ctx.get_vpn_portal_cidr() {
|
||||
self.vpn_portal
|
||||
.lock()
|
||||
.await
|
||||
.start(self.get_global_ctx(), self.get_peer_manager())
|
||||
.await?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn get_peer_manager(&self) -> Arc<PeerManager> {
|
||||
self.peer_manager.clone()
|
||||
}
|
||||
|
||||
pub async fn close_peer_conn(
|
||||
&mut self,
|
||||
peer_id: PeerId,
|
||||
conn_id: &PeerConnId,
|
||||
) -> Result<(), Error> {
|
||||
self.peer_manager
|
||||
.get_peer_map()
|
||||
.close_peer_conn(peer_id, conn_id)
|
||||
.await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn wait(&mut self) {
|
||||
while let Some(ret) = self.tasks.join_next().await {
|
||||
log::info!("task finished: {:?}", ret);
|
||||
ret.unwrap();
|
||||
}
|
||||
}
|
||||
|
||||
pub fn id(&self) -> uuid::Uuid {
|
||||
self.id
|
||||
}
|
||||
|
||||
pub fn peer_id(&self) -> PeerId {
|
||||
self.peer_manager.my_peer_id()
|
||||
}
|
||||
|
||||
fn get_vpn_portal_rpc_service(&self) -> impl VpnPortalRpc {
|
||||
struct VpnPortalRpcService {
|
||||
peer_mgr: Weak<PeerManager>,
|
||||
vpn_portal: Weak<Mutex<Box<dyn VpnPortal>>>,
|
||||
}
|
||||
|
||||
#[tonic::async_trait]
|
||||
impl VpnPortalRpc for VpnPortalRpcService {
|
||||
async fn get_vpn_portal_info(
|
||||
&self,
|
||||
_request: tonic::Request<GetVpnPortalInfoRequest>,
|
||||
) -> Result<tonic::Response<GetVpnPortalInfoResponse>, tonic::Status> {
|
||||
let Some(vpn_portal) = self.vpn_portal.upgrade() else {
|
||||
return Err(tonic::Status::unavailable("vpn portal not available"));
|
||||
};
|
||||
|
||||
let Some(peer_mgr) = self.peer_mgr.upgrade() else {
|
||||
return Err(tonic::Status::unavailable("peer manager not available"));
|
||||
};
|
||||
|
||||
let vpn_portal = vpn_portal.lock().await;
|
||||
let ret = GetVpnPortalInfoResponse {
|
||||
vpn_portal_info: Some(VpnPortalInfo {
|
||||
vpn_type: vpn_portal.name(),
|
||||
client_config: vpn_portal.dump_client_config(peer_mgr).await,
|
||||
connected_clients: vpn_portal.list_clients().await,
|
||||
}),
|
||||
};
|
||||
|
||||
Ok(tonic::Response::new(ret))
|
||||
}
|
||||
}
|
||||
|
||||
VpnPortalRpcService {
|
||||
peer_mgr: Arc::downgrade(&self.peer_manager),
|
||||
vpn_portal: Arc::downgrade(&self.vpn_portal),
|
||||
}
|
||||
}
|
||||
|
||||
fn run_rpc_server(&mut self) -> Result<(), Box<dyn std::error::Error>> {
|
||||
let Some(addr) = self.global_ctx.config.get_rpc_portal() else {
|
||||
tracing::info!("rpc server not enabled, because rpc_portal is not set.");
|
||||
return Ok(());
|
||||
};
|
||||
let peer_mgr = self.peer_manager.clone();
|
||||
let conn_manager = self.conn_manager.clone();
|
||||
let net_ns = self.global_ctx.net_ns.clone();
|
||||
let peer_center = self.peer_center.clone();
|
||||
let vpn_portal_rpc = self.get_vpn_portal_rpc_service();
|
||||
|
||||
self.tasks.spawn(async move {
|
||||
let _g = net_ns.guard();
|
||||
Server::builder()
|
||||
.add_service(
|
||||
crate::rpc::peer_manage_rpc_server::PeerManageRpcServer::new(
|
||||
PeerManagerRpcService::new(peer_mgr),
|
||||
),
|
||||
)
|
||||
.add_service(
|
||||
crate::rpc::connector_manage_rpc_server::ConnectorManageRpcServer::new(
|
||||
ConnectorManagerRpcService(conn_manager.clone()),
|
||||
),
|
||||
)
|
||||
.add_service(
|
||||
crate::rpc::peer_center_rpc_server::PeerCenterRpcServer::new(
|
||||
peer_center.get_rpc_service(),
|
||||
),
|
||||
)
|
||||
.add_service(crate::rpc::vpn_portal_rpc_server::VpnPortalRpcServer::new(
|
||||
vpn_portal_rpc,
|
||||
))
|
||||
.serve(addr)
|
||||
.await
|
||||
.with_context(|| format!("rpc server failed. addr: {}", addr))
|
||||
.unwrap();
|
||||
});
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn run_proxy_cidrs_route_updater(&mut self) {
|
||||
let peer_mgr = self.peer_manager.clone();
|
||||
let global_ctx = self.global_ctx.clone();
|
||||
let net_ns = self.global_ctx.net_ns.clone();
|
||||
let nic = self.virtual_nic.as_ref().unwrap().clone();
|
||||
|
||||
self.tasks.spawn(async move {
|
||||
let mut cur_proxy_cidrs = vec![];
|
||||
loop {
|
||||
let mut proxy_cidrs = vec![];
|
||||
let routes = peer_mgr.list_routes().await;
|
||||
for r in routes {
|
||||
for cidr in r.proxy_cidrs {
|
||||
let Ok(cidr) = cidr.parse::<cidr::Ipv4Cidr>() else {
|
||||
continue;
|
||||
};
|
||||
proxy_cidrs.push(cidr);
|
||||
}
|
||||
}
|
||||
// add vpn portal cidr to proxy_cidrs
|
||||
if let Some(vpn_cfg) = global_ctx.config.get_vpn_portal_config() {
|
||||
proxy_cidrs.push(vpn_cfg.client_cidr);
|
||||
}
|
||||
|
||||
// if route is in cur_proxy_cidrs but not in proxy_cidrs, delete it.
|
||||
for cidr in cur_proxy_cidrs.iter() {
|
||||
if proxy_cidrs.contains(cidr) {
|
||||
continue;
|
||||
}
|
||||
|
||||
let _g = net_ns.guard();
|
||||
let ret = nic
|
||||
.get_ifcfg()
|
||||
.remove_ipv4_route(
|
||||
nic.ifname(),
|
||||
cidr.first_address(),
|
||||
cidr.network_length(),
|
||||
)
|
||||
.await;
|
||||
|
||||
if ret.is_err() {
|
||||
tracing::trace!(
|
||||
cidr = ?cidr,
|
||||
err = ?ret,
|
||||
"remove route failed.",
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
for cidr in proxy_cidrs.iter() {
|
||||
if cur_proxy_cidrs.contains(cidr) {
|
||||
continue;
|
||||
}
|
||||
let _g = net_ns.guard();
|
||||
let ret = nic
|
||||
.get_ifcfg()
|
||||
.add_ipv4_route(nic.ifname(), cidr.first_address(), cidr.network_length())
|
||||
.await;
|
||||
|
||||
if ret.is_err() {
|
||||
tracing::trace!(
|
||||
cidr = ?cidr,
|
||||
err = ?ret,
|
||||
"add route failed.",
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
cur_proxy_cidrs = proxy_cidrs;
|
||||
tokio::time::sleep(std::time::Duration::from_secs(1)).await;
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
pub fn get_global_ctx(&self) -> ArcGlobalCtx {
|
||||
self.global_ctx.clone()
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,202 @@
|
||||
use std::{fmt::Debug, sync::Arc};
|
||||
|
||||
use anyhow::Context;
|
||||
use async_trait::async_trait;
|
||||
use tokio::{sync::Mutex, task::JoinSet};
|
||||
|
||||
use crate::{
|
||||
common::{
|
||||
error::Error,
|
||||
global_ctx::{ArcGlobalCtx, GlobalCtxEvent},
|
||||
netns::NetNS,
|
||||
},
|
||||
peers::peer_manager::PeerManager,
|
||||
tunnels::{
|
||||
ring_tunnel::RingTunnelListener,
|
||||
tcp_tunnel::TcpTunnelListener,
|
||||
udp_tunnel::UdpTunnelListener,
|
||||
wireguard::{WgConfig, WgTunnelListener},
|
||||
Tunnel, TunnelListener,
|
||||
},
|
||||
};
|
||||
|
||||
#[async_trait]
|
||||
pub trait TunnelHandlerForListener {
|
||||
async fn handle_tunnel(&self, tunnel: Box<dyn Tunnel>) -> Result<(), Error>;
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl TunnelHandlerForListener for PeerManager {
|
||||
#[tracing::instrument]
|
||||
async fn handle_tunnel(&self, tunnel: Box<dyn Tunnel>) -> Result<(), Error> {
|
||||
self.add_tunnel_as_server(tunnel).await
|
||||
}
|
||||
}
|
||||
|
||||
pub struct ListenerManager<H> {
|
||||
global_ctx: ArcGlobalCtx,
|
||||
net_ns: NetNS,
|
||||
listeners: Vec<Arc<Mutex<dyn TunnelListener>>>,
|
||||
peer_manager: Arc<H>,
|
||||
|
||||
tasks: JoinSet<()>,
|
||||
}
|
||||
|
||||
impl<H: TunnelHandlerForListener + Send + Sync + 'static + Debug> ListenerManager<H> {
|
||||
pub fn new(global_ctx: ArcGlobalCtx, peer_manager: Arc<H>) -> Self {
|
||||
Self {
|
||||
global_ctx: global_ctx.clone(),
|
||||
net_ns: global_ctx.net_ns.clone(),
|
||||
listeners: Vec::new(),
|
||||
peer_manager,
|
||||
tasks: JoinSet::new(),
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn prepare_listeners(&mut self) -> Result<(), Error> {
|
||||
self.add_listener(RingTunnelListener::new(
|
||||
format!("ring://{}", self.global_ctx.get_id())
|
||||
.parse()
|
||||
.unwrap(),
|
||||
))
|
||||
.await?;
|
||||
|
||||
for l in self.global_ctx.config.get_listener_uris().iter() {
|
||||
match l.scheme() {
|
||||
"tcp" => {
|
||||
self.add_listener(TcpTunnelListener::new(l.clone())).await?;
|
||||
}
|
||||
"udp" => {
|
||||
self.add_listener(UdpTunnelListener::new(l.clone())).await?;
|
||||
}
|
||||
"wg" => {
|
||||
let nid = self.global_ctx.get_network_identity();
|
||||
let wg_config =
|
||||
WgConfig::new_from_network_identity(&nid.network_name, &nid.network_secret);
|
||||
self.add_listener(WgTunnelListener::new(l.clone(), wg_config))
|
||||
.await?;
|
||||
}
|
||||
_ => {
|
||||
log::warn!("unsupported listener uri: {}", l);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn add_listener<Listener>(&mut self, listener: Listener) -> Result<(), Error>
|
||||
where
|
||||
Listener: TunnelListener + 'static,
|
||||
{
|
||||
let listener = Arc::new(Mutex::new(listener));
|
||||
self.listeners.push(listener);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tracing::instrument]
|
||||
async fn run_listener(
|
||||
listener: Arc<Mutex<dyn TunnelListener>>,
|
||||
peer_manager: Arc<H>,
|
||||
global_ctx: ArcGlobalCtx,
|
||||
) {
|
||||
let mut l = listener.lock().await;
|
||||
global_ctx.add_running_listener(l.local_url());
|
||||
global_ctx.issue_event(GlobalCtxEvent::ListenerAdded(l.local_url()));
|
||||
while let Ok(ret) = l.accept().await {
|
||||
let tunnel_info = ret.info().unwrap();
|
||||
global_ctx.issue_event(GlobalCtxEvent::ConnectionAccepted(
|
||||
tunnel_info.local_addr.clone(),
|
||||
tunnel_info.remote_addr.clone(),
|
||||
));
|
||||
tracing::info!(ret = ?ret, "conn accepted");
|
||||
let peer_manager = peer_manager.clone();
|
||||
let global_ctx = global_ctx.clone();
|
||||
tokio::spawn(async move {
|
||||
let server_ret = peer_manager.handle_tunnel(ret).await;
|
||||
if let Err(e) = &server_ret {
|
||||
global_ctx.issue_event(GlobalCtxEvent::ConnectionError(
|
||||
tunnel_info.local_addr,
|
||||
tunnel_info.remote_addr,
|
||||
e.to_string(),
|
||||
));
|
||||
tracing::error!(error = ?e, "handle conn error");
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn run(&mut self) -> Result<(), Error> {
|
||||
for listener in &self.listeners {
|
||||
let _guard = self.net_ns.guard();
|
||||
let addr = listener.lock().await.local_url();
|
||||
log::warn!("run listener: {:?}", listener);
|
||||
listener
|
||||
.lock()
|
||||
.await
|
||||
.listen()
|
||||
.await
|
||||
.with_context(|| format!("failed to add listener {}", addr))?;
|
||||
self.tasks.spawn(Self::run_listener(
|
||||
listener.clone(),
|
||||
self.peer_manager.clone(),
|
||||
self.global_ctx.clone(),
|
||||
));
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use futures::{SinkExt, StreamExt};
|
||||
use tokio::time::timeout;
|
||||
|
||||
use crate::{
|
||||
common::global_ctx::tests::get_mock_global_ctx,
|
||||
tunnels::{ring_tunnel::RingTunnelConnector, TunnelConnector},
|
||||
};
|
||||
|
||||
use super::*;
|
||||
|
||||
#[derive(Debug)]
|
||||
struct MockListenerHandler {}
|
||||
|
||||
#[async_trait]
|
||||
impl TunnelHandlerForListener for MockListenerHandler {
|
||||
async fn handle_tunnel(&self, _tunnel: Box<dyn Tunnel>) -> Result<(), Error> {
|
||||
let data = "abc";
|
||||
_tunnel.pin_sink().send(data.into()).await.unwrap();
|
||||
Err(Error::Unknown)
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn handle_error_in_accept() {
|
||||
let handler = Arc::new(MockListenerHandler {});
|
||||
let mut listener_mgr = ListenerManager::new(get_mock_global_ctx(), handler.clone());
|
||||
|
||||
let ring_id = format!("ring://{}", uuid::Uuid::new_v4());
|
||||
|
||||
listener_mgr
|
||||
.add_listener(RingTunnelListener::new(ring_id.parse().unwrap()))
|
||||
.await
|
||||
.unwrap();
|
||||
listener_mgr.run().await.unwrap();
|
||||
|
||||
let connect_once = |ring_id| async move {
|
||||
let tunnel = RingTunnelConnector::new(ring_id).connect().await.unwrap();
|
||||
assert_eq!(tunnel.pin_stream().next().await.unwrap().unwrap(), "abc");
|
||||
tunnel
|
||||
};
|
||||
|
||||
timeout(std::time::Duration::from_secs(1), async move {
|
||||
connect_once(ring_id.parse().unwrap()).await;
|
||||
// handle tunnel fail should not impact the second connect
|
||||
connect_once(ring_id.parse().unwrap()).await;
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,4 @@
|
||||
pub mod instance;
|
||||
pub mod listeners;
|
||||
pub mod tun_codec;
|
||||
pub mod virtual_nic;
|
||||
@@ -0,0 +1,179 @@
|
||||
use std::io;
|
||||
|
||||
use byteorder::{NativeEndian, NetworkEndian, WriteBytesExt};
|
||||
use tokio_util::bytes::{BufMut, Bytes, BytesMut};
|
||||
use tokio_util::codec::{Decoder, Encoder};
|
||||
|
||||
/// A packet protocol IP version
|
||||
#[derive(Debug, Clone, Copy, Default)]
|
||||
enum PacketProtocol {
|
||||
#[default]
|
||||
IPv4,
|
||||
IPv6,
|
||||
Other(u8),
|
||||
}
|
||||
|
||||
// Note: the protocol in the packet information header is platform dependent.
|
||||
impl PacketProtocol {
|
||||
#[cfg(any(target_os = "linux", target_os = "android"))]
|
||||
fn into_pi_field(self) -> Result<u16, io::Error> {
|
||||
use nix::libc;
|
||||
match self {
|
||||
PacketProtocol::IPv4 => Ok(libc::ETH_P_IP as u16),
|
||||
PacketProtocol::IPv6 => Ok(libc::ETH_P_IPV6 as u16),
|
||||
PacketProtocol::Other(_) => Err(io::Error::new(
|
||||
io::ErrorKind::Other,
|
||||
"neither an IPv4 nor IPv6 packet",
|
||||
)),
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(any(target_os = "macos", target_os = "ios"))]
|
||||
fn into_pi_field(self) -> Result<u16, io::Error> {
|
||||
use nix::libc;
|
||||
match self {
|
||||
PacketProtocol::IPv4 => Ok(libc::PF_INET as u16),
|
||||
PacketProtocol::IPv6 => Ok(libc::PF_INET6 as u16),
|
||||
PacketProtocol::Other(_) => Err(io::Error::new(
|
||||
io::ErrorKind::Other,
|
||||
"neither an IPv4 nor IPv6 packet",
|
||||
)),
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(target_os = "windows")]
|
||||
fn into_pi_field(self) -> Result<u16, io::Error> {
|
||||
unimplemented!()
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub enum TunPacketBuffer {
|
||||
Bytes(Bytes),
|
||||
BytesMut(BytesMut),
|
||||
}
|
||||
|
||||
impl From<TunPacketBuffer> for Bytes {
|
||||
fn from(buf: TunPacketBuffer) -> Self {
|
||||
match buf {
|
||||
TunPacketBuffer::Bytes(bytes) => bytes,
|
||||
TunPacketBuffer::BytesMut(bytes) => bytes.freeze(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl AsRef<[u8]> for TunPacketBuffer {
|
||||
fn as_ref(&self) -> &[u8] {
|
||||
match self {
|
||||
TunPacketBuffer::Bytes(bytes) => bytes.as_ref(),
|
||||
TunPacketBuffer::BytesMut(bytes) => bytes.as_ref(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// A Tun Packet to be sent or received on the TUN interface.
|
||||
#[derive(Debug)]
|
||||
pub struct TunPacket(PacketProtocol, TunPacketBuffer);
|
||||
|
||||
/// Infer the protocol based on the first nibble in the packet buffer.
|
||||
fn infer_proto(buf: &[u8]) -> PacketProtocol {
|
||||
match buf[0] >> 4 {
|
||||
4 => PacketProtocol::IPv4,
|
||||
6 => PacketProtocol::IPv6,
|
||||
p => PacketProtocol::Other(p),
|
||||
}
|
||||
}
|
||||
|
||||
impl TunPacket {
|
||||
/// Create a new `TunPacket` based on a byte slice.
|
||||
pub fn new(buffer: TunPacketBuffer) -> TunPacket {
|
||||
let proto = infer_proto(buffer.as_ref());
|
||||
TunPacket(proto, buffer)
|
||||
}
|
||||
|
||||
/// Return this packet's bytes.
|
||||
pub fn get_bytes(&self) -> &[u8] {
|
||||
match &self.1 {
|
||||
TunPacketBuffer::Bytes(bytes) => bytes.as_ref(),
|
||||
TunPacketBuffer::BytesMut(bytes) => bytes.as_ref(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn into_bytes(self) -> Bytes {
|
||||
match self.1 {
|
||||
TunPacketBuffer::Bytes(bytes) => bytes,
|
||||
TunPacketBuffer::BytesMut(bytes) => bytes.freeze(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn into_bytes_mut(self) -> BytesMut {
|
||||
match self.1 {
|
||||
TunPacketBuffer::Bytes(_) => panic!("cannot into_bytes_mut from bytes"),
|
||||
TunPacketBuffer::BytesMut(bytes) => bytes,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// A TunPacket Encoder/Decoder.
|
||||
pub struct TunPacketCodec(bool, i32);
|
||||
|
||||
impl TunPacketCodec {
|
||||
/// Create a new `TunPacketCodec` specifying whether the underlying
|
||||
/// tunnel Device has enabled the packet information header.
|
||||
pub fn new(pi: bool, mtu: i32) -> TunPacketCodec {
|
||||
TunPacketCodec(pi, mtu)
|
||||
}
|
||||
}
|
||||
|
||||
impl Decoder for TunPacketCodec {
|
||||
type Item = TunPacket;
|
||||
type Error = io::Error;
|
||||
|
||||
fn decode(&mut self, buf: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
|
||||
if buf.is_empty() {
|
||||
return Ok(None);
|
||||
}
|
||||
|
||||
let mut pkt = buf.split_to(buf.len());
|
||||
|
||||
// reserve enough space for the next packet
|
||||
if self.0 {
|
||||
buf.reserve(self.1 as usize + 4);
|
||||
} else {
|
||||
buf.reserve(self.1 as usize);
|
||||
}
|
||||
|
||||
// if the packet information is enabled we have to ignore the first 4 bytes
|
||||
if self.0 {
|
||||
let _ = pkt.split_to(4);
|
||||
}
|
||||
|
||||
let proto = infer_proto(pkt.as_ref());
|
||||
Ok(Some(TunPacket(proto, TunPacketBuffer::BytesMut(pkt))))
|
||||
}
|
||||
}
|
||||
|
||||
impl Encoder<TunPacket> for TunPacketCodec {
|
||||
type Error = io::Error;
|
||||
|
||||
fn encode(&mut self, item: TunPacket, dst: &mut BytesMut) -> Result<(), Self::Error> {
|
||||
dst.reserve(item.get_bytes().len() + 4);
|
||||
match item {
|
||||
TunPacket(proto, bytes) if self.0 => {
|
||||
// build the packet information header comprising of 2 u16
|
||||
// fields: flags and protocol.
|
||||
let mut buf = Vec::<u8>::with_capacity(4);
|
||||
|
||||
// flags is always 0
|
||||
buf.write_u16::<NativeEndian>(0)?;
|
||||
// write the protocol as network byte order
|
||||
buf.write_u16::<NetworkEndian>(proto.into_pi_field()?)?;
|
||||
|
||||
dst.put_slice(&buf);
|
||||
dst.put(Bytes::from(bytes));
|
||||
}
|
||||
TunPacket(_, bytes) => dst.put(Bytes::from(bytes)),
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,196 @@
|
||||
use std::{net::Ipv4Addr, pin::Pin};
|
||||
|
||||
use crate::{
|
||||
common::{
|
||||
error::Result,
|
||||
global_ctx::ArcGlobalCtx,
|
||||
ifcfg::{IfConfiger, IfConfiguerTrait},
|
||||
},
|
||||
tunnels::{
|
||||
codec::BytesCodec, common::FramedTunnel, DatagramSink, DatagramStream, Tunnel, TunnelError,
|
||||
},
|
||||
};
|
||||
|
||||
use futures::{SinkExt, StreamExt};
|
||||
use tokio_util::{bytes::Bytes, codec::Framed};
|
||||
use tun::Device;
|
||||
|
||||
use super::tun_codec::{TunPacket, TunPacketCodec};
|
||||
|
||||
pub struct VirtualNic {
|
||||
dev_name: String,
|
||||
queue_num: usize,
|
||||
|
||||
global_ctx: ArcGlobalCtx,
|
||||
|
||||
ifname: Option<String>,
|
||||
tun: Option<Box<dyn Tunnel>>,
|
||||
ifcfg: Box<dyn IfConfiguerTrait + Send + Sync + 'static>,
|
||||
}
|
||||
|
||||
impl VirtualNic {
|
||||
pub fn new(global_ctx: ArcGlobalCtx) -> Self {
|
||||
Self {
|
||||
dev_name: "".to_owned(),
|
||||
queue_num: 1,
|
||||
global_ctx,
|
||||
ifname: None,
|
||||
tun: None,
|
||||
ifcfg: Box::new(IfConfiger {}),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn set_dev_name(mut self, dev_name: &str) -> Result<Self> {
|
||||
self.dev_name = dev_name.to_owned();
|
||||
Ok(self)
|
||||
}
|
||||
|
||||
pub fn set_queue_num(mut self, queue_num: usize) -> Result<Self> {
|
||||
self.queue_num = queue_num;
|
||||
Ok(self)
|
||||
}
|
||||
|
||||
async fn create_dev_ret_err(&mut self) -> Result<()> {
|
||||
let mut config = tun::Configuration::default();
|
||||
let has_packet_info = cfg!(target_os = "macos");
|
||||
config.layer(tun::Layer::L3);
|
||||
|
||||
#[cfg(target_os = "linux")]
|
||||
{
|
||||
config.platform(|config| {
|
||||
// detect protocol by ourselves for cross platform
|
||||
config.packet_information(false);
|
||||
});
|
||||
}
|
||||
|
||||
if self.queue_num != 1 {
|
||||
todo!("queue_num != 1")
|
||||
}
|
||||
config.queues(self.queue_num);
|
||||
config.up();
|
||||
|
||||
let dev = {
|
||||
let _g = self.global_ctx.net_ns.guard();
|
||||
tun::create_as_async(&config)?
|
||||
};
|
||||
let ifname = dev.get_ref().name()?;
|
||||
self.ifcfg.wait_interface_show(ifname.as_str()).await?;
|
||||
|
||||
let ft: Box<dyn Tunnel> = if has_packet_info {
|
||||
let framed = Framed::new(dev, TunPacketCodec::new(true, 2500));
|
||||
let (sink, stream) = framed.split();
|
||||
|
||||
let new_stream = stream.map(|item| match item {
|
||||
Ok(item) => Ok(item.into_bytes_mut()),
|
||||
Err(err) => {
|
||||
println!("tun stream error: {:?}", err);
|
||||
Err(TunnelError::TunError(err.to_string()))
|
||||
}
|
||||
});
|
||||
|
||||
let new_sink = Box::pin(sink.with(|item: Bytes| async move {
|
||||
if false {
|
||||
return Err(TunnelError::TunError("tun sink error".to_owned()));
|
||||
}
|
||||
Ok(TunPacket::new(super::tun_codec::TunPacketBuffer::Bytes(
|
||||
item,
|
||||
)))
|
||||
}));
|
||||
|
||||
Box::new(FramedTunnel::new(new_stream, new_sink, None))
|
||||
} else {
|
||||
let framed = Framed::new(dev, BytesCodec::new(2500));
|
||||
let (sink, stream) = framed.split();
|
||||
Box::new(FramedTunnel::new(stream, sink, None))
|
||||
};
|
||||
|
||||
self.ifname = Some(ifname.to_owned());
|
||||
self.tun = Some(ft);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn create_dev(mut self) -> Result<Self> {
|
||||
self.create_dev_ret_err().await?;
|
||||
Ok(self)
|
||||
}
|
||||
|
||||
pub fn ifname(&self) -> &str {
|
||||
self.ifname.as_ref().unwrap().as_str()
|
||||
}
|
||||
|
||||
pub async fn link_up(&self) -> Result<()> {
|
||||
let _g = self.global_ctx.net_ns.guard();
|
||||
self.ifcfg.set_link_status(self.ifname(), true).await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn add_route(&self, address: Ipv4Addr, cidr: u8) -> Result<()> {
|
||||
let _g = self.global_ctx.net_ns.guard();
|
||||
self.ifcfg
|
||||
.add_ipv4_route(self.ifname(), address, cidr)
|
||||
.await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn remove_ip(&self, ip: Option<Ipv4Addr>) -> Result<()> {
|
||||
let _g = self.global_ctx.net_ns.guard();
|
||||
self.ifcfg.remove_ip(self.ifname(), ip).await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn add_ip(&self, ip: Ipv4Addr, cidr: i32) -> Result<()> {
|
||||
let _g = self.global_ctx.net_ns.guard();
|
||||
self.ifcfg
|
||||
.add_ipv4_ip(self.ifname(), ip, cidr as u8)
|
||||
.await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn pin_recv_stream(&self) -> Pin<Box<dyn DatagramStream>> {
|
||||
self.tun.as_ref().unwrap().pin_stream()
|
||||
}
|
||||
|
||||
pub fn pin_send_stream(&self) -> Pin<Box<dyn DatagramSink>> {
|
||||
self.tun.as_ref().unwrap().pin_sink()
|
||||
}
|
||||
|
||||
pub fn get_ifcfg(&self) -> &dyn IfConfiguerTrait {
|
||||
self.ifcfg.as_ref()
|
||||
}
|
||||
}
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use crate::common::{error::Error, global_ctx::tests::get_mock_global_ctx};
|
||||
|
||||
use super::VirtualNic;
|
||||
|
||||
async fn run_test_helper() -> Result<VirtualNic, Error> {
|
||||
let dev = VirtualNic::new(get_mock_global_ctx()).create_dev().await?;
|
||||
|
||||
tokio::time::sleep(tokio::time::Duration::from_secs(1)).await;
|
||||
|
||||
dev.link_up().await?;
|
||||
dev.remove_ip(None).await?;
|
||||
dev.add_ip("10.144.111.1".parse().unwrap(), 24).await?;
|
||||
Ok(dev)
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn tun_test() {
|
||||
let _dev = run_test_helper().await.unwrap();
|
||||
|
||||
// let mut stream = nic.pin_recv_stream();
|
||||
// while let Some(item) = stream.next().await {
|
||||
// println!("item: {:?}", item);
|
||||
// }
|
||||
|
||||
// let framed = dev.into_framed();
|
||||
// let (mut s, mut b) = framed.split();
|
||||
// loop {
|
||||
// let tmp = b.next().await.unwrap().unwrap();
|
||||
// let tmp = EthernetPacket::new(tmp.get_bytes());
|
||||
// println!("ret: {:?}", tmp.unwrap());
|
||||
// }
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,382 @@
|
||||
use std::{
|
||||
collections::hash_map::DefaultHasher,
|
||||
hash::{Hash, Hasher},
|
||||
sync::{
|
||||
atomic::{AtomicBool, Ordering},
|
||||
Arc,
|
||||
},
|
||||
time::{Duration, SystemTime},
|
||||
};
|
||||
|
||||
use crossbeam::atomic::AtomicCell;
|
||||
use futures::Future;
|
||||
use tokio::{
|
||||
sync::{Mutex, RwLock},
|
||||
task::JoinSet,
|
||||
};
|
||||
use tracing::Instrument;
|
||||
|
||||
use crate::{
|
||||
common::PeerId,
|
||||
peers::{peer_manager::PeerManager, rpc_service::PeerManagerRpcService},
|
||||
rpc::{GetGlobalPeerMapRequest, GetGlobalPeerMapResponse},
|
||||
};
|
||||
|
||||
use super::{
|
||||
server::PeerCenterServer,
|
||||
service::{GlobalPeerMap, PeerCenterService, PeerCenterServiceClient, PeerInfoForGlobalMap},
|
||||
Digest, Error,
|
||||
};
|
||||
|
||||
struct PeerCenterBase {
|
||||
peer_mgr: Arc<PeerManager>,
|
||||
tasks: Arc<Mutex<JoinSet<()>>>,
|
||||
lock: Arc<Mutex<()>>,
|
||||
}
|
||||
|
||||
static SERVICE_ID: u32 = 5;
|
||||
|
||||
struct PeridicJobCtx<T> {
|
||||
peer_mgr: Arc<PeerManager>,
|
||||
center_peer: AtomicCell<PeerId>,
|
||||
job_ctx: T,
|
||||
}
|
||||
|
||||
impl PeerCenterBase {
|
||||
pub async fn init(&self) -> Result<(), Error> {
|
||||
self.peer_mgr.get_peer_rpc_mgr().run_service(
|
||||
SERVICE_ID,
|
||||
PeerCenterServer::new(self.peer_mgr.my_peer_id()).serve(),
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn select_center_peer(peer_mgr: &Arc<PeerManager>) -> Option<PeerId> {
|
||||
let peers = peer_mgr.list_routes().await;
|
||||
if peers.is_empty() {
|
||||
return None;
|
||||
}
|
||||
// find peer with alphabetical smallest id.
|
||||
let mut min_peer = peer_mgr.my_peer_id();
|
||||
for peer in peers.iter() {
|
||||
let peer_id = peer.peer_id;
|
||||
if peer_id < min_peer {
|
||||
min_peer = peer_id;
|
||||
}
|
||||
}
|
||||
Some(min_peer)
|
||||
}
|
||||
|
||||
async fn init_periodic_job<
|
||||
T: Send + Sync + 'static + Clone,
|
||||
Fut: Future<Output = Result<u32, tarpc::client::RpcError>> + Send + 'static,
|
||||
>(
|
||||
&self,
|
||||
job_ctx: T,
|
||||
job_fn: (impl Fn(PeerCenterServiceClient, Arc<PeridicJobCtx<T>>) -> Fut + Send + Sync + 'static),
|
||||
) -> () {
|
||||
let my_peer_id = self.peer_mgr.my_peer_id();
|
||||
let peer_mgr = self.peer_mgr.clone();
|
||||
let lock = self.lock.clone();
|
||||
self.tasks.lock().await.spawn(
|
||||
async move {
|
||||
let ctx = Arc::new(PeridicJobCtx {
|
||||
peer_mgr: peer_mgr.clone(),
|
||||
center_peer: AtomicCell::new(PeerId::default()),
|
||||
job_ctx,
|
||||
});
|
||||
loop {
|
||||
let Some(center_peer) = Self::select_center_peer(&peer_mgr).await else {
|
||||
tracing::trace!("no center peer found, sleep 1 second");
|
||||
tokio::time::sleep(Duration::from_secs(1)).await;
|
||||
continue;
|
||||
};
|
||||
ctx.center_peer.store(center_peer.clone());
|
||||
tracing::trace!(?center_peer, "run periodic job");
|
||||
let rpc_mgr = peer_mgr.get_peer_rpc_mgr();
|
||||
let _g = lock.lock().await;
|
||||
let ret = rpc_mgr
|
||||
.do_client_rpc_scoped(SERVICE_ID, center_peer, |c| async {
|
||||
let client =
|
||||
PeerCenterServiceClient::new(tarpc::client::Config::default(), c)
|
||||
.spawn();
|
||||
job_fn(client, ctx.clone()).await
|
||||
})
|
||||
.await;
|
||||
drop(_g);
|
||||
|
||||
let Ok(sleep_time_ms) = ret else {
|
||||
tracing::error!("periodic job to center server rpc failed: {:?}", ret);
|
||||
tokio::time::sleep(Duration::from_secs(3)).await;
|
||||
continue;
|
||||
};
|
||||
|
||||
if sleep_time_ms > 0 {
|
||||
tokio::time::sleep(Duration::from_millis(sleep_time_ms as u64)).await;
|
||||
}
|
||||
}
|
||||
}
|
||||
.instrument(tracing::info_span!("periodic_job", ?my_peer_id)),
|
||||
);
|
||||
}
|
||||
|
||||
pub fn new(peer_mgr: Arc<PeerManager>) -> Self {
|
||||
PeerCenterBase {
|
||||
peer_mgr,
|
||||
tasks: Arc::new(Mutex::new(JoinSet::new())),
|
||||
lock: Arc::new(Mutex::new(())),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct PeerCenterInstanceService {
|
||||
global_peer_map: Arc<RwLock<GlobalPeerMap>>,
|
||||
global_peer_map_digest: Arc<RwLock<Digest>>,
|
||||
}
|
||||
|
||||
#[tonic::async_trait]
|
||||
impl crate::rpc::cli::peer_center_rpc_server::PeerCenterRpc for PeerCenterInstanceService {
|
||||
async fn get_global_peer_map(
|
||||
&self,
|
||||
_request: tonic::Request<GetGlobalPeerMapRequest>,
|
||||
) -> Result<tonic::Response<GetGlobalPeerMapResponse>, tonic::Status> {
|
||||
let global_peer_map = self.global_peer_map.read().await.clone();
|
||||
Ok(tonic::Response::new(GetGlobalPeerMapResponse {
|
||||
global_peer_map: global_peer_map
|
||||
.map
|
||||
.into_iter()
|
||||
.map(|(k, v)| (k, v))
|
||||
.collect(),
|
||||
}))
|
||||
}
|
||||
}
|
||||
|
||||
pub struct PeerCenterInstance {
|
||||
peer_mgr: Arc<PeerManager>,
|
||||
|
||||
client: Arc<PeerCenterBase>,
|
||||
global_peer_map: Arc<RwLock<GlobalPeerMap>>,
|
||||
global_peer_map_digest: Arc<RwLock<Digest>>,
|
||||
}
|
||||
|
||||
impl PeerCenterInstance {
|
||||
pub fn new(peer_mgr: Arc<PeerManager>) -> Self {
|
||||
PeerCenterInstance {
|
||||
peer_mgr: peer_mgr.clone(),
|
||||
client: Arc::new(PeerCenterBase::new(peer_mgr.clone())),
|
||||
global_peer_map: Arc::new(RwLock::new(GlobalPeerMap::new())),
|
||||
global_peer_map_digest: Arc::new(RwLock::new(Digest::default())),
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn init(&self) {
|
||||
self.client.init().await.unwrap();
|
||||
self.init_get_global_info_job().await;
|
||||
self.init_report_peers_job().await;
|
||||
}
|
||||
|
||||
async fn init_get_global_info_job(&self) {
|
||||
struct Ctx {
|
||||
global_peer_map: Arc<RwLock<GlobalPeerMap>>,
|
||||
global_peer_map_digest: Arc<RwLock<Digest>>,
|
||||
}
|
||||
|
||||
let ctx = Arc::new(Ctx {
|
||||
global_peer_map: self.global_peer_map.clone(),
|
||||
global_peer_map_digest: self.global_peer_map_digest.clone(),
|
||||
});
|
||||
|
||||
self.client
|
||||
.init_periodic_job(ctx, |client, ctx| async move {
|
||||
let mut rpc_ctx = tarpc::context::current();
|
||||
rpc_ctx.deadline = SystemTime::now() + Duration::from_secs(3);
|
||||
|
||||
let ret = client
|
||||
.get_global_peer_map(
|
||||
rpc_ctx,
|
||||
ctx.job_ctx.global_peer_map_digest.read().await.clone(),
|
||||
)
|
||||
.await?;
|
||||
|
||||
let Ok(resp) = ret else {
|
||||
tracing::error!(
|
||||
"get global info from center server got error result: {:?}",
|
||||
ret
|
||||
);
|
||||
return Ok(1000);
|
||||
};
|
||||
|
||||
let Some(resp) = resp else {
|
||||
return Ok(5000);
|
||||
};
|
||||
|
||||
tracing::info!(
|
||||
"get global info from center server: {:?}, digest: {:?}",
|
||||
resp.global_peer_map,
|
||||
resp.digest
|
||||
);
|
||||
|
||||
*ctx.job_ctx.global_peer_map.write().await = resp.global_peer_map;
|
||||
*ctx.job_ctx.global_peer_map_digest.write().await = resp.digest;
|
||||
|
||||
Ok(10000)
|
||||
})
|
||||
.await;
|
||||
}
|
||||
|
||||
async fn init_report_peers_job(&self) {
|
||||
struct Ctx {
|
||||
service: PeerManagerRpcService,
|
||||
need_send_peers: AtomicBool,
|
||||
last_report_peers: Mutex<PeerInfoForGlobalMap>,
|
||||
last_center_peer: AtomicCell<PeerId>,
|
||||
}
|
||||
let ctx = Arc::new(Ctx {
|
||||
service: PeerManagerRpcService::new(self.peer_mgr.clone()),
|
||||
need_send_peers: AtomicBool::new(true),
|
||||
last_report_peers: Mutex::new(PeerInfoForGlobalMap::default()),
|
||||
last_center_peer: AtomicCell::new(PeerId::default()),
|
||||
});
|
||||
|
||||
self.client
|
||||
.init_periodic_job(ctx, |client, ctx| async move {
|
||||
let my_node_id = ctx.peer_mgr.my_peer_id();
|
||||
|
||||
// if peers are not same in next 10 seconds, report peers to center server
|
||||
let mut peers = PeerInfoForGlobalMap::default();
|
||||
for _ in 1..10 {
|
||||
peers = ctx.job_ctx.service.list_peers().await.into();
|
||||
if ctx.center_peer.load() != ctx.job_ctx.last_center_peer.load() {
|
||||
// if center peer changed, report peers immediately
|
||||
break;
|
||||
}
|
||||
if peers == *ctx.job_ctx.last_report_peers.lock().await {
|
||||
return Ok(3000);
|
||||
}
|
||||
tokio::time::sleep(Duration::from_secs(2)).await;
|
||||
}
|
||||
|
||||
*ctx.job_ctx.last_report_peers.lock().await = peers.clone();
|
||||
let mut hasher = DefaultHasher::new();
|
||||
peers.hash(&mut hasher);
|
||||
|
||||
let peers = if ctx.job_ctx.need_send_peers.load(Ordering::Relaxed) {
|
||||
Some(peers)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
let mut rpc_ctx = tarpc::context::current();
|
||||
rpc_ctx.deadline = SystemTime::now() + Duration::from_secs(3);
|
||||
|
||||
let ret = client
|
||||
.report_peers(
|
||||
rpc_ctx,
|
||||
my_node_id.clone(),
|
||||
peers,
|
||||
hasher.finish() as Digest,
|
||||
)
|
||||
.await?;
|
||||
|
||||
if matches!(ret.as_ref().err(), Some(Error::DigestMismatch)) {
|
||||
ctx.job_ctx.need_send_peers.store(true, Ordering::Relaxed);
|
||||
return Ok(0);
|
||||
} else if ret.is_err() {
|
||||
tracing::error!("report peers to center server got error result: {:?}", ret);
|
||||
return Ok(500);
|
||||
}
|
||||
|
||||
ctx.job_ctx.last_center_peer.store(ctx.center_peer.load());
|
||||
ctx.job_ctx.need_send_peers.store(false, Ordering::Relaxed);
|
||||
Ok(3000)
|
||||
})
|
||||
.await;
|
||||
}
|
||||
|
||||
pub fn get_rpc_service(&self) -> PeerCenterInstanceService {
|
||||
PeerCenterInstanceService {
|
||||
global_peer_map: self.global_peer_map.clone(),
|
||||
global_peer_map_digest: self.global_peer_map_digest.clone(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::ops::Deref;
|
||||
|
||||
use crate::{
|
||||
peer_center::server::get_global_data,
|
||||
peers::tests::{connect_peer_manager, create_mock_peer_manager, wait_route_appear},
|
||||
};
|
||||
|
||||
use super::*;
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_peer_center_instance() {
|
||||
let peer_mgr_a = create_mock_peer_manager().await;
|
||||
let peer_mgr_b = create_mock_peer_manager().await;
|
||||
let peer_mgr_c = create_mock_peer_manager().await;
|
||||
|
||||
let peer_center_a = PeerCenterInstance::new(peer_mgr_a.clone());
|
||||
let peer_center_b = PeerCenterInstance::new(peer_mgr_b.clone());
|
||||
let peer_center_c = PeerCenterInstance::new(peer_mgr_c.clone());
|
||||
|
||||
let peer_centers = vec![&peer_center_a, &peer_center_b, &peer_center_c];
|
||||
for pc in peer_centers.iter() {
|
||||
pc.init().await;
|
||||
}
|
||||
|
||||
connect_peer_manager(peer_mgr_a.clone(), peer_mgr_b.clone()).await;
|
||||
connect_peer_manager(peer_mgr_b.clone(), peer_mgr_c.clone()).await;
|
||||
|
||||
wait_route_appear(peer_mgr_a.clone(), peer_mgr_c.clone())
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let center_peer = PeerCenterBase::select_center_peer(&peer_mgr_a)
|
||||
.await
|
||||
.unwrap();
|
||||
let center_data = get_global_data(center_peer);
|
||||
|
||||
// wait center_data has 3 records for 10 seconds
|
||||
let now = std::time::Instant::now();
|
||||
loop {
|
||||
if center_data.read().await.global_peer_map.map.len() == 3 {
|
||||
println!(
|
||||
"center data ready, {:#?}",
|
||||
center_data.read().await.global_peer_map
|
||||
);
|
||||
break;
|
||||
}
|
||||
if now.elapsed().as_secs() > 60 {
|
||||
panic!("center data not ready");
|
||||
}
|
||||
tokio::time::sleep(Duration::from_millis(100)).await;
|
||||
}
|
||||
|
||||
let mut digest = None;
|
||||
for pc in peer_centers.iter() {
|
||||
let rpc_service = pc.get_rpc_service();
|
||||
let now = std::time::Instant::now();
|
||||
while now.elapsed().as_secs() < 10 {
|
||||
if rpc_service.global_peer_map.read().await.map.len() == 3 {
|
||||
break;
|
||||
}
|
||||
tokio::time::sleep(Duration::from_millis(100)).await;
|
||||
}
|
||||
assert_eq!(rpc_service.global_peer_map.read().await.map.len(), 3);
|
||||
println!("rpc service ready, {:#?}", rpc_service.global_peer_map);
|
||||
|
||||
if digest.is_none() {
|
||||
digest = Some(rpc_service.global_peer_map_digest.read().await.clone());
|
||||
} else {
|
||||
let v = rpc_service.global_peer_map_digest.read().await;
|
||||
assert_eq!(digest.as_ref().unwrap(), v.deref());
|
||||
}
|
||||
}
|
||||
|
||||
let global_digest = get_global_data(center_peer).read().await.digest.clone();
|
||||
assert_eq!(digest.as_ref().unwrap(), &global_digest);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,20 @@
|
||||
// peer_center is used to collect peer info into one peer node.
|
||||
// the center node is selected with the following rules:
|
||||
// 1. has smallest peer id
|
||||
// 2. TODO: has allow_to_be_center peer feature
|
||||
// peer center is not guaranteed to be stable and can be changed when peer enter or leave.
|
||||
// it's used to reduce the cost to exchange infos between peers.
|
||||
|
||||
pub mod instance;
|
||||
mod server;
|
||||
mod service;
|
||||
|
||||
#[derive(thiserror::Error, Debug, serde::Deserialize, serde::Serialize)]
|
||||
pub enum Error {
|
||||
#[error("Digest not match, need provide full peer info to center server.")]
|
||||
DigestMismatch,
|
||||
#[error("Not center server")]
|
||||
NotCenterServer,
|
||||
}
|
||||
|
||||
pub type Digest = u64;
|
||||
@@ -0,0 +1,152 @@
|
||||
use std::{
|
||||
hash::{Hash, Hasher},
|
||||
sync::Arc,
|
||||
};
|
||||
|
||||
use dashmap::DashMap;
|
||||
use once_cell::sync::Lazy;
|
||||
use tokio::{sync::RwLock, task::JoinSet};
|
||||
|
||||
use crate::common::PeerId;
|
||||
|
||||
use super::{
|
||||
service::{GetGlobalPeerMapResponse, GlobalPeerMap, PeerCenterService, PeerInfoForGlobalMap},
|
||||
Digest, Error,
|
||||
};
|
||||
|
||||
pub(crate) struct PeerCenterServerGlobalData {
|
||||
pub global_peer_map: GlobalPeerMap,
|
||||
pub digest: Digest,
|
||||
pub update_time: std::time::Instant,
|
||||
pub peer_update_time: DashMap<PeerId, std::time::Instant>,
|
||||
}
|
||||
|
||||
impl PeerCenterServerGlobalData {
|
||||
fn new() -> Self {
|
||||
PeerCenterServerGlobalData {
|
||||
global_peer_map: GlobalPeerMap::new(),
|
||||
digest: Digest::default(),
|
||||
update_time: std::time::Instant::now(),
|
||||
peer_update_time: DashMap::new(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// a global unique instance for PeerCenterServer
|
||||
pub(crate) static GLOBAL_DATA: Lazy<DashMap<PeerId, Arc<RwLock<PeerCenterServerGlobalData>>>> =
|
||||
Lazy::new(DashMap::new);
|
||||
|
||||
pub(crate) fn get_global_data(node_id: PeerId) -> Arc<RwLock<PeerCenterServerGlobalData>> {
|
||||
GLOBAL_DATA
|
||||
.entry(node_id)
|
||||
.or_insert_with(|| Arc::new(RwLock::new(PeerCenterServerGlobalData::new())))
|
||||
.value()
|
||||
.clone()
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct PeerCenterServer {
|
||||
// every peer has its own server, so use per-struct dash map is ok.
|
||||
my_node_id: PeerId,
|
||||
digest_map: DashMap<PeerId, Digest>,
|
||||
|
||||
tasks: Arc<JoinSet<()>>,
|
||||
}
|
||||
|
||||
impl PeerCenterServer {
|
||||
pub fn new(my_node_id: PeerId) -> Self {
|
||||
let mut tasks = JoinSet::new();
|
||||
tasks.spawn(async move {
|
||||
loop {
|
||||
tokio::time::sleep(std::time::Duration::from_secs(10)).await;
|
||||
PeerCenterServer::clean_outdated_peer(my_node_id).await;
|
||||
}
|
||||
});
|
||||
|
||||
PeerCenterServer {
|
||||
my_node_id,
|
||||
digest_map: DashMap::new(),
|
||||
|
||||
tasks: Arc::new(tasks),
|
||||
}
|
||||
}
|
||||
|
||||
async fn clean_outdated_peer(my_node_id: PeerId) {
|
||||
let data = get_global_data(my_node_id);
|
||||
let mut locked_data = data.write().await;
|
||||
let now = std::time::Instant::now();
|
||||
let mut to_remove = Vec::new();
|
||||
for kv in locked_data.peer_update_time.iter() {
|
||||
if now.duration_since(*kv.value()).as_secs() > 20 {
|
||||
to_remove.push(*kv.key());
|
||||
}
|
||||
}
|
||||
for peer_id in to_remove {
|
||||
locked_data.global_peer_map.map.remove(&peer_id);
|
||||
locked_data.peer_update_time.remove(&peer_id);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[tarpc::server]
|
||||
impl PeerCenterService for PeerCenterServer {
|
||||
#[tracing::instrument()]
|
||||
async fn report_peers(
|
||||
self,
|
||||
_: tarpc::context::Context,
|
||||
my_peer_id: PeerId,
|
||||
peers: Option<PeerInfoForGlobalMap>,
|
||||
digest: Digest,
|
||||
) -> Result<(), Error> {
|
||||
tracing::trace!("receive report_peers");
|
||||
|
||||
let data = get_global_data(self.my_node_id);
|
||||
let mut locked_data = data.write().await;
|
||||
locked_data
|
||||
.peer_update_time
|
||||
.insert(my_peer_id, std::time::Instant::now());
|
||||
|
||||
let old_digest = self.digest_map.get(&my_peer_id);
|
||||
// if digest match, no need to update
|
||||
if let Some(old_digest) = old_digest {
|
||||
if *old_digest == digest {
|
||||
return Ok(());
|
||||
}
|
||||
}
|
||||
|
||||
if peers.is_none() {
|
||||
return Err(Error::DigestMismatch);
|
||||
}
|
||||
|
||||
self.digest_map.insert(my_peer_id, digest);
|
||||
locked_data
|
||||
.global_peer_map
|
||||
.map
|
||||
.insert(my_peer_id, peers.unwrap());
|
||||
|
||||
let mut hasher = std::collections::hash_map::DefaultHasher::new();
|
||||
locked_data.global_peer_map.map.hash(&mut hasher);
|
||||
locked_data.digest = hasher.finish() as Digest;
|
||||
locked_data.update_time = std::time::Instant::now();
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn get_global_peer_map(
|
||||
self,
|
||||
_: tarpc::context::Context,
|
||||
digest: Digest,
|
||||
) -> Result<Option<GetGlobalPeerMapResponse>, Error> {
|
||||
let data = get_global_data(self.my_node_id);
|
||||
if digest == data.read().await.digest {
|
||||
return Ok(None);
|
||||
}
|
||||
|
||||
let data = get_global_data(self.my_node_id);
|
||||
let locked_data = data.read().await;
|
||||
Ok(Some(GetGlobalPeerMapResponse {
|
||||
global_peer_map: locked_data.global_peer_map.clone(),
|
||||
digest: locked_data.digest,
|
||||
}))
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,84 @@
|
||||
use std::collections::BTreeMap;
|
||||
|
||||
use crate::{common::PeerId, rpc::DirectConnectedPeerInfo};
|
||||
|
||||
use super::{Digest, Error};
|
||||
use crate::rpc::PeerInfo;
|
||||
|
||||
pub type LatencyLevel = crate::rpc::cli::LatencyLevel;
|
||||
|
||||
impl LatencyLevel {
|
||||
pub const fn from_latency_ms(lat_ms: u32) -> Self {
|
||||
if lat_ms < 10 {
|
||||
LatencyLevel::VeryLow
|
||||
} else if lat_ms < 50 {
|
||||
LatencyLevel::Low
|
||||
} else if lat_ms < 100 {
|
||||
LatencyLevel::Normal
|
||||
} else if lat_ms < 200 {
|
||||
LatencyLevel::High
|
||||
} else {
|
||||
LatencyLevel::VeryHigh
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub type PeerInfoForGlobalMap = crate::rpc::cli::PeerInfoForGlobalMap;
|
||||
|
||||
impl From<Vec<PeerInfo>> for PeerInfoForGlobalMap {
|
||||
fn from(peers: Vec<PeerInfo>) -> Self {
|
||||
let mut peer_map = BTreeMap::new();
|
||||
for peer in peers {
|
||||
let min_lat = peer
|
||||
.conns
|
||||
.iter()
|
||||
.map(|conn| conn.stats.as_ref().unwrap().latency_us)
|
||||
.min()
|
||||
.unwrap_or(0);
|
||||
|
||||
let dp_info = DirectConnectedPeerInfo {
|
||||
latency_level: LatencyLevel::from_latency_ms(min_lat as u32 / 1000) as i32,
|
||||
};
|
||||
|
||||
// sort conn info so hash result is stable
|
||||
peer_map.insert(peer.peer_id, dp_info);
|
||||
}
|
||||
PeerInfoForGlobalMap {
|
||||
direct_peers: peer_map,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// a global peer topology map, peers can use it to find optimal path to other peers
|
||||
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
|
||||
pub struct GlobalPeerMap {
|
||||
pub map: BTreeMap<PeerId, PeerInfoForGlobalMap>,
|
||||
}
|
||||
|
||||
impl GlobalPeerMap {
|
||||
pub fn new() -> Self {
|
||||
GlobalPeerMap {
|
||||
map: BTreeMap::new(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, serde::Deserialize, serde::Serialize)]
|
||||
pub struct GetGlobalPeerMapResponse {
|
||||
pub global_peer_map: GlobalPeerMap,
|
||||
pub digest: Digest,
|
||||
}
|
||||
|
||||
#[tarpc::service]
|
||||
pub trait PeerCenterService {
|
||||
// report center server which peer is directly connected to me
|
||||
// digest is a hash of current peer map, if digest not match, we need to transfer the whole map
|
||||
async fn report_peers(
|
||||
my_peer_id: PeerId,
|
||||
peers: Option<PeerInfoForGlobalMap>,
|
||||
digest: Digest,
|
||||
) -> Result<(), Error>;
|
||||
|
||||
async fn get_global_peer_map(digest: Digest)
|
||||
-> Result<Option<GetGlobalPeerMapResponse>, Error>;
|
||||
}
|
||||
@@ -0,0 +1,200 @@
|
||||
use std::{
|
||||
sync::Arc,
|
||||
time::{Duration, SystemTime},
|
||||
};
|
||||
|
||||
use dashmap::DashMap;
|
||||
use tokio::{
|
||||
sync::{mpsc, Mutex},
|
||||
task::JoinSet,
|
||||
};
|
||||
use tokio_util::bytes::Bytes;
|
||||
|
||||
use crate::common::{
|
||||
error::Error,
|
||||
global_ctx::{ArcGlobalCtx, NetworkIdentity},
|
||||
PeerId,
|
||||
};
|
||||
|
||||
use super::{
|
||||
foreign_network_manager::{ForeignNetworkServiceClient, FOREIGN_NETWORK_SERVICE_ID},
|
||||
peer_conn::PeerConn,
|
||||
peer_map::PeerMap,
|
||||
peer_rpc::PeerRpcManager,
|
||||
};
|
||||
|
||||
pub struct ForeignNetworkClient {
|
||||
global_ctx: ArcGlobalCtx,
|
||||
peer_rpc: Arc<PeerRpcManager>,
|
||||
my_peer_id: PeerId,
|
||||
|
||||
peer_map: Arc<PeerMap>,
|
||||
|
||||
next_hop: Arc<DashMap<PeerId, PeerId>>,
|
||||
tasks: Mutex<JoinSet<()>>,
|
||||
}
|
||||
|
||||
impl ForeignNetworkClient {
|
||||
pub fn new(
|
||||
global_ctx: ArcGlobalCtx,
|
||||
packet_sender_to_mgr: mpsc::Sender<Bytes>,
|
||||
peer_rpc: Arc<PeerRpcManager>,
|
||||
my_peer_id: PeerId,
|
||||
) -> Self {
|
||||
let peer_map = Arc::new(PeerMap::new(
|
||||
packet_sender_to_mgr,
|
||||
global_ctx.clone(),
|
||||
my_peer_id,
|
||||
));
|
||||
let next_hop = Arc::new(DashMap::new());
|
||||
|
||||
Self {
|
||||
global_ctx,
|
||||
peer_rpc,
|
||||
my_peer_id,
|
||||
|
||||
peer_map,
|
||||
|
||||
next_hop,
|
||||
tasks: Mutex::new(JoinSet::new()),
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn add_new_peer_conn(&self, peer_conn: PeerConn) {
|
||||
tracing::warn!(peer_conn = ?peer_conn.get_conn_info(), network = ?peer_conn.get_network_identity(), "add new peer conn in foreign network client");
|
||||
self.peer_map.add_new_peer_conn(peer_conn).await
|
||||
}
|
||||
|
||||
async fn collect_next_hop_in_foreign_network_task(
|
||||
network_identity: NetworkIdentity,
|
||||
peer_map: Arc<PeerMap>,
|
||||
peer_rpc: Arc<PeerRpcManager>,
|
||||
next_hop: Arc<DashMap<PeerId, PeerId>>,
|
||||
) {
|
||||
loop {
|
||||
tokio::time::sleep(std::time::Duration::from_secs(1)).await;
|
||||
|
||||
peer_map.clean_peer_without_conn().await;
|
||||
|
||||
let new_next_hop = Self::collect_next_hop_in_foreign_network(
|
||||
network_identity.clone(),
|
||||
peer_map.clone(),
|
||||
peer_rpc.clone(),
|
||||
)
|
||||
.await;
|
||||
|
||||
next_hop.clear();
|
||||
for (k, v) in new_next_hop.into_iter() {
|
||||
next_hop.insert(k, v);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn collect_next_hop_in_foreign_network(
|
||||
network_identity: NetworkIdentity,
|
||||
peer_map: Arc<PeerMap>,
|
||||
peer_rpc: Arc<PeerRpcManager>,
|
||||
) -> DashMap<PeerId, PeerId> {
|
||||
let peers = peer_map.list_peers().await;
|
||||
let mut tasks = JoinSet::new();
|
||||
if !peers.is_empty() {
|
||||
tracing::warn!(?peers, my_peer_id = ?peer_rpc.my_peer_id(), "collect next hop in foreign network");
|
||||
}
|
||||
for peer in peers {
|
||||
let peer_rpc = peer_rpc.clone();
|
||||
let network_identity = network_identity.clone();
|
||||
tasks.spawn(async move {
|
||||
let Ok(Some(peers_in_foreign)) = peer_rpc
|
||||
.do_client_rpc_scoped(FOREIGN_NETWORK_SERVICE_ID, peer, |c| async {
|
||||
let c =
|
||||
ForeignNetworkServiceClient::new(tarpc::client::Config::default(), c)
|
||||
.spawn();
|
||||
let mut rpc_ctx = tarpc::context::current();
|
||||
rpc_ctx.deadline = SystemTime::now() + Duration::from_secs(2);
|
||||
let ret = c.list_network_peers(rpc_ctx, network_identity).await;
|
||||
ret
|
||||
})
|
||||
.await
|
||||
else {
|
||||
return (peer, vec![]);
|
||||
};
|
||||
|
||||
(peer, peers_in_foreign)
|
||||
});
|
||||
}
|
||||
|
||||
let new_next_hop = DashMap::new();
|
||||
while let Some(join_ret) = tasks.join_next().await {
|
||||
let Ok((gateway, peer_ids)) = join_ret else {
|
||||
tracing::error!(?join_ret, "collect next hop in foreign network failed");
|
||||
continue;
|
||||
};
|
||||
for ret in peer_ids {
|
||||
new_next_hop.insert(ret, gateway);
|
||||
}
|
||||
}
|
||||
|
||||
new_next_hop
|
||||
}
|
||||
|
||||
pub fn has_next_hop(&self, peer_id: PeerId) -> bool {
|
||||
self.get_next_hop(peer_id).is_some()
|
||||
}
|
||||
|
||||
pub fn get_next_hop(&self, peer_id: PeerId) -> Option<PeerId> {
|
||||
if self.peer_map.has_peer(peer_id) {
|
||||
return Some(peer_id.clone());
|
||||
}
|
||||
self.next_hop.get(&peer_id).map(|v| v.clone())
|
||||
}
|
||||
|
||||
pub async fn send_msg(&self, msg: Bytes, peer_id: PeerId) -> Result<(), Error> {
|
||||
if let Some(next_hop) = self.get_next_hop(peer_id) {
|
||||
let ret = self.peer_map.send_msg_directly(msg, next_hop).await;
|
||||
if ret.is_err() {
|
||||
tracing::error!(
|
||||
?ret,
|
||||
?peer_id,
|
||||
?next_hop,
|
||||
"foreign network client send msg failed"
|
||||
);
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
Err(Error::RouteError(Some("no next hop".to_string())))
|
||||
}
|
||||
|
||||
pub fn list_foreign_peers(&self) -> Vec<PeerId> {
|
||||
let mut peers = vec![];
|
||||
for item in self.next_hop.iter() {
|
||||
if item.key() != &self.my_peer_id {
|
||||
peers.push(item.key().clone());
|
||||
}
|
||||
}
|
||||
peers
|
||||
}
|
||||
|
||||
pub async fn run(&self) {
|
||||
self.tasks
|
||||
.lock()
|
||||
.await
|
||||
.spawn(Self::collect_next_hop_in_foreign_network_task(
|
||||
self.global_ctx.get_network_identity(),
|
||||
self.peer_map.clone(),
|
||||
self.peer_rpc.clone(),
|
||||
self.next_hop.clone(),
|
||||
));
|
||||
}
|
||||
|
||||
pub fn get_next_hop_table(&self) -> DashMap<PeerId, PeerId> {
|
||||
let next_hop = DashMap::new();
|
||||
for item in self.next_hop.iter() {
|
||||
next_hop.insert(item.key().clone(), item.value().clone());
|
||||
}
|
||||
next_hop
|
||||
}
|
||||
|
||||
pub fn get_peer_map(&self) -> Arc<PeerMap> {
|
||||
self.peer_map.clone()
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,459 @@
|
||||
/*
|
||||
foreign_network_manager is used to forward packets of other networks. currently
|
||||
only forward packets of peers that directly connected to this node.
|
||||
|
||||
in future, with the help wo peer center we can forward packets of peers that
|
||||
connected to any node in the local network.
|
||||
*/
|
||||
use std::sync::Arc;
|
||||
|
||||
use dashmap::DashMap;
|
||||
use tokio::{
|
||||
sync::{
|
||||
mpsc::{self, UnboundedReceiver, UnboundedSender},
|
||||
Mutex,
|
||||
},
|
||||
task::JoinSet,
|
||||
};
|
||||
use tokio_util::bytes::Bytes;
|
||||
|
||||
use crate::common::{
|
||||
error::Error,
|
||||
global_ctx::{ArcGlobalCtx, GlobalCtxEvent, NetworkIdentity},
|
||||
PeerId,
|
||||
};
|
||||
|
||||
use super::{
|
||||
packet::{self},
|
||||
peer_conn::PeerConn,
|
||||
peer_map::PeerMap,
|
||||
peer_rpc::{PeerRpcManager, PeerRpcManagerTransport},
|
||||
};
|
||||
|
||||
struct ForeignNetworkEntry {
|
||||
network: NetworkIdentity,
|
||||
peer_map: Arc<PeerMap>,
|
||||
}
|
||||
|
||||
impl ForeignNetworkEntry {
|
||||
fn new(
|
||||
network: NetworkIdentity,
|
||||
packet_sender: mpsc::Sender<Bytes>,
|
||||
global_ctx: ArcGlobalCtx,
|
||||
my_peer_id: PeerId,
|
||||
) -> Self {
|
||||
let peer_map = Arc::new(PeerMap::new(packet_sender, global_ctx, my_peer_id));
|
||||
Self { network, peer_map }
|
||||
}
|
||||
}
|
||||
|
||||
struct ForeignNetworkManagerData {
|
||||
network_peer_maps: DashMap<String, Arc<ForeignNetworkEntry>>,
|
||||
peer_network_map: DashMap<PeerId, String>,
|
||||
}
|
||||
|
||||
impl ForeignNetworkManagerData {
|
||||
async fn send_msg(&self, msg: Bytes, dst_peer_id: PeerId) -> Result<(), Error> {
|
||||
let network_name = self
|
||||
.peer_network_map
|
||||
.get(&dst_peer_id)
|
||||
.ok_or_else(|| Error::RouteError(Some("network not found".to_string())))?
|
||||
.clone();
|
||||
let entry = self
|
||||
.network_peer_maps
|
||||
.get(&network_name)
|
||||
.ok_or_else(|| Error::RouteError(Some("no peer in network".to_string())))?
|
||||
.clone();
|
||||
entry.peer_map.send_msg(msg, dst_peer_id).await
|
||||
}
|
||||
|
||||
fn get_peer_network(&self, peer_id: PeerId) -> Option<String> {
|
||||
self.peer_network_map.get(&peer_id).map(|v| v.clone())
|
||||
}
|
||||
|
||||
fn get_network_entry(&self, network_name: &str) -> Option<Arc<ForeignNetworkEntry>> {
|
||||
self.network_peer_maps.get(network_name).map(|v| v.clone())
|
||||
}
|
||||
|
||||
fn remove_peer(&self, peer_id: PeerId) {
|
||||
self.peer_network_map.remove(&peer_id);
|
||||
self.network_peer_maps.retain(|_, v| !v.peer_map.is_empty());
|
||||
}
|
||||
|
||||
fn clear_no_conn_peer(&self) {
|
||||
for item in self.network_peer_maps.iter() {
|
||||
let peer_map = item.value().peer_map.clone();
|
||||
tokio::spawn(async move {
|
||||
peer_map.clean_peer_without_conn().await;
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
struct RpcTransport {
|
||||
my_peer_id: PeerId,
|
||||
data: Arc<ForeignNetworkManagerData>,
|
||||
|
||||
packet_recv: Mutex<UnboundedReceiver<Bytes>>,
|
||||
}
|
||||
|
||||
#[async_trait::async_trait]
|
||||
impl PeerRpcManagerTransport for RpcTransport {
|
||||
fn my_peer_id(&self) -> PeerId {
|
||||
self.my_peer_id
|
||||
}
|
||||
|
||||
async fn send(&self, msg: Bytes, dst_peer_id: PeerId) -> Result<(), Error> {
|
||||
self.data.send_msg(msg, dst_peer_id).await
|
||||
}
|
||||
|
||||
async fn recv(&self) -> Result<Bytes, Error> {
|
||||
if let Some(o) = self.packet_recv.lock().await.recv().await {
|
||||
Ok(o)
|
||||
} else {
|
||||
Err(Error::Unknown)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub const FOREIGN_NETWORK_SERVICE_ID: u32 = 1;
|
||||
|
||||
#[tarpc::service]
|
||||
pub trait ForeignNetworkService {
|
||||
async fn list_network_peers(network_identy: NetworkIdentity) -> Option<Vec<PeerId>>;
|
||||
}
|
||||
|
||||
#[tarpc::server]
|
||||
impl ForeignNetworkService for Arc<ForeignNetworkManagerData> {
|
||||
async fn list_network_peers(
|
||||
self,
|
||||
_: tarpc::context::Context,
|
||||
network_identy: NetworkIdentity,
|
||||
) -> Option<Vec<PeerId>> {
|
||||
let entry = self.network_peer_maps.get(&network_identy.network_name)?;
|
||||
Some(entry.peer_map.list_peers().await)
|
||||
}
|
||||
}
|
||||
|
||||
pub struct ForeignNetworkManager {
|
||||
my_peer_id: PeerId,
|
||||
global_ctx: ArcGlobalCtx,
|
||||
packet_sender_to_mgr: mpsc::Sender<Bytes>,
|
||||
|
||||
packet_sender: mpsc::Sender<Bytes>,
|
||||
packet_recv: Mutex<Option<mpsc::Receiver<Bytes>>>,
|
||||
|
||||
data: Arc<ForeignNetworkManagerData>,
|
||||
rpc_mgr: Arc<PeerRpcManager>,
|
||||
rpc_transport_sender: UnboundedSender<Bytes>,
|
||||
|
||||
tasks: Mutex<JoinSet<()>>,
|
||||
}
|
||||
|
||||
impl ForeignNetworkManager {
|
||||
pub fn new(
|
||||
my_peer_id: PeerId,
|
||||
global_ctx: ArcGlobalCtx,
|
||||
packet_sender_to_mgr: mpsc::Sender<Bytes>,
|
||||
) -> Self {
|
||||
// recv packet from all foreign networks
|
||||
let (packet_sender, packet_recv) = mpsc::channel(1000);
|
||||
|
||||
let data = Arc::new(ForeignNetworkManagerData {
|
||||
network_peer_maps: DashMap::new(),
|
||||
peer_network_map: DashMap::new(),
|
||||
});
|
||||
|
||||
// handle rpc from foreign networks
|
||||
let (rpc_transport_sender, peer_rpc_tspt_recv) = mpsc::unbounded_channel();
|
||||
let rpc_mgr = Arc::new(PeerRpcManager::new(RpcTransport {
|
||||
my_peer_id,
|
||||
data: data.clone(),
|
||||
packet_recv: Mutex::new(peer_rpc_tspt_recv),
|
||||
}));
|
||||
|
||||
Self {
|
||||
my_peer_id,
|
||||
global_ctx,
|
||||
packet_sender_to_mgr,
|
||||
|
||||
packet_sender,
|
||||
packet_recv: Mutex::new(Some(packet_recv)),
|
||||
|
||||
data,
|
||||
rpc_mgr,
|
||||
rpc_transport_sender,
|
||||
|
||||
tasks: Mutex::new(JoinSet::new()),
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn add_peer_conn(&self, peer_conn: PeerConn) -> Result<(), Error> {
|
||||
tracing::info!(peer_conn = ?peer_conn.get_conn_info(), network = ?peer_conn.get_network_identity(), "add new peer conn in foreign network manager");
|
||||
|
||||
let entry = self
|
||||
.data
|
||||
.network_peer_maps
|
||||
.entry(peer_conn.get_network_identity().network_name.clone())
|
||||
.or_insert_with(|| {
|
||||
Arc::new(ForeignNetworkEntry::new(
|
||||
peer_conn.get_network_identity(),
|
||||
self.packet_sender.clone(),
|
||||
self.global_ctx.clone(),
|
||||
self.my_peer_id,
|
||||
))
|
||||
})
|
||||
.clone();
|
||||
|
||||
self.data.peer_network_map.insert(
|
||||
peer_conn.get_peer_id(),
|
||||
peer_conn.get_network_identity().network_name.clone(),
|
||||
);
|
||||
|
||||
if entry.network.network_secret != peer_conn.get_network_identity().network_secret {
|
||||
return Err(anyhow::anyhow!("network secret not match").into());
|
||||
}
|
||||
|
||||
Ok(entry.peer_map.add_new_peer_conn(peer_conn).await)
|
||||
}
|
||||
|
||||
async fn start_global_event_handler(&self) {
|
||||
let data = self.data.clone();
|
||||
let mut s = self.global_ctx.subscribe();
|
||||
self.tasks.lock().await.spawn(async move {
|
||||
while let Ok(e) = s.recv().await {
|
||||
if let GlobalCtxEvent::PeerRemoved(peer_id) = &e {
|
||||
tracing::info!(?e, "remove peer from foreign network manager");
|
||||
data.remove_peer(*peer_id);
|
||||
} else if let GlobalCtxEvent::PeerConnRemoved(..) = &e {
|
||||
tracing::info!(?e, "clear no conn peer from foreign network manager");
|
||||
data.clear_no_conn_peer();
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
async fn start_packet_recv(&self) {
|
||||
let mut recv = self.packet_recv.lock().await.take().unwrap();
|
||||
let sender_to_mgr = self.packet_sender_to_mgr.clone();
|
||||
let my_node_id = self.my_peer_id;
|
||||
let rpc_sender = self.rpc_transport_sender.clone();
|
||||
let data = self.data.clone();
|
||||
|
||||
self.tasks.lock().await.spawn(async move {
|
||||
while let Some(packet_bytes) = recv.recv().await {
|
||||
let packet = packet::Packet::decode(&packet_bytes);
|
||||
let from_peer_id = packet.from_peer.into();
|
||||
let to_peer_id = packet.to_peer.into();
|
||||
if to_peer_id == my_node_id {
|
||||
if packet.packet_type == packet::PacketType::TaRpc {
|
||||
rpc_sender.send(packet_bytes.clone()).unwrap();
|
||||
continue;
|
||||
}
|
||||
if let Err(e) = sender_to_mgr.send(packet_bytes).await {
|
||||
tracing::error!("send packet to mgr failed: {:?}", e);
|
||||
}
|
||||
} else {
|
||||
let Some(from_network) = data.get_peer_network(from_peer_id) else {
|
||||
continue;
|
||||
};
|
||||
let Some(to_network) = data.get_peer_network(to_peer_id) else {
|
||||
continue;
|
||||
};
|
||||
if from_network != to_network {
|
||||
continue;
|
||||
}
|
||||
|
||||
if let Some(entry) = data.get_network_entry(&from_network) {
|
||||
let ret = entry.peer_map.send_msg(packet_bytes, to_peer_id).await;
|
||||
if ret.is_err() {
|
||||
tracing::error!("forward packet to peer failed: {:?}", ret.err());
|
||||
}
|
||||
} else {
|
||||
tracing::error!("foreign network not found: {}", from_network);
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
async fn register_peer_rpc_service(&self) {
|
||||
self.rpc_mgr.run();
|
||||
self.rpc_mgr
|
||||
.run_service(FOREIGN_NETWORK_SERVICE_ID, self.data.clone().serve())
|
||||
}
|
||||
|
||||
pub async fn run(&self) {
|
||||
self.start_global_event_handler().await;
|
||||
self.start_packet_recv().await;
|
||||
self.register_peer_rpc_service().await;
|
||||
}
|
||||
|
||||
pub async fn list_foreign_networks(&self) -> DashMap<String, Vec<PeerId>> {
|
||||
let ret = DashMap::new();
|
||||
for item in self.data.network_peer_maps.iter() {
|
||||
let network_name = item.key().clone();
|
||||
ret.insert(network_name, vec![]);
|
||||
}
|
||||
|
||||
for mut n in ret.iter_mut() {
|
||||
let network_name = n.key().clone();
|
||||
let Some(item) = self
|
||||
.data
|
||||
.network_peer_maps
|
||||
.get(&network_name)
|
||||
.map(|v| v.clone())
|
||||
else {
|
||||
continue;
|
||||
};
|
||||
n.value_mut().extend(item.peer_map.list_peers().await);
|
||||
}
|
||||
ret
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use crate::{
|
||||
common::global_ctx::tests::get_mock_global_ctx_with_network,
|
||||
connector::udp_hole_punch::tests::{
|
||||
create_mock_peer_manager_with_mock_stun, replace_stun_info_collector,
|
||||
},
|
||||
peers::{
|
||||
peer_manager::{PeerManager, RouteAlgoType},
|
||||
tests::{connect_peer_manager, wait_route_appear},
|
||||
},
|
||||
rpc::NatType,
|
||||
};
|
||||
|
||||
use super::*;
|
||||
|
||||
async fn create_mock_peer_manager_for_foreign_network(network: &str) -> Arc<PeerManager> {
|
||||
let (s, _r) = tokio::sync::mpsc::channel(1000);
|
||||
let peer_mgr = Arc::new(PeerManager::new(
|
||||
RouteAlgoType::Ospf,
|
||||
get_mock_global_ctx_with_network(Some(NetworkIdentity {
|
||||
network_name: network.to_string(),
|
||||
network_secret: network.to_string(),
|
||||
})),
|
||||
s,
|
||||
));
|
||||
replace_stun_info_collector(peer_mgr.clone(), NatType::Unknown);
|
||||
peer_mgr.run().await.unwrap();
|
||||
peer_mgr
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_foreign_network_manager() {
|
||||
let pm_center = create_mock_peer_manager_with_mock_stun(crate::rpc::NatType::Unknown).await;
|
||||
let pm_center2 =
|
||||
create_mock_peer_manager_with_mock_stun(crate::rpc::NatType::Unknown).await;
|
||||
connect_peer_manager(pm_center.clone(), pm_center2.clone()).await;
|
||||
|
||||
let pma_net1 = create_mock_peer_manager_for_foreign_network("net1").await;
|
||||
let pmb_net1 = create_mock_peer_manager_for_foreign_network("net1").await;
|
||||
connect_peer_manager(pma_net1.clone(), pm_center.clone()).await;
|
||||
connect_peer_manager(pmb_net1.clone(), pm_center.clone()).await;
|
||||
|
||||
let now = std::time::Instant::now();
|
||||
let mut succ = false;
|
||||
while now.elapsed().as_secs() < 10 {
|
||||
let table = pma_net1.get_foreign_network_client().get_next_hop_table();
|
||||
if table.len() >= 1 {
|
||||
succ = true;
|
||||
break;
|
||||
}
|
||||
tokio::time::sleep(std::time::Duration::from_secs(1)).await;
|
||||
}
|
||||
assert!(succ);
|
||||
|
||||
assert_eq!(
|
||||
vec![pm_center.my_peer_id()],
|
||||
pma_net1
|
||||
.get_foreign_network_client()
|
||||
.get_peer_map()
|
||||
.list_peers()
|
||||
.await
|
||||
);
|
||||
assert_eq!(
|
||||
vec![pm_center.my_peer_id()],
|
||||
pmb_net1
|
||||
.get_foreign_network_client()
|
||||
.get_peer_map()
|
||||
.list_peers()
|
||||
.await
|
||||
);
|
||||
wait_route_appear(pma_net1.clone(), pmb_net1.clone())
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(1, pma_net1.list_routes().await.len());
|
||||
assert_eq!(1, pmb_net1.list_routes().await.len());
|
||||
|
||||
let pmc_net1 = create_mock_peer_manager_for_foreign_network("net1").await;
|
||||
connect_peer_manager(pmc_net1.clone(), pm_center.clone()).await;
|
||||
wait_route_appear(pma_net1.clone(), pmc_net1.clone())
|
||||
.await
|
||||
.unwrap();
|
||||
wait_route_appear(pmb_net1.clone(), pmc_net1.clone())
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(2, pmc_net1.list_routes().await.len());
|
||||
|
||||
let pma_net2 = create_mock_peer_manager_for_foreign_network("net2").await;
|
||||
let pmb_net2 = create_mock_peer_manager_for_foreign_network("net2").await;
|
||||
connect_peer_manager(pma_net2.clone(), pm_center.clone()).await;
|
||||
connect_peer_manager(pmb_net2.clone(), pm_center.clone()).await;
|
||||
wait_route_appear(pma_net2.clone(), pmb_net2.clone())
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(1, pma_net2.list_routes().await.len());
|
||||
assert_eq!(1, pmb_net2.list_routes().await.len());
|
||||
|
||||
assert_eq!(
|
||||
5,
|
||||
pm_center
|
||||
.get_foreign_network_manager()
|
||||
.data
|
||||
.peer_network_map
|
||||
.len()
|
||||
);
|
||||
|
||||
assert_eq!(
|
||||
2,
|
||||
pm_center
|
||||
.get_foreign_network_manager()
|
||||
.data
|
||||
.network_peer_maps
|
||||
.len()
|
||||
);
|
||||
|
||||
drop(pmb_net2);
|
||||
tokio::time::sleep(std::time::Duration::from_secs(1)).await;
|
||||
assert_eq!(
|
||||
4,
|
||||
pm_center
|
||||
.get_foreign_network_manager()
|
||||
.data
|
||||
.peer_network_map
|
||||
.len()
|
||||
);
|
||||
drop(pma_net2);
|
||||
tokio::time::sleep(std::time::Duration::from_secs(1)).await;
|
||||
assert_eq!(
|
||||
3,
|
||||
pm_center
|
||||
.get_foreign_network_manager()
|
||||
.data
|
||||
.peer_network_map
|
||||
.len()
|
||||
);
|
||||
assert_eq!(
|
||||
1,
|
||||
pm_center
|
||||
.get_foreign_network_manager()
|
||||
.data
|
||||
.network_peer_maps
|
||||
.len()
|
||||
);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,39 @@
|
||||
pub mod packet;
|
||||
pub mod peer;
|
||||
pub mod peer_conn;
|
||||
pub mod peer_manager;
|
||||
pub mod peer_map;
|
||||
pub mod peer_ospf_route;
|
||||
pub mod peer_rip_route;
|
||||
pub mod peer_rpc;
|
||||
pub mod route_trait;
|
||||
pub mod rpc_service;
|
||||
|
||||
pub mod foreign_network_client;
|
||||
pub mod foreign_network_manager;
|
||||
|
||||
#[cfg(test)]
|
||||
pub mod tests;
|
||||
|
||||
use tokio_util::bytes::{Bytes, BytesMut};
|
||||
|
||||
#[async_trait::async_trait]
|
||||
#[auto_impl::auto_impl(Arc)]
|
||||
pub trait PeerPacketFilter {
|
||||
async fn try_process_packet_from_peer(
|
||||
&self,
|
||||
_packet: &packet::ArchivedPacket,
|
||||
_data: &Bytes,
|
||||
) -> Option<()> {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait::async_trait]
|
||||
#[auto_impl::auto_impl(Arc)]
|
||||
pub trait NicPacketFilter {
|
||||
async fn try_process_packet_from_nic(&self, data: BytesMut) -> BytesMut;
|
||||
}
|
||||
|
||||
type BoxPeerPacketFilter = Box<dyn PeerPacketFilter + Send + Sync>;
|
||||
type BoxNicPacketFilter = Box<dyn NicPacketFilter + Send + Sync>;
|
||||
@@ -0,0 +1,254 @@
|
||||
use std::fmt::Debug;
|
||||
|
||||
use rkyv::{Archive, Deserialize, Serialize};
|
||||
use tokio_util::bytes::Bytes;
|
||||
|
||||
use crate::common::{
|
||||
global_ctx::NetworkIdentity,
|
||||
rkyv_util::{decode_from_bytes, encode_to_bytes, vec_to_string},
|
||||
PeerId,
|
||||
};
|
||||
|
||||
const MAGIC: u32 = 0xd1e1a5e1;
|
||||
const VERSION: u32 = 1;
|
||||
|
||||
#[derive(Archive, Deserialize, Serialize, PartialEq, Clone)]
|
||||
#[archive(compare(PartialEq), check_bytes)]
|
||||
// Derives can be passed through to the generated type:
|
||||
#[archive_attr(derive(Debug))]
|
||||
pub struct UUID(uuid::Bytes);
|
||||
|
||||
// impl Debug for UUID
|
||||
impl std::fmt::Debug for UUID {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
let uuid = uuid::Uuid::from_bytes(self.0);
|
||||
write!(f, "{}", uuid)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<uuid::Uuid> for UUID {
|
||||
fn from(uuid: uuid::Uuid) -> Self {
|
||||
UUID(*uuid.as_bytes())
|
||||
}
|
||||
}
|
||||
|
||||
impl From<UUID> for uuid::Uuid {
|
||||
fn from(uuid: UUID) -> Self {
|
||||
uuid::Uuid::from_bytes(uuid.0)
|
||||
}
|
||||
}
|
||||
|
||||
impl ArchivedUUID {
|
||||
pub fn to_uuid(&self) -> uuid::Uuid {
|
||||
uuid::Uuid::from_bytes(self.0)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<&ArchivedUUID> for UUID {
|
||||
fn from(uuid: &ArchivedUUID) -> Self {
|
||||
UUID(uuid.0)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(serde::Serialize, serde::Deserialize, Debug)]
|
||||
pub struct HandShake {
|
||||
pub magic: u32,
|
||||
pub my_peer_id: PeerId,
|
||||
pub version: u32,
|
||||
pub features: Vec<String>,
|
||||
pub network_identity: NetworkIdentity,
|
||||
}
|
||||
|
||||
#[derive(serde::Serialize, serde::Deserialize, Debug)]
|
||||
pub struct RoutePacket {
|
||||
pub route_id: u8,
|
||||
pub body: Vec<u8>,
|
||||
}
|
||||
|
||||
#[derive(Debug, serde::Serialize, serde::Deserialize)]
|
||||
pub enum CtrlPacketPayload {
|
||||
HandShake(HandShake),
|
||||
RoutePacket(RoutePacket),
|
||||
Ping(u32),
|
||||
Pong(u32),
|
||||
TaRpc(u32, u32, bool, Vec<u8>), // u32: service_id, u32: transact_id, bool: is_req, Vec<u8>: rpc body
|
||||
}
|
||||
|
||||
impl CtrlPacketPayload {
|
||||
pub fn from_packet(p: &ArchivedPacket) -> CtrlPacketPayload {
|
||||
assert_ne!(p.packet_type, PacketType::Data);
|
||||
postcard::from_bytes(p.payload.as_bytes()).unwrap()
|
||||
}
|
||||
|
||||
pub fn from_packet2(p: &Packet) -> CtrlPacketPayload {
|
||||
postcard::from_bytes(p.payload.as_bytes()).unwrap()
|
||||
}
|
||||
}
|
||||
|
||||
#[repr(u8)]
|
||||
#[derive(Archive, Deserialize, Serialize, Debug)]
|
||||
#[archive(compare(PartialEq), check_bytes)]
|
||||
// Derives can be passed through to the generated type:
|
||||
#[archive_attr(derive(Debug))]
|
||||
pub enum PacketType {
|
||||
Data = 1,
|
||||
HandShake = 2,
|
||||
RoutePacket = 3,
|
||||
Ping = 4,
|
||||
Pong = 5,
|
||||
TaRpc = 6,
|
||||
}
|
||||
|
||||
#[derive(Archive, Deserialize, Serialize)]
|
||||
#[archive(compare(PartialEq), check_bytes)]
|
||||
// Derives can be passed through to the generated type:
|
||||
pub struct Packet {
|
||||
pub from_peer: PeerId,
|
||||
pub to_peer: PeerId,
|
||||
pub packet_type: PacketType,
|
||||
pub payload: String,
|
||||
}
|
||||
|
||||
impl std::fmt::Debug for Packet {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
write!(
|
||||
f,
|
||||
"Packet {{ from_peer: {}, to_peer: {}, packet_type: {:?}, payload: {:?} }}",
|
||||
self.from_peer,
|
||||
self.to_peer,
|
||||
self.packet_type,
|
||||
&self.payload.as_bytes()
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
impl std::fmt::Debug for ArchivedPacket {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
write!(
|
||||
f,
|
||||
"Packet {{ from_peer: {}, to_peer: {}, packet_type: {:?}, payload: {:?} }}",
|
||||
self.from_peer,
|
||||
self.to_peer,
|
||||
self.packet_type,
|
||||
&self.payload.as_bytes()
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
impl Packet {
|
||||
pub fn decode(v: &[u8]) -> &ArchivedPacket {
|
||||
decode_from_bytes::<Packet>(v).unwrap()
|
||||
}
|
||||
|
||||
pub fn new(
|
||||
from_peer: PeerId,
|
||||
to_peer: PeerId,
|
||||
packet_type: PacketType,
|
||||
payload: Vec<u8>,
|
||||
) -> Self {
|
||||
Packet {
|
||||
from_peer,
|
||||
to_peer,
|
||||
packet_type,
|
||||
payload: vec_to_string(payload),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<Packet> for Bytes {
|
||||
fn from(val: Packet) -> Self {
|
||||
encode_to_bytes::<_, 4096>(&val)
|
||||
}
|
||||
}
|
||||
|
||||
impl Packet {
|
||||
pub fn new_handshake(from_peer: PeerId, network: &NetworkIdentity) -> Self {
|
||||
let handshake = CtrlPacketPayload::HandShake(HandShake {
|
||||
magic: MAGIC,
|
||||
my_peer_id: from_peer,
|
||||
version: VERSION,
|
||||
features: Vec::new(),
|
||||
network_identity: network.clone().into(),
|
||||
});
|
||||
Packet::new(
|
||||
from_peer.into(),
|
||||
0,
|
||||
PacketType::HandShake,
|
||||
postcard::to_allocvec(&handshake).unwrap(),
|
||||
)
|
||||
}
|
||||
|
||||
pub fn new_data_packet(from_peer: PeerId, to_peer: PeerId, data: &[u8]) -> Self {
|
||||
Packet::new(from_peer, to_peer, PacketType::Data, data.to_vec())
|
||||
}
|
||||
|
||||
pub fn new_route_packet(from_peer: PeerId, to_peer: PeerId, route_id: u8, data: &[u8]) -> Self {
|
||||
let route = CtrlPacketPayload::RoutePacket(RoutePacket {
|
||||
route_id,
|
||||
body: data.to_vec(),
|
||||
});
|
||||
Packet::new(
|
||||
from_peer,
|
||||
to_peer,
|
||||
PacketType::RoutePacket,
|
||||
postcard::to_allocvec(&route).unwrap(),
|
||||
)
|
||||
}
|
||||
|
||||
pub fn new_ping_packet(from_peer: PeerId, to_peer: PeerId, seq: u32) -> Self {
|
||||
let ping = CtrlPacketPayload::Ping(seq);
|
||||
Packet::new(
|
||||
from_peer,
|
||||
to_peer,
|
||||
PacketType::Ping,
|
||||
postcard::to_allocvec(&ping).unwrap(),
|
||||
)
|
||||
}
|
||||
|
||||
pub fn new_pong_packet(from_peer: PeerId, to_peer: PeerId, seq: u32) -> Self {
|
||||
let pong = CtrlPacketPayload::Pong(seq);
|
||||
Packet::new(
|
||||
from_peer,
|
||||
to_peer,
|
||||
PacketType::Pong,
|
||||
postcard::to_allocvec(&pong).unwrap(),
|
||||
)
|
||||
}
|
||||
|
||||
pub fn new_tarpc_packet(
|
||||
from_peer: PeerId,
|
||||
to_peer: PeerId,
|
||||
service_id: u32,
|
||||
transact_id: u32,
|
||||
is_req: bool,
|
||||
body: Vec<u8>,
|
||||
) -> Self {
|
||||
let ta_rpc = CtrlPacketPayload::TaRpc(service_id, transact_id, is_req, body);
|
||||
Packet::new(
|
||||
from_peer,
|
||||
to_peer,
|
||||
PacketType::TaRpc,
|
||||
postcard::to_allocvec(&ta_rpc).unwrap(),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use crate::common::new_peer_id;
|
||||
|
||||
use super::*;
|
||||
|
||||
#[tokio::test]
|
||||
async fn serialize() {
|
||||
let a = "abcde";
|
||||
let out = Packet::new_data_packet(new_peer_id(), new_peer_id(), a.as_bytes());
|
||||
// let out = T::new(a.as_bytes());
|
||||
let out_bytes: Bytes = out.into();
|
||||
println!("out str: {:?}", a.as_bytes());
|
||||
println!("out bytes: {:?}", out_bytes);
|
||||
|
||||
let archived = Packet::decode(&out_bytes[..]);
|
||||
println!("in packet: {:?}", archived);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,213 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use dashmap::DashMap;
|
||||
|
||||
use tokio::{
|
||||
select,
|
||||
sync::{mpsc, Mutex},
|
||||
task::JoinHandle,
|
||||
};
|
||||
use tokio_util::bytes::Bytes;
|
||||
use tracing::Instrument;
|
||||
|
||||
use super::peer_conn::{PeerConn, PeerConnId};
|
||||
use crate::common::{
|
||||
error::Error,
|
||||
global_ctx::{ArcGlobalCtx, GlobalCtxEvent},
|
||||
PeerId,
|
||||
};
|
||||
use crate::rpc::PeerConnInfo;
|
||||
|
||||
type ArcPeerConn = Arc<Mutex<PeerConn>>;
|
||||
type ConnMap = Arc<DashMap<PeerConnId, ArcPeerConn>>;
|
||||
|
||||
pub struct Peer {
|
||||
pub peer_node_id: PeerId,
|
||||
conns: ConnMap,
|
||||
global_ctx: ArcGlobalCtx,
|
||||
|
||||
packet_recv_chan: mpsc::Sender<Bytes>,
|
||||
|
||||
close_event_sender: mpsc::Sender<PeerConnId>,
|
||||
close_event_listener: JoinHandle<()>,
|
||||
|
||||
shutdown_notifier: Arc<tokio::sync::Notify>,
|
||||
}
|
||||
|
||||
impl Peer {
|
||||
pub fn new(
|
||||
peer_node_id: PeerId,
|
||||
packet_recv_chan: mpsc::Sender<Bytes>,
|
||||
global_ctx: ArcGlobalCtx,
|
||||
) -> Self {
|
||||
let conns: ConnMap = Arc::new(DashMap::new());
|
||||
let (close_event_sender, mut close_event_receiver) = mpsc::channel(10);
|
||||
let shutdown_notifier = Arc::new(tokio::sync::Notify::new());
|
||||
|
||||
let conns_copy = conns.clone();
|
||||
let shutdown_notifier_copy = shutdown_notifier.clone();
|
||||
let global_ctx_copy = global_ctx.clone();
|
||||
let close_event_listener = tokio::spawn(
|
||||
async move {
|
||||
loop {
|
||||
select! {
|
||||
ret = close_event_receiver.recv() => {
|
||||
if ret.is_none() {
|
||||
break;
|
||||
}
|
||||
let ret = ret.unwrap();
|
||||
tracing::warn!(
|
||||
?peer_node_id,
|
||||
?ret,
|
||||
"notified that peer conn is closed",
|
||||
);
|
||||
|
||||
if let Some((_, conn)) = conns_copy.remove(&ret) {
|
||||
global_ctx_copy.issue_event(GlobalCtxEvent::PeerConnRemoved(
|
||||
conn.lock().await.get_conn_info(),
|
||||
));
|
||||
}
|
||||
}
|
||||
|
||||
_ = shutdown_notifier_copy.notified() => {
|
||||
close_event_receiver.close();
|
||||
tracing::warn!(?peer_node_id, "peer close event listener notified");
|
||||
}
|
||||
}
|
||||
}
|
||||
tracing::info!("peer {} close event listener exit", peer_node_id);
|
||||
}
|
||||
.instrument(tracing::info_span!(
|
||||
"peer_close_event_listener",
|
||||
?peer_node_id,
|
||||
)),
|
||||
);
|
||||
|
||||
Peer {
|
||||
peer_node_id,
|
||||
conns: conns.clone(),
|
||||
packet_recv_chan,
|
||||
global_ctx,
|
||||
|
||||
close_event_sender,
|
||||
close_event_listener,
|
||||
|
||||
shutdown_notifier,
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn add_peer_conn(&self, mut conn: PeerConn) {
|
||||
conn.set_close_event_sender(self.close_event_sender.clone());
|
||||
conn.start_recv_loop(self.packet_recv_chan.clone());
|
||||
conn.start_pingpong();
|
||||
self.global_ctx
|
||||
.issue_event(GlobalCtxEvent::PeerConnAdded(conn.get_conn_info()));
|
||||
self.conns
|
||||
.insert(conn.get_conn_id(), Arc::new(Mutex::new(conn)));
|
||||
}
|
||||
|
||||
pub async fn send_msg(&self, msg: Bytes) -> Result<(), Error> {
|
||||
let Some(conn) = self.conns.iter().next() else {
|
||||
return Err(Error::PeerNoConnectionError(self.peer_node_id));
|
||||
};
|
||||
|
||||
let conn_clone = conn.clone();
|
||||
drop(conn);
|
||||
conn_clone.lock().await.send_msg(msg).await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn close_peer_conn(&self, conn_id: &PeerConnId) -> Result<(), Error> {
|
||||
let has_key = self.conns.contains_key(conn_id);
|
||||
if !has_key {
|
||||
return Err(Error::NotFound);
|
||||
}
|
||||
self.close_event_sender.send(conn_id.clone()).await.unwrap();
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn list_peer_conns(&self) -> Vec<PeerConnInfo> {
|
||||
let mut conns = vec![];
|
||||
for conn in self.conns.iter() {
|
||||
// do not lock here, otherwise it will cause dashmap deadlock
|
||||
conns.push(conn.clone());
|
||||
}
|
||||
|
||||
let mut ret = Vec::new();
|
||||
for conn in conns {
|
||||
ret.push(conn.lock().await.get_conn_info());
|
||||
}
|
||||
ret
|
||||
}
|
||||
}
|
||||
|
||||
// pritn on drop
|
||||
impl Drop for Peer {
|
||||
fn drop(&mut self) {
|
||||
self.shutdown_notifier.notify_one();
|
||||
tracing::info!("peer {} drop", self.peer_node_id);
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
|
||||
use tokio::{sync::mpsc, time::timeout};
|
||||
|
||||
use crate::{
|
||||
common::{global_ctx::tests::get_mock_global_ctx, new_peer_id},
|
||||
peers::peer_conn::PeerConn,
|
||||
tunnels::ring_tunnel::create_ring_tunnel_pair,
|
||||
};
|
||||
|
||||
use super::Peer;
|
||||
|
||||
#[tokio::test]
|
||||
async fn close_peer() {
|
||||
let (local_packet_send, _local_packet_recv) = mpsc::channel(10);
|
||||
let (remote_packet_send, _remote_packet_recv) = mpsc::channel(10);
|
||||
let global_ctx = get_mock_global_ctx();
|
||||
let local_peer = Peer::new(new_peer_id(), local_packet_send, global_ctx.clone());
|
||||
let remote_peer = Peer::new(new_peer_id(), remote_packet_send, global_ctx.clone());
|
||||
|
||||
let (local_tunnel, remote_tunnel) = create_ring_tunnel_pair();
|
||||
let mut local_peer_conn =
|
||||
PeerConn::new(local_peer.peer_node_id, global_ctx.clone(), local_tunnel);
|
||||
let mut remote_peer_conn =
|
||||
PeerConn::new(remote_peer.peer_node_id, global_ctx.clone(), remote_tunnel);
|
||||
|
||||
assert!(!local_peer_conn.handshake_done());
|
||||
assert!(!remote_peer_conn.handshake_done());
|
||||
|
||||
let (a, b) = tokio::join!(
|
||||
local_peer_conn.do_handshake_as_client(),
|
||||
remote_peer_conn.do_handshake_as_server()
|
||||
);
|
||||
a.unwrap();
|
||||
b.unwrap();
|
||||
|
||||
let local_conn_id = local_peer_conn.get_conn_id();
|
||||
|
||||
local_peer.add_peer_conn(local_peer_conn).await;
|
||||
remote_peer.add_peer_conn(remote_peer_conn).await;
|
||||
|
||||
assert_eq!(local_peer.list_peer_conns().await.len(), 1);
|
||||
assert_eq!(remote_peer.list_peer_conns().await.len(), 1);
|
||||
|
||||
let close_handler =
|
||||
tokio::spawn(async move { local_peer.close_peer_conn(&local_conn_id).await });
|
||||
|
||||
// wait for remote peer conn close
|
||||
timeout(std::time::Duration::from_secs(5), async {
|
||||
while (&remote_peer).list_peer_conns().await.len() != 0 {
|
||||
tokio::time::sleep(std::time::Duration::from_millis(100)).await;
|
||||
}
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
println!("wait for close handler");
|
||||
close_handler.await.unwrap().unwrap();
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,652 @@
|
||||
use std::{
|
||||
fmt::Debug,
|
||||
pin::Pin,
|
||||
sync::{
|
||||
atomic::{AtomicU32, Ordering},
|
||||
Arc,
|
||||
},
|
||||
};
|
||||
|
||||
use futures::{SinkExt, StreamExt};
|
||||
use pnet::datalink::NetworkInterface;
|
||||
|
||||
use tokio::{
|
||||
sync::{broadcast, mpsc, Mutex},
|
||||
task::JoinSet,
|
||||
time::{timeout, Duration},
|
||||
};
|
||||
|
||||
use tokio_util::{bytes::Bytes, sync::PollSender};
|
||||
use tracing::Instrument;
|
||||
|
||||
use crate::{
|
||||
common::{
|
||||
global_ctx::{ArcGlobalCtx, NetworkIdentity},
|
||||
PeerId,
|
||||
},
|
||||
define_tunnel_filter_chain,
|
||||
peers::packet::{ArchivedPacketType, CtrlPacketPayload, PacketType},
|
||||
rpc::{PeerConnInfo, PeerConnStats},
|
||||
tunnels::{
|
||||
stats::{Throughput, WindowLatency},
|
||||
tunnel_filter::StatsRecorderTunnelFilter,
|
||||
DatagramSink, Tunnel, TunnelError,
|
||||
},
|
||||
};
|
||||
|
||||
use super::packet::{self, HandShake, Packet};
|
||||
|
||||
pub type PacketRecvChan = mpsc::Sender<Bytes>;
|
||||
|
||||
pub type PeerConnId = uuid::Uuid;
|
||||
|
||||
macro_rules! wait_response {
|
||||
($stream: ident, $out_var:ident, $pattern:pat_param => $value:expr) => {
|
||||
let rsp_vec = timeout(Duration::from_secs(1), $stream.next()).await;
|
||||
if rsp_vec.is_err() {
|
||||
return Err(TunnelError::WaitRespError(
|
||||
"wait handshake response timeout".to_owned(),
|
||||
));
|
||||
}
|
||||
let rsp_vec = rsp_vec.unwrap().unwrap()?;
|
||||
|
||||
let $out_var;
|
||||
let rsp_bytes = Packet::decode(&rsp_vec);
|
||||
if rsp_bytes.packet_type != PacketType::HandShake {
|
||||
tracing::error!("unexpected packet type: {:?}", rsp_bytes);
|
||||
return Err(TunnelError::WaitRespError(
|
||||
"unexpected packet type".to_owned(),
|
||||
));
|
||||
}
|
||||
let resp_payload = CtrlPacketPayload::from_packet(&rsp_bytes);
|
||||
match &resp_payload {
|
||||
$pattern => $out_var = $value,
|
||||
_ => {
|
||||
tracing::error!(
|
||||
"unexpected packet: {:?}, pattern: {:?}",
|
||||
rsp_bytes,
|
||||
stringify!($pattern)
|
||||
);
|
||||
return Err(TunnelError::WaitRespError("unexpected packet".to_owned()));
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct PeerInfo {
|
||||
magic: u32,
|
||||
pub my_peer_id: PeerId,
|
||||
version: u32,
|
||||
pub features: Vec<String>,
|
||||
pub interfaces: Vec<NetworkInterface>,
|
||||
pub network_identity: NetworkIdentity,
|
||||
}
|
||||
|
||||
impl<'a> From<&HandShake> for PeerInfo {
|
||||
fn from(hs: &HandShake) -> Self {
|
||||
PeerInfo {
|
||||
magic: hs.magic.into(),
|
||||
my_peer_id: hs.my_peer_id.into(),
|
||||
version: hs.version.into(),
|
||||
features: hs.features.iter().map(|x| x.to_string()).collect(),
|
||||
interfaces: Vec::new(),
|
||||
network_identity: hs.network_identity.clone(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
struct PeerConnPinger {
|
||||
my_peer_id: PeerId,
|
||||
peer_id: PeerId,
|
||||
sink: Arc<Mutex<Pin<Box<dyn DatagramSink>>>>,
|
||||
ctrl_sender: broadcast::Sender<Bytes>,
|
||||
latency_stats: Arc<WindowLatency>,
|
||||
loss_rate_stats: Arc<AtomicU32>,
|
||||
tasks: JoinSet<Result<(), TunnelError>>,
|
||||
}
|
||||
|
||||
impl Debug for PeerConnPinger {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
f.debug_struct("PeerConnPinger")
|
||||
.field("my_peer_id", &self.my_peer_id)
|
||||
.field("peer_id", &self.peer_id)
|
||||
.finish()
|
||||
}
|
||||
}
|
||||
|
||||
impl PeerConnPinger {
|
||||
pub fn new(
|
||||
my_peer_id: PeerId,
|
||||
peer_id: PeerId,
|
||||
sink: Pin<Box<dyn DatagramSink>>,
|
||||
ctrl_sender: broadcast::Sender<Bytes>,
|
||||
latency_stats: Arc<WindowLatency>,
|
||||
loss_rate_stats: Arc<AtomicU32>,
|
||||
) -> Self {
|
||||
Self {
|
||||
my_peer_id,
|
||||
peer_id,
|
||||
sink: Arc::new(Mutex::new(sink)),
|
||||
tasks: JoinSet::new(),
|
||||
latency_stats,
|
||||
ctrl_sender,
|
||||
loss_rate_stats,
|
||||
}
|
||||
}
|
||||
|
||||
async fn do_pingpong_once(
|
||||
my_node_id: PeerId,
|
||||
peer_id: PeerId,
|
||||
sink: Arc<Mutex<Pin<Box<dyn DatagramSink>>>>,
|
||||
receiver: &mut broadcast::Receiver<Bytes>,
|
||||
seq: u32,
|
||||
) -> Result<u128, TunnelError> {
|
||||
// should add seq here. so latency can be calculated more accurately
|
||||
let req = packet::Packet::new_ping_packet(my_node_id, peer_id, seq).into();
|
||||
tracing::trace!("send ping packet: {:?}", req);
|
||||
sink.lock().await.send(req).await.map_err(|e| {
|
||||
tracing::warn!("send ping packet error: {:?}", e);
|
||||
TunnelError::CommonError("send ping packet error".to_owned())
|
||||
})?;
|
||||
|
||||
let now = std::time::Instant::now();
|
||||
|
||||
// wait until we get a pong packet in ctrl_resp_receiver
|
||||
let resp = timeout(Duration::from_secs(1), async {
|
||||
loop {
|
||||
match receiver.recv().await {
|
||||
Ok(p) => {
|
||||
let ctrl_payload =
|
||||
packet::CtrlPacketPayload::from_packet(Packet::decode(&p));
|
||||
if let packet::CtrlPacketPayload::Pong(resp_seq) = ctrl_payload {
|
||||
if resp_seq == seq {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
log::warn!("recv pong resp error: {:?}", e);
|
||||
return Err(TunnelError::WaitRespError(
|
||||
"recv pong resp error".to_owned(),
|
||||
));
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
})
|
||||
.await;
|
||||
|
||||
tracing::trace!(?resp, "wait ping response done");
|
||||
|
||||
if resp.is_err() {
|
||||
return Err(TunnelError::WaitRespError(
|
||||
"wait ping response timeout".to_owned(),
|
||||
));
|
||||
}
|
||||
|
||||
if resp.as_ref().unwrap().is_err() {
|
||||
return Err(resp.unwrap().err().unwrap());
|
||||
}
|
||||
|
||||
Ok(now.elapsed().as_micros())
|
||||
}
|
||||
|
||||
async fn pingpong(&mut self) {
|
||||
let sink = self.sink.clone();
|
||||
let my_node_id = self.my_peer_id;
|
||||
let peer_id = self.peer_id;
|
||||
let latency_stats = self.latency_stats.clone();
|
||||
|
||||
let (ping_res_sender, mut ping_res_receiver) = tokio::sync::mpsc::channel(100);
|
||||
|
||||
let stopped = Arc::new(AtomicU32::new(0));
|
||||
|
||||
// generate a pingpong task every 200ms
|
||||
let mut pingpong_tasks = JoinSet::new();
|
||||
let ctrl_resp_sender = self.ctrl_sender.clone();
|
||||
let stopped_clone = stopped.clone();
|
||||
self.tasks.spawn(async move {
|
||||
let mut req_seq = 0;
|
||||
loop {
|
||||
let receiver = ctrl_resp_sender.subscribe();
|
||||
let ping_res_sender = ping_res_sender.clone();
|
||||
let sink = sink.clone();
|
||||
|
||||
if stopped_clone.load(Ordering::Relaxed) != 0 {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
while pingpong_tasks.len() > 5 {
|
||||
pingpong_tasks.join_next().await;
|
||||
}
|
||||
|
||||
pingpong_tasks.spawn(async move {
|
||||
let mut receiver = receiver.resubscribe();
|
||||
let pingpong_once_ret = Self::do_pingpong_once(
|
||||
my_node_id,
|
||||
peer_id,
|
||||
sink.clone(),
|
||||
&mut receiver,
|
||||
req_seq,
|
||||
)
|
||||
.await;
|
||||
|
||||
if let Err(e) = ping_res_sender.send(pingpong_once_ret).await {
|
||||
tracing::info!(?e, "pingpong task send result error, exit..");
|
||||
};
|
||||
});
|
||||
|
||||
req_seq = req_seq.wrapping_add(1);
|
||||
tokio::time::sleep(Duration::from_millis(1000)).await;
|
||||
}
|
||||
});
|
||||
|
||||
// one with 1% precision
|
||||
let loss_rate_stats_1 = WindowLatency::new(100);
|
||||
// one with 20% precision, so we can fast fail this conn.
|
||||
let loss_rate_stats_20 = WindowLatency::new(5);
|
||||
|
||||
let mut counter: u64 = 0;
|
||||
|
||||
while let Some(ret) = ping_res_receiver.recv().await {
|
||||
counter += 1;
|
||||
|
||||
if let Ok(lat) = ret {
|
||||
latency_stats.record_latency(lat as u32);
|
||||
|
||||
loss_rate_stats_1.record_latency(0);
|
||||
loss_rate_stats_20.record_latency(0);
|
||||
} else {
|
||||
loss_rate_stats_1.record_latency(1);
|
||||
loss_rate_stats_20.record_latency(1);
|
||||
}
|
||||
|
||||
let loss_rate_20: f64 = loss_rate_stats_20.get_latency_us();
|
||||
let loss_rate_1: f64 = loss_rate_stats_1.get_latency_us();
|
||||
|
||||
tracing::trace!(
|
||||
?ret,
|
||||
?self,
|
||||
?loss_rate_1,
|
||||
?loss_rate_20,
|
||||
"pingpong task recv pingpong_once result"
|
||||
);
|
||||
|
||||
if (counter > 5 && loss_rate_20 > 0.74) || (counter > 150 && loss_rate_1 > 0.20) {
|
||||
tracing::warn!(
|
||||
?ret,
|
||||
?self,
|
||||
?loss_rate_1,
|
||||
?loss_rate_20,
|
||||
"pingpong loss rate too high, closing"
|
||||
);
|
||||
break;
|
||||
}
|
||||
|
||||
self.loss_rate_stats
|
||||
.store((loss_rate_1 * 100.0) as u32, Ordering::Relaxed);
|
||||
}
|
||||
|
||||
stopped.store(1, Ordering::Relaxed);
|
||||
ping_res_receiver.close();
|
||||
}
|
||||
}
|
||||
|
||||
define_tunnel_filter_chain!(PeerConnTunnel, stats = StatsRecorderTunnelFilter);
|
||||
|
||||
pub struct PeerConn {
|
||||
conn_id: PeerConnId,
|
||||
|
||||
my_peer_id: PeerId,
|
||||
global_ctx: ArcGlobalCtx,
|
||||
|
||||
sink: Pin<Box<dyn DatagramSink>>,
|
||||
tunnel: Box<dyn Tunnel>,
|
||||
|
||||
tasks: JoinSet<Result<(), TunnelError>>,
|
||||
|
||||
info: Option<PeerInfo>,
|
||||
|
||||
close_event_sender: Option<mpsc::Sender<PeerConnId>>,
|
||||
|
||||
ctrl_resp_sender: broadcast::Sender<Bytes>,
|
||||
|
||||
latency_stats: Arc<WindowLatency>,
|
||||
throughput: Arc<Throughput>,
|
||||
loss_rate_stats: Arc<AtomicU32>,
|
||||
}
|
||||
|
||||
enum PeerConnPacketType {
|
||||
Data(Bytes),
|
||||
CtrlReq(Bytes),
|
||||
CtrlResp(Bytes),
|
||||
}
|
||||
|
||||
impl PeerConn {
|
||||
pub fn new(my_peer_id: PeerId, global_ctx: ArcGlobalCtx, tunnel: Box<dyn Tunnel>) -> Self {
|
||||
let (ctrl_sender, _ctrl_receiver) = broadcast::channel(100);
|
||||
let peer_conn_tunnel = PeerConnTunnel::new();
|
||||
let tunnel = peer_conn_tunnel.wrap_tunnel(tunnel);
|
||||
|
||||
PeerConn {
|
||||
conn_id: PeerConnId::new_v4(),
|
||||
|
||||
my_peer_id,
|
||||
global_ctx,
|
||||
|
||||
sink: tunnel.pin_sink(),
|
||||
tunnel: Box::new(tunnel),
|
||||
|
||||
tasks: JoinSet::new(),
|
||||
|
||||
info: None,
|
||||
close_event_sender: None,
|
||||
|
||||
ctrl_resp_sender: ctrl_sender,
|
||||
|
||||
latency_stats: Arc::new(WindowLatency::new(15)),
|
||||
throughput: peer_conn_tunnel.stats.get_throughput().clone(),
|
||||
loss_rate_stats: Arc::new(AtomicU32::new(0)),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn get_conn_id(&self) -> PeerConnId {
|
||||
self.conn_id
|
||||
}
|
||||
|
||||
#[tracing::instrument]
|
||||
pub async fn do_handshake_as_server(&mut self) -> Result<(), TunnelError> {
|
||||
let mut stream = self.tunnel.pin_stream();
|
||||
let mut sink = self.tunnel.pin_sink();
|
||||
|
||||
tracing::info!("waiting for handshake request from client");
|
||||
wait_response!(stream, hs_req, CtrlPacketPayload::HandShake(x) => x);
|
||||
self.info = Some(PeerInfo::from(hs_req));
|
||||
tracing::info!("handshake request: {:?}", hs_req);
|
||||
|
||||
let hs_req = self
|
||||
.global_ctx
|
||||
.net_ns
|
||||
.run(|| packet::Packet::new_handshake(self.my_peer_id, &self.global_ctx.network));
|
||||
sink.send(hs_req.into()).await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tracing::instrument]
|
||||
pub async fn do_handshake_as_client(&mut self) -> Result<(), TunnelError> {
|
||||
let mut stream = self.tunnel.pin_stream();
|
||||
let mut sink = self.tunnel.pin_sink();
|
||||
|
||||
let hs_req = self
|
||||
.global_ctx
|
||||
.net_ns
|
||||
.run(|| packet::Packet::new_handshake(self.my_peer_id, &self.global_ctx.network));
|
||||
sink.send(hs_req.into()).await?;
|
||||
|
||||
tracing::info!("waiting for handshake request from server");
|
||||
wait_response!(stream, hs_rsp, CtrlPacketPayload::HandShake(x) => x);
|
||||
self.info = Some(PeerInfo::from(hs_rsp));
|
||||
tracing::info!("handshake response: {:?}", hs_rsp);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn handshake_done(&self) -> bool {
|
||||
self.info.is_some()
|
||||
}
|
||||
|
||||
pub fn start_pingpong(&mut self) {
|
||||
let mut pingpong = PeerConnPinger::new(
|
||||
self.my_peer_id,
|
||||
self.get_peer_id(),
|
||||
self.tunnel.pin_sink(),
|
||||
self.ctrl_resp_sender.clone(),
|
||||
self.latency_stats.clone(),
|
||||
self.loss_rate_stats.clone(),
|
||||
);
|
||||
|
||||
let close_event_sender = self.close_event_sender.clone().unwrap();
|
||||
let conn_id = self.conn_id;
|
||||
|
||||
self.tasks.spawn(async move {
|
||||
pingpong.pingpong().await;
|
||||
|
||||
tracing::warn!(?pingpong, "pingpong task exit");
|
||||
|
||||
if let Err(e) = close_event_sender.send(conn_id).await {
|
||||
log::warn!("close event sender error: {:?}", e);
|
||||
}
|
||||
|
||||
Ok(())
|
||||
});
|
||||
}
|
||||
|
||||
pub fn start_recv_loop(&mut self, packet_recv_chan: PacketRecvChan) {
|
||||
let mut stream = self.tunnel.pin_stream();
|
||||
let mut sink = self.tunnel.pin_sink();
|
||||
let mut sender = PollSender::new(packet_recv_chan.clone());
|
||||
let close_event_sender = self.close_event_sender.clone().unwrap();
|
||||
let conn_id = self.conn_id;
|
||||
let ctrl_sender = self.ctrl_resp_sender.clone();
|
||||
let conn_info = self.get_conn_info();
|
||||
let conn_info_for_instrument = self.get_conn_info();
|
||||
|
||||
self.tasks.spawn(
|
||||
async move {
|
||||
tracing::info!("start recving peer conn packet");
|
||||
let mut task_ret = Ok(());
|
||||
while let Some(ret) = stream.next().await {
|
||||
if ret.is_err() {
|
||||
tracing::error!(error = ?ret, "peer conn recv error");
|
||||
task_ret = Err(ret.err().unwrap());
|
||||
break;
|
||||
}
|
||||
|
||||
let buf = ret.unwrap();
|
||||
let p = Packet::decode(&buf);
|
||||
match p.packet_type {
|
||||
ArchivedPacketType::Ping => {
|
||||
let CtrlPacketPayload::Ping(seq) = CtrlPacketPayload::from_packet(p)
|
||||
else {
|
||||
log::error!("unexpected packet: {:?}", p);
|
||||
continue;
|
||||
};
|
||||
|
||||
let pong = packet::Packet::new_pong_packet(
|
||||
conn_info.my_peer_id,
|
||||
conn_info.peer_id,
|
||||
seq.into(),
|
||||
);
|
||||
|
||||
if let Err(e) = sink.send(pong.into()).await {
|
||||
tracing::error!(?e, "peer conn send req error");
|
||||
}
|
||||
}
|
||||
ArchivedPacketType::Pong => {
|
||||
if let Err(e) = ctrl_sender.send(buf.into()) {
|
||||
tracing::error!(?e, "peer conn send ctrl resp error");
|
||||
}
|
||||
}
|
||||
_ => {
|
||||
if sender.send(buf.into()).await.is_err() {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
tracing::info!("end recving peer conn packet");
|
||||
|
||||
if let Err(close_ret) = sink.close().await {
|
||||
tracing::error!(error = ?close_ret, "peer conn sink close error, ignore it");
|
||||
}
|
||||
if let Err(e) = close_event_sender.send(conn_id).await {
|
||||
tracing::error!(error = ?e, "peer conn close event send error");
|
||||
}
|
||||
|
||||
task_ret
|
||||
}
|
||||
.instrument(
|
||||
tracing::info_span!("peer conn recv loop", conn_info = ?conn_info_for_instrument),
|
||||
),
|
||||
);
|
||||
}
|
||||
|
||||
pub async fn send_msg(&mut self, msg: Bytes) -> Result<(), TunnelError> {
|
||||
self.sink.send(msg).await
|
||||
}
|
||||
|
||||
pub fn get_peer_id(&self) -> PeerId {
|
||||
self.info.as_ref().unwrap().my_peer_id
|
||||
}
|
||||
|
||||
pub fn get_network_identity(&self) -> NetworkIdentity {
|
||||
self.info.as_ref().unwrap().network_identity.clone()
|
||||
}
|
||||
|
||||
pub fn set_close_event_sender(&mut self, sender: mpsc::Sender<PeerConnId>) {
|
||||
self.close_event_sender = Some(sender);
|
||||
}
|
||||
|
||||
pub fn get_stats(&self) -> PeerConnStats {
|
||||
PeerConnStats {
|
||||
latency_us: self.latency_stats.get_latency_us(),
|
||||
|
||||
tx_bytes: self.throughput.tx_bytes(),
|
||||
rx_bytes: self.throughput.rx_bytes(),
|
||||
|
||||
tx_packets: self.throughput.tx_packets(),
|
||||
rx_packets: self.throughput.rx_packets(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn get_conn_info(&self) -> PeerConnInfo {
|
||||
PeerConnInfo {
|
||||
conn_id: self.conn_id.to_string(),
|
||||
my_peer_id: self.my_peer_id,
|
||||
peer_id: self.get_peer_id(),
|
||||
features: self.info.as_ref().unwrap().features.clone(),
|
||||
tunnel: self.tunnel.info(),
|
||||
stats: Some(self.get_stats()),
|
||||
loss_rate: (f64::from(self.loss_rate_stats.load(Ordering::Relaxed)) / 100.0) as f32,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for PeerConn {
|
||||
fn drop(&mut self) {
|
||||
let mut sink = self.tunnel.pin_sink();
|
||||
tokio::spawn(async move {
|
||||
let ret = sink.close().await;
|
||||
tracing::info!(error = ?ret, "peer conn tunnel closed.");
|
||||
});
|
||||
log::info!("peer conn {:?} drop", self.conn_id);
|
||||
}
|
||||
}
|
||||
|
||||
impl Debug for PeerConn {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
f.debug_struct("PeerConn")
|
||||
.field("conn_id", &self.conn_id)
|
||||
.field("my_peer_id", &self.my_peer_id)
|
||||
.field("info", &self.info)
|
||||
.finish()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::sync::Arc;
|
||||
|
||||
use super::*;
|
||||
use crate::common::global_ctx::tests::get_mock_global_ctx;
|
||||
use crate::common::new_peer_id;
|
||||
use crate::tunnels::tunnel_filter::tests::DropSendTunnelFilter;
|
||||
use crate::tunnels::tunnel_filter::{PacketRecorderTunnelFilter, TunnelWithFilter};
|
||||
|
||||
#[tokio::test]
|
||||
async fn peer_conn_handshake() {
|
||||
use crate::tunnels::ring_tunnel::create_ring_tunnel_pair;
|
||||
let (c, s) = create_ring_tunnel_pair();
|
||||
|
||||
let c_recorder = Arc::new(PacketRecorderTunnelFilter::new());
|
||||
let s_recorder = Arc::new(PacketRecorderTunnelFilter::new());
|
||||
|
||||
let c = TunnelWithFilter::new(c, c_recorder.clone());
|
||||
let s = TunnelWithFilter::new(s, s_recorder.clone());
|
||||
|
||||
let c_peer_id = new_peer_id();
|
||||
let s_peer_id = new_peer_id();
|
||||
|
||||
let mut c_peer = PeerConn::new(c_peer_id, get_mock_global_ctx(), Box::new(c));
|
||||
|
||||
let mut s_peer = PeerConn::new(s_peer_id, get_mock_global_ctx(), Box::new(s));
|
||||
|
||||
let (c_ret, s_ret) = tokio::join!(
|
||||
c_peer.do_handshake_as_client(),
|
||||
s_peer.do_handshake_as_server()
|
||||
);
|
||||
|
||||
c_ret.unwrap();
|
||||
s_ret.unwrap();
|
||||
|
||||
assert_eq!(c_recorder.sent.lock().unwrap().len(), 1);
|
||||
assert_eq!(c_recorder.received.lock().unwrap().len(), 1);
|
||||
|
||||
assert_eq!(s_recorder.sent.lock().unwrap().len(), 1);
|
||||
assert_eq!(s_recorder.received.lock().unwrap().len(), 1);
|
||||
|
||||
assert_eq!(c_peer.get_peer_id(), s_peer_id);
|
||||
assert_eq!(s_peer.get_peer_id(), c_peer_id);
|
||||
assert_eq!(c_peer.get_network_identity(), s_peer.get_network_identity());
|
||||
assert_eq!(c_peer.get_network_identity(), NetworkIdentity::default());
|
||||
}
|
||||
|
||||
async fn peer_conn_pingpong_test_common(drop_start: u32, drop_end: u32, conn_closed: bool) {
|
||||
use crate::tunnels::ring_tunnel::create_ring_tunnel_pair;
|
||||
let (c, s) = create_ring_tunnel_pair();
|
||||
|
||||
// drop 1-3 packets should not affect pingpong
|
||||
let c_recorder = Arc::new(DropSendTunnelFilter::new(drop_start, drop_end));
|
||||
let c = TunnelWithFilter::new(c, c_recorder.clone());
|
||||
|
||||
let c_peer_id = new_peer_id();
|
||||
let s_peer_id = new_peer_id();
|
||||
|
||||
let mut c_peer = PeerConn::new(c_peer_id, get_mock_global_ctx(), Box::new(c));
|
||||
let mut s_peer = PeerConn::new(s_peer_id, get_mock_global_ctx(), Box::new(s));
|
||||
|
||||
let (c_ret, s_ret) = tokio::join!(
|
||||
c_peer.do_handshake_as_client(),
|
||||
s_peer.do_handshake_as_server()
|
||||
);
|
||||
|
||||
s_peer.set_close_event_sender(tokio::sync::mpsc::channel(1).0);
|
||||
s_peer.start_recv_loop(tokio::sync::mpsc::channel(200).0);
|
||||
|
||||
assert!(c_ret.is_ok());
|
||||
assert!(s_ret.is_ok());
|
||||
|
||||
let (close_send, mut close_recv) = tokio::sync::mpsc::channel(1);
|
||||
c_peer.set_close_event_sender(close_send);
|
||||
c_peer.start_pingpong();
|
||||
c_peer.start_recv_loop(tokio::sync::mpsc::channel(200).0);
|
||||
|
||||
// wait 5s, conn should not be disconnected
|
||||
tokio::time::sleep(Duration::from_secs(15)).await;
|
||||
|
||||
if conn_closed {
|
||||
assert!(close_recv.try_recv().is_ok());
|
||||
} else {
|
||||
assert!(close_recv.try_recv().is_err());
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn peer_conn_pingpong_timeout() {
|
||||
peer_conn_pingpong_test_common(3, 5, false).await;
|
||||
peer_conn_pingpong_test_common(5, 12, true).await;
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,641 @@
|
||||
use std::{
|
||||
fmt::Debug,
|
||||
net::Ipv4Addr,
|
||||
sync::{Arc, Weak},
|
||||
};
|
||||
|
||||
use async_trait::async_trait;
|
||||
use futures::StreamExt;
|
||||
|
||||
use tokio::{
|
||||
sync::{
|
||||
mpsc::{self, UnboundedReceiver, UnboundedSender},
|
||||
Mutex, RwLock,
|
||||
},
|
||||
task::JoinSet,
|
||||
};
|
||||
use tokio_stream::wrappers::ReceiverStream;
|
||||
use tokio_util::bytes::{Bytes, BytesMut};
|
||||
|
||||
use crate::{
|
||||
common::{
|
||||
error::Error, global_ctx::ArcGlobalCtx, rkyv_util::extract_bytes_from_archived_string,
|
||||
PeerId,
|
||||
},
|
||||
peers::{
|
||||
packet, peer_conn::PeerConn, peer_rpc::PeerRpcManagerTransport,
|
||||
route_trait::RouteInterface, PeerPacketFilter,
|
||||
},
|
||||
tunnels::{SinkItem, Tunnel, TunnelConnector},
|
||||
};
|
||||
|
||||
use super::{
|
||||
foreign_network_client::ForeignNetworkClient,
|
||||
foreign_network_manager::ForeignNetworkManager,
|
||||
peer_conn::PeerConnId,
|
||||
peer_map::PeerMap,
|
||||
peer_ospf_route::PeerRoute,
|
||||
peer_rip_route::BasicRoute,
|
||||
peer_rpc::PeerRpcManager,
|
||||
route_trait::{ArcRoute, Route},
|
||||
BoxNicPacketFilter, BoxPeerPacketFilter,
|
||||
};
|
||||
|
||||
struct RpcTransport {
|
||||
my_peer_id: PeerId,
|
||||
peers: Weak<PeerMap>,
|
||||
foreign_peers: Mutex<Option<Weak<ForeignNetworkClient>>>,
|
||||
|
||||
packet_recv: Mutex<UnboundedReceiver<Bytes>>,
|
||||
peer_rpc_tspt_sender: UnboundedSender<Bytes>,
|
||||
}
|
||||
|
||||
#[async_trait::async_trait]
|
||||
impl PeerRpcManagerTransport for RpcTransport {
|
||||
fn my_peer_id(&self) -> PeerId {
|
||||
self.my_peer_id
|
||||
}
|
||||
|
||||
async fn send(&self, msg: Bytes, dst_peer_id: PeerId) -> Result<(), Error> {
|
||||
let foreign_peers = self
|
||||
.foreign_peers
|
||||
.lock()
|
||||
.await
|
||||
.as_ref()
|
||||
.ok_or(Error::Unknown)?
|
||||
.upgrade()
|
||||
.ok_or(Error::Unknown)?;
|
||||
let peers = self.peers.upgrade().ok_or(Error::Unknown)?;
|
||||
|
||||
let ret = peers.send_msg(msg.clone(), dst_peer_id).await;
|
||||
|
||||
if matches!(ret, Err(Error::RouteError(..))) && foreign_peers.has_next_hop(dst_peer_id) {
|
||||
tracing::info!(
|
||||
?dst_peer_id,
|
||||
?self.my_peer_id,
|
||||
"failed to send msg to peer, try foreign network",
|
||||
);
|
||||
return foreign_peers.send_msg(msg, dst_peer_id).await;
|
||||
}
|
||||
|
||||
ret
|
||||
}
|
||||
|
||||
async fn recv(&self) -> Result<Bytes, Error> {
|
||||
if let Some(o) = self.packet_recv.lock().await.recv().await {
|
||||
Ok(o)
|
||||
} else {
|
||||
Err(Error::Unknown)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub enum RouteAlgoType {
|
||||
Rip,
|
||||
Ospf,
|
||||
None,
|
||||
}
|
||||
|
||||
enum RouteAlgoInst {
|
||||
Rip(Arc<BasicRoute>),
|
||||
Ospf(Arc<PeerRoute>),
|
||||
None,
|
||||
}
|
||||
|
||||
pub struct PeerManager {
|
||||
my_peer_id: PeerId,
|
||||
|
||||
global_ctx: ArcGlobalCtx,
|
||||
nic_channel: mpsc::Sender<SinkItem>,
|
||||
|
||||
tasks: Arc<Mutex<JoinSet<()>>>,
|
||||
|
||||
packet_recv: Arc<Mutex<Option<mpsc::Receiver<Bytes>>>>,
|
||||
|
||||
peers: Arc<PeerMap>,
|
||||
|
||||
peer_rpc_mgr: Arc<PeerRpcManager>,
|
||||
peer_rpc_tspt: Arc<RpcTransport>,
|
||||
|
||||
peer_packet_process_pipeline: Arc<RwLock<Vec<BoxPeerPacketFilter>>>,
|
||||
nic_packet_process_pipeline: Arc<RwLock<Vec<BoxNicPacketFilter>>>,
|
||||
|
||||
route_algo_inst: RouteAlgoInst,
|
||||
|
||||
foreign_network_manager: Arc<ForeignNetworkManager>,
|
||||
foreign_network_client: Arc<ForeignNetworkClient>,
|
||||
}
|
||||
|
||||
impl Debug for PeerManager {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
f.debug_struct("PeerManager")
|
||||
.field("my_peer_id", &self.my_peer_id())
|
||||
.field("instance_name", &self.global_ctx.inst_name)
|
||||
.field("net_ns", &self.global_ctx.net_ns.name())
|
||||
.finish()
|
||||
}
|
||||
}
|
||||
|
||||
impl PeerManager {
|
||||
pub fn new(
|
||||
route_algo: RouteAlgoType,
|
||||
global_ctx: ArcGlobalCtx,
|
||||
nic_channel: mpsc::Sender<SinkItem>,
|
||||
) -> Self {
|
||||
let my_peer_id = rand::random();
|
||||
|
||||
let (packet_send, packet_recv) = mpsc::channel(100);
|
||||
let peers = Arc::new(PeerMap::new(
|
||||
packet_send.clone(),
|
||||
global_ctx.clone(),
|
||||
my_peer_id,
|
||||
));
|
||||
|
||||
// TODO: remove these because we have impl pipeline processor.
|
||||
let (peer_rpc_tspt_sender, peer_rpc_tspt_recv) = mpsc::unbounded_channel();
|
||||
let rpc_tspt = Arc::new(RpcTransport {
|
||||
my_peer_id,
|
||||
peers: Arc::downgrade(&peers),
|
||||
foreign_peers: Mutex::new(None),
|
||||
packet_recv: Mutex::new(peer_rpc_tspt_recv),
|
||||
peer_rpc_tspt_sender,
|
||||
});
|
||||
let peer_rpc_mgr = Arc::new(PeerRpcManager::new(rpc_tspt.clone()));
|
||||
|
||||
let route_algo_inst = match route_algo {
|
||||
RouteAlgoType::Rip => {
|
||||
RouteAlgoInst::Rip(Arc::new(BasicRoute::new(my_peer_id, global_ctx.clone())))
|
||||
}
|
||||
RouteAlgoType::Ospf => RouteAlgoInst::Ospf(PeerRoute::new(
|
||||
my_peer_id,
|
||||
global_ctx.clone(),
|
||||
peer_rpc_mgr.clone(),
|
||||
)),
|
||||
RouteAlgoType::None => RouteAlgoInst::None,
|
||||
};
|
||||
|
||||
let foreign_network_manager = Arc::new(ForeignNetworkManager::new(
|
||||
my_peer_id,
|
||||
global_ctx.clone(),
|
||||
packet_send.clone(),
|
||||
));
|
||||
let foreign_network_client = Arc::new(ForeignNetworkClient::new(
|
||||
global_ctx.clone(),
|
||||
packet_send.clone(),
|
||||
peer_rpc_mgr.clone(),
|
||||
my_peer_id,
|
||||
));
|
||||
|
||||
PeerManager {
|
||||
my_peer_id,
|
||||
|
||||
global_ctx,
|
||||
nic_channel,
|
||||
|
||||
tasks: Arc::new(Mutex::new(JoinSet::new())),
|
||||
|
||||
packet_recv: Arc::new(Mutex::new(Some(packet_recv))),
|
||||
|
||||
peers: peers.clone(),
|
||||
|
||||
peer_rpc_mgr,
|
||||
peer_rpc_tspt: rpc_tspt,
|
||||
|
||||
peer_packet_process_pipeline: Arc::new(RwLock::new(Vec::new())),
|
||||
nic_packet_process_pipeline: Arc::new(RwLock::new(Vec::new())),
|
||||
|
||||
route_algo_inst,
|
||||
|
||||
foreign_network_manager,
|
||||
foreign_network_client,
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn add_client_tunnel(
|
||||
&self,
|
||||
tunnel: Box<dyn Tunnel>,
|
||||
) -> Result<(PeerId, PeerConnId), Error> {
|
||||
let mut peer = PeerConn::new(self.my_peer_id, self.global_ctx.clone(), tunnel);
|
||||
peer.do_handshake_as_client().await?;
|
||||
let conn_id = peer.get_conn_id();
|
||||
let peer_id = peer.get_peer_id();
|
||||
if peer.get_network_identity() == self.global_ctx.get_network_identity() {
|
||||
self.peers.add_new_peer_conn(peer).await;
|
||||
} else {
|
||||
self.foreign_network_client.add_new_peer_conn(peer).await;
|
||||
}
|
||||
Ok((peer_id, conn_id))
|
||||
}
|
||||
|
||||
#[tracing::instrument]
|
||||
pub async fn try_connect<C>(&self, mut connector: C) -> Result<(PeerId, PeerConnId), Error>
|
||||
where
|
||||
C: TunnelConnector + Debug,
|
||||
{
|
||||
let ns = self.global_ctx.net_ns.clone();
|
||||
let t = ns
|
||||
.run_async(|| async move { connector.connect().await })
|
||||
.await?;
|
||||
self.add_client_tunnel(t).await
|
||||
}
|
||||
|
||||
#[tracing::instrument]
|
||||
pub async fn add_tunnel_as_server(&self, tunnel: Box<dyn Tunnel>) -> Result<(), Error> {
|
||||
tracing::info!("add tunnel as server start");
|
||||
let mut peer = PeerConn::new(self.my_peer_id, self.global_ctx.clone(), tunnel);
|
||||
peer.do_handshake_as_server().await?;
|
||||
if peer.get_network_identity() == self.global_ctx.get_network_identity() {
|
||||
self.peers.add_new_peer_conn(peer).await;
|
||||
} else {
|
||||
self.foreign_network_manager.add_peer_conn(peer).await?;
|
||||
}
|
||||
tracing::info!("add tunnel as server done");
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn start_peer_recv(&self) {
|
||||
let mut recv = ReceiverStream::new(self.packet_recv.lock().await.take().unwrap());
|
||||
let my_peer_id = self.my_peer_id;
|
||||
let peers = self.peers.clone();
|
||||
let pipe_line = self.peer_packet_process_pipeline.clone();
|
||||
self.tasks.lock().await.spawn(async move {
|
||||
log::trace!("start_peer_recv");
|
||||
while let Some(ret) = recv.next().await {
|
||||
log::trace!("peer recv a packet...: {:?}", ret);
|
||||
let packet = packet::Packet::decode(&ret);
|
||||
let from_peer_id: PeerId = packet.from_peer.into();
|
||||
let to_peer_id: PeerId = packet.to_peer.into();
|
||||
if to_peer_id != my_peer_id {
|
||||
log::trace!(
|
||||
"need forward: to_peer_id: {:?}, my_peer_id: {:?}",
|
||||
to_peer_id,
|
||||
my_peer_id
|
||||
);
|
||||
let ret = peers.send_msg(ret.clone(), to_peer_id).await;
|
||||
if ret.is_err() {
|
||||
log::error!(
|
||||
"forward packet error: {:?}, dst: {:?}, from: {:?}",
|
||||
ret,
|
||||
to_peer_id,
|
||||
from_peer_id
|
||||
);
|
||||
}
|
||||
} else {
|
||||
let mut processed = false;
|
||||
for pipeline in pipe_line.read().await.iter().rev() {
|
||||
if let Some(_) = pipeline.try_process_packet_from_peer(&packet, &ret).await
|
||||
{
|
||||
processed = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if !processed {
|
||||
tracing::error!("unexpected packet: {:?}", ret);
|
||||
}
|
||||
}
|
||||
}
|
||||
panic!("done_peer_recv");
|
||||
});
|
||||
}
|
||||
|
||||
pub async fn add_packet_process_pipeline(&self, pipeline: BoxPeerPacketFilter) {
|
||||
// newest pipeline will be executed first
|
||||
self.peer_packet_process_pipeline
|
||||
.write()
|
||||
.await
|
||||
.push(pipeline);
|
||||
}
|
||||
|
||||
pub async fn add_nic_packet_process_pipeline(&self, pipeline: BoxNicPacketFilter) {
|
||||
// newest pipeline will be executed first
|
||||
self.nic_packet_process_pipeline
|
||||
.write()
|
||||
.await
|
||||
.push(pipeline);
|
||||
}
|
||||
|
||||
async fn init_packet_process_pipeline(&self) {
|
||||
// for tun/tap ip/eth packet.
|
||||
struct NicPacketProcessor {
|
||||
nic_channel: mpsc::Sender<SinkItem>,
|
||||
}
|
||||
#[async_trait::async_trait]
|
||||
impl PeerPacketFilter for NicPacketProcessor {
|
||||
async fn try_process_packet_from_peer(
|
||||
&self,
|
||||
packet: &packet::ArchivedPacket,
|
||||
data: &Bytes,
|
||||
) -> Option<()> {
|
||||
if packet.packet_type == packet::PacketType::Data {
|
||||
// TODO: use a function to get the body ref directly for zero copy
|
||||
self.nic_channel
|
||||
.send(extract_bytes_from_archived_string(data, &packet.payload))
|
||||
.await
|
||||
.unwrap();
|
||||
Some(())
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
}
|
||||
self.add_packet_process_pipeline(Box::new(NicPacketProcessor {
|
||||
nic_channel: self.nic_channel.clone(),
|
||||
}))
|
||||
.await;
|
||||
|
||||
// for peer rpc packet
|
||||
struct PeerRpcPacketProcessor {
|
||||
peer_rpc_tspt_sender: UnboundedSender<Bytes>,
|
||||
}
|
||||
|
||||
#[async_trait::async_trait]
|
||||
impl PeerPacketFilter for PeerRpcPacketProcessor {
|
||||
async fn try_process_packet_from_peer(
|
||||
&self,
|
||||
packet: &packet::ArchivedPacket,
|
||||
data: &Bytes,
|
||||
) -> Option<()> {
|
||||
if packet.packet_type == packet::PacketType::TaRpc {
|
||||
self.peer_rpc_tspt_sender.send(data.clone()).unwrap();
|
||||
Some(())
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
}
|
||||
self.add_packet_process_pipeline(Box::new(PeerRpcPacketProcessor {
|
||||
peer_rpc_tspt_sender: self.peer_rpc_tspt.peer_rpc_tspt_sender.clone(),
|
||||
}))
|
||||
.await;
|
||||
}
|
||||
|
||||
pub async fn add_route<T>(&self, route: T)
|
||||
where
|
||||
T: Route + PeerPacketFilter + Send + Sync + Clone + 'static,
|
||||
{
|
||||
// for route
|
||||
self.add_packet_process_pipeline(Box::new(route.clone()))
|
||||
.await;
|
||||
|
||||
struct Interface {
|
||||
my_peer_id: PeerId,
|
||||
peers: Weak<PeerMap>,
|
||||
foreign_network_client: Weak<ForeignNetworkClient>,
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl RouteInterface for Interface {
|
||||
async fn list_peers(&self) -> Vec<PeerId> {
|
||||
let Some(foreign_client) = self.foreign_network_client.upgrade() else {
|
||||
return vec![];
|
||||
};
|
||||
|
||||
let Some(peer_map) = self.peers.upgrade() else {
|
||||
return vec![];
|
||||
};
|
||||
|
||||
let mut peers = foreign_client.list_foreign_peers();
|
||||
peers.extend(peer_map.list_peers_with_conn().await);
|
||||
peers
|
||||
}
|
||||
async fn send_route_packet(
|
||||
&self,
|
||||
msg: Bytes,
|
||||
route_id: u8,
|
||||
dst_peer_id: PeerId,
|
||||
) -> Result<(), Error> {
|
||||
let foreign_client = self
|
||||
.foreign_network_client
|
||||
.upgrade()
|
||||
.ok_or(Error::Unknown)?;
|
||||
let peer_map = self.peers.upgrade().ok_or(Error::Unknown)?;
|
||||
|
||||
let packet_bytes: Bytes =
|
||||
packet::Packet::new_route_packet(self.my_peer_id, dst_peer_id, route_id, &msg)
|
||||
.into();
|
||||
if foreign_client.has_next_hop(dst_peer_id) {
|
||||
return foreign_client.send_msg(packet_bytes, dst_peer_id).await;
|
||||
}
|
||||
|
||||
peer_map.send_msg_directly(packet_bytes, dst_peer_id).await
|
||||
}
|
||||
fn my_peer_id(&self) -> PeerId {
|
||||
self.my_peer_id
|
||||
}
|
||||
}
|
||||
|
||||
let my_peer_id = self.my_peer_id;
|
||||
let _route_id = route
|
||||
.open(Box::new(Interface {
|
||||
my_peer_id,
|
||||
peers: Arc::downgrade(&self.peers),
|
||||
foreign_network_client: Arc::downgrade(&self.foreign_network_client),
|
||||
}))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let arc_route: ArcRoute = Arc::new(Box::new(route));
|
||||
self.peers.add_route(arc_route).await;
|
||||
}
|
||||
|
||||
pub fn get_route(&self) -> Box<dyn Route + Send + Sync + 'static> {
|
||||
match &self.route_algo_inst {
|
||||
RouteAlgoInst::Rip(route) => Box::new(route.clone()),
|
||||
RouteAlgoInst::Ospf(route) => Box::new(route.clone()),
|
||||
RouteAlgoInst::None => panic!("no route"),
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn list_routes(&self) -> Vec<crate::rpc::Route> {
|
||||
self.get_route().list_routes().await
|
||||
}
|
||||
|
||||
async fn run_nic_packet_process_pipeline(&self, mut data: BytesMut) -> BytesMut {
|
||||
for pipeline in self.nic_packet_process_pipeline.read().await.iter().rev() {
|
||||
data = pipeline.try_process_packet_from_nic(data).await;
|
||||
}
|
||||
data
|
||||
}
|
||||
|
||||
pub async fn send_msg(&self, msg: Bytes, dst_peer_id: PeerId) -> Result<(), Error> {
|
||||
self.peers.send_msg(msg, dst_peer_id).await
|
||||
}
|
||||
|
||||
pub async fn send_msg_ipv4(&self, msg: BytesMut, ipv4_addr: Ipv4Addr) -> Result<(), Error> {
|
||||
log::trace!(
|
||||
"do send_msg in peer manager, msg: {:?}, ipv4_addr: {}",
|
||||
msg,
|
||||
ipv4_addr
|
||||
);
|
||||
|
||||
let mut dst_peers = vec![];
|
||||
// NOTE: currently we only support ipv4 and cidr is 24
|
||||
if ipv4_addr.is_broadcast() || ipv4_addr.is_multicast() || ipv4_addr.octets()[3] == 255 {
|
||||
dst_peers.extend(
|
||||
self.peers
|
||||
.list_routes()
|
||||
.await
|
||||
.iter()
|
||||
.map(|x| x.key().clone()),
|
||||
);
|
||||
} else if let Some(peer_id) = self.peers.get_peer_id_by_ipv4(&ipv4_addr).await {
|
||||
dst_peers.push(peer_id);
|
||||
}
|
||||
|
||||
if dst_peers.is_empty() {
|
||||
tracing::info!("no peer id for ipv4: {}", ipv4_addr);
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let msg = self.run_nic_packet_process_pipeline(msg).await;
|
||||
let mut errs: Vec<Error> = vec![];
|
||||
|
||||
for peer_id in dst_peers.iter() {
|
||||
let msg: Bytes =
|
||||
packet::Packet::new_data_packet(self.my_peer_id, peer_id.clone(), &msg).into();
|
||||
let send_ret = self.peers.send_msg(msg.clone(), *peer_id).await;
|
||||
|
||||
if matches!(send_ret, Err(Error::RouteError(..)))
|
||||
&& self.foreign_network_client.has_next_hop(*peer_id)
|
||||
{
|
||||
let foreign_send_ret = self.foreign_network_client.send_msg(msg, *peer_id).await;
|
||||
if foreign_send_ret.is_ok() {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
if let Err(send_ret) = send_ret {
|
||||
errs.push(send_ret);
|
||||
}
|
||||
}
|
||||
|
||||
tracing::trace!(?dst_peers, "do send_msg in peer manager done");
|
||||
|
||||
if errs.is_empty() {
|
||||
Ok(())
|
||||
} else {
|
||||
tracing::error!(?errs, "send_msg has error");
|
||||
Err(anyhow::anyhow!("send_msg has error: {:?}", errs).into())
|
||||
}
|
||||
}
|
||||
|
||||
async fn run_clean_peer_without_conn_routine(&self) {
|
||||
let peer_map = self.peers.clone();
|
||||
self.tasks.lock().await.spawn(async move {
|
||||
loop {
|
||||
peer_map.clean_peer_without_conn().await;
|
||||
tokio::time::sleep(std::time::Duration::from_secs(3)).await;
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
async fn run_foriegn_network(&self) {
|
||||
self.peer_rpc_tspt
|
||||
.foreign_peers
|
||||
.lock()
|
||||
.await
|
||||
.replace(Arc::downgrade(&self.foreign_network_client));
|
||||
|
||||
self.foreign_network_manager.run().await;
|
||||
self.foreign_network_client.run().await;
|
||||
}
|
||||
|
||||
pub async fn run(&self) -> Result<(), Error> {
|
||||
match &self.route_algo_inst {
|
||||
RouteAlgoInst::Ospf(route) => self.add_route(route.clone()).await,
|
||||
RouteAlgoInst::Rip(route) => self.add_route(route.clone()).await,
|
||||
RouteAlgoInst::None => {}
|
||||
};
|
||||
|
||||
self.init_packet_process_pipeline().await;
|
||||
self.peer_rpc_mgr.run();
|
||||
|
||||
self.start_peer_recv().await;
|
||||
self.run_clean_peer_without_conn_routine().await;
|
||||
|
||||
self.run_foriegn_network().await;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn get_peer_map(&self) -> Arc<PeerMap> {
|
||||
self.peers.clone()
|
||||
}
|
||||
|
||||
pub fn get_peer_rpc_mgr(&self) -> Arc<PeerRpcManager> {
|
||||
self.peer_rpc_mgr.clone()
|
||||
}
|
||||
|
||||
pub fn my_node_id(&self) -> uuid::Uuid {
|
||||
self.global_ctx.get_id()
|
||||
}
|
||||
|
||||
pub fn my_peer_id(&self) -> PeerId {
|
||||
self.my_peer_id
|
||||
}
|
||||
|
||||
pub fn get_global_ctx(&self) -> ArcGlobalCtx {
|
||||
self.global_ctx.clone()
|
||||
}
|
||||
|
||||
pub fn get_nic_channel(&self) -> mpsc::Sender<SinkItem> {
|
||||
self.nic_channel.clone()
|
||||
}
|
||||
|
||||
pub fn get_basic_route(&self) -> Arc<BasicRoute> {
|
||||
match &self.route_algo_inst {
|
||||
RouteAlgoInst::Rip(route) => route.clone(),
|
||||
_ => panic!("not rip route"),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn get_foreign_network_manager(&self) -> Arc<ForeignNetworkManager> {
|
||||
self.foreign_network_manager.clone()
|
||||
}
|
||||
|
||||
pub fn get_foreign_network_client(&self) -> Arc<ForeignNetworkClient> {
|
||||
self.foreign_network_client.clone()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
|
||||
use crate::{
|
||||
connector::udp_hole_punch::tests::create_mock_peer_manager_with_mock_stun,
|
||||
peers::tests::{connect_peer_manager, wait_for_condition, wait_route_appear},
|
||||
rpc::NatType,
|
||||
};
|
||||
|
||||
#[tokio::test]
|
||||
async fn drop_peer_manager() {
|
||||
let peer_mgr_a = create_mock_peer_manager_with_mock_stun(NatType::Unknown).await;
|
||||
let peer_mgr_b = create_mock_peer_manager_with_mock_stun(NatType::Unknown).await;
|
||||
let peer_mgr_c = create_mock_peer_manager_with_mock_stun(NatType::Unknown).await;
|
||||
connect_peer_manager(peer_mgr_a.clone(), peer_mgr_b.clone()).await;
|
||||
connect_peer_manager(peer_mgr_b.clone(), peer_mgr_c.clone()).await;
|
||||
connect_peer_manager(peer_mgr_a.clone(), peer_mgr_c.clone()).await;
|
||||
|
||||
wait_route_appear(peer_mgr_a.clone(), peer_mgr_b.clone())
|
||||
.await
|
||||
.unwrap();
|
||||
wait_route_appear(peer_mgr_a.clone(), peer_mgr_c.clone())
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// wait mgr_a have 2 peers
|
||||
wait_for_condition(
|
||||
|| async { peer_mgr_a.get_peer_map().list_peers_with_conn().await.len() == 2 },
|
||||
std::time::Duration::from_secs(5),
|
||||
)
|
||||
.await;
|
||||
|
||||
drop(peer_mgr_b);
|
||||
|
||||
wait_for_condition(
|
||||
|| async { peer_mgr_a.get_peer_map().list_peers_with_conn().await.len() == 1 },
|
||||
std::time::Duration::from_secs(5),
|
||||
)
|
||||
.await;
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,234 @@
|
||||
use std::{net::Ipv4Addr, sync::Arc};
|
||||
|
||||
use anyhow::Context;
|
||||
use dashmap::DashMap;
|
||||
use tokio::sync::{mpsc, RwLock};
|
||||
use tokio_util::bytes::Bytes;
|
||||
|
||||
use crate::{
|
||||
common::{
|
||||
error::Error,
|
||||
global_ctx::{ArcGlobalCtx, GlobalCtxEvent},
|
||||
PeerId,
|
||||
},
|
||||
rpc::PeerConnInfo,
|
||||
tunnels::TunnelError,
|
||||
};
|
||||
|
||||
use super::{
|
||||
peer::Peer,
|
||||
peer_conn::{PeerConn, PeerConnId},
|
||||
route_trait::ArcRoute,
|
||||
};
|
||||
|
||||
pub struct PeerMap {
|
||||
global_ctx: ArcGlobalCtx,
|
||||
my_peer_id: PeerId,
|
||||
peer_map: DashMap<PeerId, Arc<Peer>>,
|
||||
packet_send: mpsc::Sender<Bytes>,
|
||||
routes: RwLock<Vec<ArcRoute>>,
|
||||
}
|
||||
|
||||
impl PeerMap {
|
||||
pub fn new(
|
||||
packet_send: mpsc::Sender<Bytes>,
|
||||
global_ctx: ArcGlobalCtx,
|
||||
my_peer_id: PeerId,
|
||||
) -> Self {
|
||||
PeerMap {
|
||||
global_ctx,
|
||||
my_peer_id,
|
||||
peer_map: DashMap::new(),
|
||||
packet_send,
|
||||
routes: RwLock::new(Vec::new()),
|
||||
}
|
||||
}
|
||||
|
||||
async fn add_new_peer(&self, peer: Peer) {
|
||||
let peer_id = peer.peer_node_id.clone();
|
||||
self.peer_map.insert(peer_id.clone(), Arc::new(peer));
|
||||
self.global_ctx
|
||||
.issue_event(GlobalCtxEvent::PeerAdded(peer_id));
|
||||
}
|
||||
|
||||
pub async fn add_new_peer_conn(&self, peer_conn: PeerConn) {
|
||||
let peer_id = peer_conn.get_peer_id();
|
||||
let no_entry = self.peer_map.get(&peer_id).is_none();
|
||||
if no_entry {
|
||||
let new_peer = Peer::new(peer_id, self.packet_send.clone(), self.global_ctx.clone());
|
||||
new_peer.add_peer_conn(peer_conn).await;
|
||||
self.add_new_peer(new_peer).await;
|
||||
} else {
|
||||
let peer = self.peer_map.get(&peer_id).unwrap().clone();
|
||||
peer.add_peer_conn(peer_conn).await;
|
||||
}
|
||||
}
|
||||
|
||||
fn get_peer_by_id(&self, peer_id: PeerId) -> Option<Arc<Peer>> {
|
||||
self.peer_map.get(&peer_id).map(|v| v.clone())
|
||||
}
|
||||
|
||||
pub fn has_peer(&self, peer_id: PeerId) -> bool {
|
||||
self.peer_map.contains_key(&peer_id)
|
||||
}
|
||||
|
||||
pub async fn send_msg_directly(&self, msg: Bytes, dst_peer_id: PeerId) -> Result<(), Error> {
|
||||
if dst_peer_id == self.my_peer_id {
|
||||
return Ok(self
|
||||
.packet_send
|
||||
.send(msg)
|
||||
.await
|
||||
.with_context(|| "send msg to self failed")?);
|
||||
}
|
||||
|
||||
match self.get_peer_by_id(dst_peer_id) {
|
||||
Some(peer) => {
|
||||
peer.send_msg(msg).await?;
|
||||
}
|
||||
None => {
|
||||
log::error!("no peer for dst_peer_id: {}", dst_peer_id);
|
||||
return Err(Error::RouteError(None));
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn send_msg(&self, msg: Bytes, dst_peer_id: PeerId) -> Result<(), Error> {
|
||||
if dst_peer_id == self.my_peer_id {
|
||||
return Ok(self
|
||||
.packet_send
|
||||
.send(msg)
|
||||
.await
|
||||
.with_context(|| "send msg to self failed")?);
|
||||
}
|
||||
|
||||
// get route info
|
||||
let mut gateway_peer_id = None;
|
||||
for route in self.routes.read().await.iter() {
|
||||
gateway_peer_id = route.get_next_hop(dst_peer_id).await;
|
||||
if gateway_peer_id.is_none() {
|
||||
continue;
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if gateway_peer_id.is_none() && self.has_peer(dst_peer_id) {
|
||||
gateway_peer_id = Some(dst_peer_id);
|
||||
}
|
||||
|
||||
let Some(gateway_peer_id) = gateway_peer_id else {
|
||||
tracing::trace!(
|
||||
"no gateway for dst_peer_id: {}, peers: {:?}, my_peer_id: {}",
|
||||
dst_peer_id,
|
||||
self.peer_map.iter().map(|v| *v.key()).collect::<Vec<_>>(),
|
||||
self.my_peer_id
|
||||
);
|
||||
return Err(Error::RouteError(None));
|
||||
};
|
||||
|
||||
self.send_msg_directly(msg.clone(), gateway_peer_id).await?;
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
pub async fn get_peer_id_by_ipv4(&self, ipv4: &Ipv4Addr) -> Option<PeerId> {
|
||||
for route in self.routes.read().await.iter() {
|
||||
let peer_id = route.get_peer_id_by_ipv4(ipv4).await;
|
||||
if peer_id.is_some() {
|
||||
return peer_id;
|
||||
}
|
||||
}
|
||||
None
|
||||
}
|
||||
|
||||
pub fn is_empty(&self) -> bool {
|
||||
self.peer_map.is_empty()
|
||||
}
|
||||
|
||||
pub async fn list_peers(&self) -> Vec<PeerId> {
|
||||
let mut ret = Vec::new();
|
||||
for item in self.peer_map.iter() {
|
||||
let peer_id = item.key();
|
||||
ret.push(*peer_id);
|
||||
}
|
||||
ret
|
||||
}
|
||||
|
||||
pub async fn list_peers_with_conn(&self) -> Vec<PeerId> {
|
||||
let mut ret = Vec::new();
|
||||
let peers = self.list_peers().await;
|
||||
for peer_id in peers.iter() {
|
||||
let Some(peer) = self.get_peer_by_id(*peer_id) else {
|
||||
continue;
|
||||
};
|
||||
if peer.list_peer_conns().await.len() > 0 {
|
||||
ret.push(*peer_id);
|
||||
}
|
||||
}
|
||||
ret
|
||||
}
|
||||
|
||||
pub async fn list_peer_conns(&self, peer_id: PeerId) -> Option<Vec<PeerConnInfo>> {
|
||||
if let Some(p) = self.get_peer_by_id(peer_id) {
|
||||
Some(p.list_peer_conns().await)
|
||||
} else {
|
||||
return None;
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn close_peer_conn(
|
||||
&self,
|
||||
peer_id: PeerId,
|
||||
conn_id: &PeerConnId,
|
||||
) -> Result<(), Error> {
|
||||
if let Some(p) = self.get_peer_by_id(peer_id) {
|
||||
p.close_peer_conn(conn_id).await
|
||||
} else {
|
||||
return Err(Error::NotFound);
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn close_peer(&self, peer_id: PeerId) -> Result<(), TunnelError> {
|
||||
let remove_ret = self.peer_map.remove(&peer_id);
|
||||
self.global_ctx
|
||||
.issue_event(GlobalCtxEvent::PeerRemoved(peer_id));
|
||||
tracing::info!(
|
||||
?peer_id,
|
||||
has_old_value = ?remove_ret.is_some(),
|
||||
peer_ref_counter = ?remove_ret.map(|v| Arc::strong_count(&v.1)),
|
||||
"peer is closed"
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn add_route(&self, route: ArcRoute) {
|
||||
let mut routes = self.routes.write().await;
|
||||
routes.insert(0, route);
|
||||
}
|
||||
|
||||
pub async fn clean_peer_without_conn(&self) {
|
||||
let mut to_remove = vec![];
|
||||
|
||||
for peer_id in self.list_peers().await {
|
||||
let conns = self.list_peer_conns(peer_id).await;
|
||||
if conns.is_none() || conns.as_ref().unwrap().is_empty() {
|
||||
to_remove.push(peer_id);
|
||||
}
|
||||
}
|
||||
|
||||
for peer_id in to_remove {
|
||||
self.close_peer(peer_id).await.unwrap();
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn list_routes(&self) -> DashMap<PeerId, PeerId> {
|
||||
let route_map = DashMap::new();
|
||||
for route in self.routes.read().await.iter() {
|
||||
for item in route.list_routes().await.iter() {
|
||||
route_map.insert(item.peer_id, item.next_hop_peer_id);
|
||||
}
|
||||
}
|
||||
route_map
|
||||
}
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,770 @@
|
||||
use std::{
|
||||
net::Ipv4Addr,
|
||||
sync::{atomic::AtomicU32, Arc},
|
||||
time::{Duration, Instant},
|
||||
};
|
||||
|
||||
use async_trait::async_trait;
|
||||
use dashmap::DashMap;
|
||||
use tokio::{
|
||||
sync::{Mutex, RwLock},
|
||||
task::JoinSet,
|
||||
};
|
||||
use tokio_util::bytes::Bytes;
|
||||
use tracing::Instrument;
|
||||
|
||||
use crate::{
|
||||
common::{error::Error, global_ctx::ArcGlobalCtx, stun::StunInfoCollectorTrait, PeerId},
|
||||
peers::{
|
||||
packet,
|
||||
route_trait::{Route, RouteInterfaceBox},
|
||||
},
|
||||
rpc::{NatType, StunInfo},
|
||||
};
|
||||
|
||||
use super::{packet::CtrlPacketPayload, PeerPacketFilter};
|
||||
|
||||
const SEND_ROUTE_PERIOD_SEC: u64 = 60;
|
||||
const SEND_ROUTE_FAST_REPLY_SEC: u64 = 5;
|
||||
const ROUTE_EXPIRED_SEC: u64 = 70;
|
||||
|
||||
type Version = u32;
|
||||
|
||||
#[derive(serde::Deserialize, serde::Serialize, Clone, Debug, PartialEq)]
|
||||
// Derives can be passed through to the generated type:
|
||||
pub struct SyncPeerInfo {
|
||||
// means next hop in route table.
|
||||
pub peer_id: PeerId,
|
||||
pub cost: u32,
|
||||
pub ipv4_addr: Option<Ipv4Addr>,
|
||||
pub proxy_cidrs: Vec<String>,
|
||||
pub hostname: Option<String>,
|
||||
pub udp_stun_info: i8,
|
||||
}
|
||||
|
||||
impl SyncPeerInfo {
|
||||
pub fn new_self(from_peer: PeerId, global_ctx: &ArcGlobalCtx) -> Self {
|
||||
SyncPeerInfo {
|
||||
peer_id: from_peer,
|
||||
cost: 0,
|
||||
ipv4_addr: global_ctx.get_ipv4(),
|
||||
proxy_cidrs: global_ctx
|
||||
.get_proxy_cidrs()
|
||||
.iter()
|
||||
.map(|x| x.to_string())
|
||||
.chain(global_ctx.get_vpn_portal_cidr().map(|x| x.to_string()))
|
||||
.collect(),
|
||||
hostname: global_ctx.get_hostname(),
|
||||
udp_stun_info: global_ctx
|
||||
.get_stun_info_collector()
|
||||
.get_stun_info()
|
||||
.udp_nat_type as i8,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn clone_for_route_table(&self, next_hop: PeerId, cost: u32, from: &Self) -> Self {
|
||||
SyncPeerInfo {
|
||||
peer_id: next_hop,
|
||||
cost,
|
||||
ipv4_addr: from.ipv4_addr.clone(),
|
||||
proxy_cidrs: from.proxy_cidrs.clone(),
|
||||
hostname: from.hostname.clone(),
|
||||
udp_stun_info: from.udp_stun_info,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(serde::Deserialize, serde::Serialize, Clone, Debug)]
|
||||
pub struct SyncPeer {
|
||||
pub myself: SyncPeerInfo,
|
||||
pub neighbors: Vec<SyncPeerInfo>,
|
||||
// the route table version of myself
|
||||
pub version: Version,
|
||||
// the route table version of peer that we have received last time
|
||||
pub peer_version: Option<Version>,
|
||||
// if we do not have latest peer version, need_reply is true
|
||||
pub need_reply: bool,
|
||||
}
|
||||
|
||||
impl SyncPeer {
|
||||
pub fn new(
|
||||
from_peer: PeerId,
|
||||
_to_peer: PeerId,
|
||||
neighbors: Vec<SyncPeerInfo>,
|
||||
global_ctx: ArcGlobalCtx,
|
||||
version: Version,
|
||||
peer_version: Option<Version>,
|
||||
need_reply: bool,
|
||||
) -> Self {
|
||||
SyncPeer {
|
||||
myself: SyncPeerInfo::new_self(from_peer, &global_ctx),
|
||||
neighbors,
|
||||
version,
|
||||
peer_version,
|
||||
need_reply,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct SyncPeerFromRemote {
|
||||
packet: SyncPeer,
|
||||
last_update: std::time::Instant,
|
||||
}
|
||||
|
||||
type SyncPeerFromRemoteMap = Arc<DashMap<PeerId, SyncPeerFromRemote>>;
|
||||
|
||||
#[derive(Debug)]
|
||||
struct RouteTable {
|
||||
route_info: DashMap<PeerId, SyncPeerInfo>,
|
||||
ipv4_peer_id_map: DashMap<Ipv4Addr, PeerId>,
|
||||
cidr_peer_id_map: DashMap<cidr::IpCidr, PeerId>,
|
||||
}
|
||||
|
||||
impl RouteTable {
|
||||
fn new() -> Self {
|
||||
RouteTable {
|
||||
route_info: DashMap::new(),
|
||||
ipv4_peer_id_map: DashMap::new(),
|
||||
cidr_peer_id_map: DashMap::new(),
|
||||
}
|
||||
}
|
||||
|
||||
fn copy_from(&self, other: &Self) {
|
||||
self.route_info.clear();
|
||||
for item in other.route_info.iter() {
|
||||
let (k, v) = item.pair();
|
||||
self.route_info.insert(*k, v.clone());
|
||||
}
|
||||
|
||||
self.ipv4_peer_id_map.clear();
|
||||
for item in other.ipv4_peer_id_map.iter() {
|
||||
let (k, v) = item.pair();
|
||||
self.ipv4_peer_id_map.insert(*k, *v);
|
||||
}
|
||||
|
||||
self.cidr_peer_id_map.clear();
|
||||
for item in other.cidr_peer_id_map.iter() {
|
||||
let (k, v) = item.pair();
|
||||
self.cidr_peer_id_map.insert(*k, *v);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct RouteVersion(Arc<AtomicU32>);
|
||||
|
||||
impl RouteVersion {
|
||||
fn new() -> Self {
|
||||
// RouteVersion(Arc::new(AtomicU32::new(rand::random())))
|
||||
RouteVersion(Arc::new(AtomicU32::new(0)))
|
||||
}
|
||||
|
||||
fn get(&self) -> Version {
|
||||
self.0.load(std::sync::atomic::Ordering::Relaxed)
|
||||
}
|
||||
|
||||
fn inc(&self) {
|
||||
self.0.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
|
||||
}
|
||||
}
|
||||
|
||||
pub struct BasicRoute {
|
||||
my_peer_id: PeerId,
|
||||
global_ctx: ArcGlobalCtx,
|
||||
interface: Arc<Mutex<Option<RouteInterfaceBox>>>,
|
||||
|
||||
route_table: Arc<RouteTable>,
|
||||
|
||||
sync_peer_from_remote: SyncPeerFromRemoteMap,
|
||||
|
||||
tasks: Mutex<JoinSet<()>>,
|
||||
|
||||
need_sync_notifier: Arc<tokio::sync::Notify>,
|
||||
|
||||
version: RouteVersion,
|
||||
myself: Arc<RwLock<SyncPeerInfo>>,
|
||||
last_send_time_map: Arc<DashMap<PeerId, (Version, Option<Version>, Instant)>>,
|
||||
}
|
||||
|
||||
impl BasicRoute {
|
||||
pub fn new(my_peer_id: PeerId, global_ctx: ArcGlobalCtx) -> Self {
|
||||
BasicRoute {
|
||||
my_peer_id,
|
||||
global_ctx: global_ctx.clone(),
|
||||
interface: Arc::new(Mutex::new(None)),
|
||||
|
||||
route_table: Arc::new(RouteTable::new()),
|
||||
|
||||
sync_peer_from_remote: Arc::new(DashMap::new()),
|
||||
tasks: Mutex::new(JoinSet::new()),
|
||||
|
||||
need_sync_notifier: Arc::new(tokio::sync::Notify::new()),
|
||||
|
||||
version: RouteVersion::new(),
|
||||
myself: Arc::new(RwLock::new(SyncPeerInfo::new_self(
|
||||
my_peer_id.into(),
|
||||
&global_ctx,
|
||||
))),
|
||||
last_send_time_map: Arc::new(DashMap::new()),
|
||||
}
|
||||
}
|
||||
|
||||
fn update_route_table(
|
||||
my_id: PeerId,
|
||||
sync_peer_reqs: SyncPeerFromRemoteMap,
|
||||
route_table: Arc<RouteTable>,
|
||||
) {
|
||||
tracing::trace!(my_id = ?my_id, route_table = ?route_table, "update route table");
|
||||
|
||||
let new_route_table = Arc::new(RouteTable::new());
|
||||
for item in sync_peer_reqs.iter() {
|
||||
Self::update_route_table_with_req(my_id, &item.value().packet, new_route_table.clone());
|
||||
}
|
||||
|
||||
route_table.copy_from(&new_route_table);
|
||||
}
|
||||
|
||||
async fn update_myself(
|
||||
my_peer_id: PeerId,
|
||||
myself: &Arc<RwLock<SyncPeerInfo>>,
|
||||
global_ctx: &ArcGlobalCtx,
|
||||
) -> bool {
|
||||
let new_myself = SyncPeerInfo::new_self(my_peer_id, &global_ctx);
|
||||
if *myself.read().await != new_myself {
|
||||
*myself.write().await = new_myself;
|
||||
true
|
||||
} else {
|
||||
false
|
||||
}
|
||||
}
|
||||
|
||||
fn update_route_table_with_req(my_id: PeerId, packet: &SyncPeer, route_table: Arc<RouteTable>) {
|
||||
let peer_id = packet.myself.peer_id.clone();
|
||||
let update = |cost: u32, peer_info: &SyncPeerInfo| {
|
||||
let node_id: PeerId = peer_info.peer_id.into();
|
||||
let ret = route_table
|
||||
.route_info
|
||||
.entry(node_id.clone().into())
|
||||
.and_modify(|info| {
|
||||
if info.cost > cost {
|
||||
*info = info.clone_for_route_table(peer_id, cost, &peer_info);
|
||||
}
|
||||
})
|
||||
.or_insert(
|
||||
peer_info
|
||||
.clone()
|
||||
.clone_for_route_table(peer_id, cost, &peer_info),
|
||||
)
|
||||
.value()
|
||||
.clone();
|
||||
|
||||
if ret.cost > 6 {
|
||||
log::error!(
|
||||
"cost too large: {}, may lost connection, remove it",
|
||||
ret.cost
|
||||
);
|
||||
route_table.route_info.remove(&node_id);
|
||||
}
|
||||
|
||||
log::trace!(
|
||||
"update route info, to: {:?}, gateway: {:?}, cost: {}, peer: {:?}",
|
||||
node_id,
|
||||
peer_id,
|
||||
cost,
|
||||
&peer_info
|
||||
);
|
||||
|
||||
if let Some(ipv4) = peer_info.ipv4_addr {
|
||||
route_table
|
||||
.ipv4_peer_id_map
|
||||
.insert(ipv4.clone(), node_id.clone().into());
|
||||
}
|
||||
|
||||
for cidr in peer_info.proxy_cidrs.iter() {
|
||||
let cidr: cidr::IpCidr = cidr.parse().unwrap();
|
||||
route_table
|
||||
.cidr_peer_id_map
|
||||
.insert(cidr, node_id.clone().into());
|
||||
}
|
||||
};
|
||||
|
||||
for neighbor in packet.neighbors.iter() {
|
||||
if neighbor.peer_id == my_id {
|
||||
continue;
|
||||
}
|
||||
update(neighbor.cost + 1, &neighbor);
|
||||
log::trace!("route info: {:?}", neighbor);
|
||||
}
|
||||
|
||||
// add the sender peer to route info
|
||||
update(1, &packet.myself);
|
||||
|
||||
log::trace!("my_id: {:?}, current route table: {:?}", my_id, route_table);
|
||||
}
|
||||
|
||||
async fn send_sync_peer_request(
|
||||
interface: &RouteInterfaceBox,
|
||||
my_peer_id: PeerId,
|
||||
global_ctx: ArcGlobalCtx,
|
||||
peer_id: PeerId,
|
||||
route_table: Arc<RouteTable>,
|
||||
my_version: Version,
|
||||
peer_version: Option<Version>,
|
||||
need_reply: bool,
|
||||
) -> Result<(), Error> {
|
||||
let mut route_info_copy: Vec<SyncPeerInfo> = Vec::new();
|
||||
// copy the route info
|
||||
for item in route_table.route_info.iter() {
|
||||
let (k, v) = item.pair();
|
||||
route_info_copy.push(v.clone().clone_for_route_table(*k, v.cost, &v));
|
||||
}
|
||||
let msg = SyncPeer::new(
|
||||
my_peer_id,
|
||||
peer_id,
|
||||
route_info_copy,
|
||||
global_ctx,
|
||||
my_version,
|
||||
peer_version,
|
||||
need_reply,
|
||||
);
|
||||
// TODO: this may exceed the MTU of the tunnel
|
||||
interface
|
||||
.send_route_packet(postcard::to_allocvec(&msg).unwrap().into(), 1, peer_id)
|
||||
.await
|
||||
}
|
||||
|
||||
async fn sync_peer_periodically(&self) {
|
||||
let route_table = self.route_table.clone();
|
||||
let global_ctx = self.global_ctx.clone();
|
||||
let my_peer_id = self.my_peer_id.clone();
|
||||
let interface = self.interface.clone();
|
||||
let notifier = self.need_sync_notifier.clone();
|
||||
let sync_peer_from_remote = self.sync_peer_from_remote.clone();
|
||||
let myself = self.myself.clone();
|
||||
let version = self.version.clone();
|
||||
let last_send_time_map = self.last_send_time_map.clone();
|
||||
self.tasks.lock().await.spawn(
|
||||
async move {
|
||||
loop {
|
||||
if Self::update_myself(my_peer_id,&myself, &global_ctx).await {
|
||||
version.inc();
|
||||
tracing::info!(
|
||||
my_id = ?my_peer_id,
|
||||
version = version.get(),
|
||||
"update route table version when myself changed"
|
||||
);
|
||||
}
|
||||
|
||||
let lockd_interface = interface.lock().await;
|
||||
let interface = lockd_interface.as_ref().unwrap();
|
||||
let last_send_time_map_new = DashMap::new();
|
||||
let peers = interface.list_peers().await;
|
||||
for peer in peers.iter() {
|
||||
let last_send_time = last_send_time_map.get(peer).map(|v| *v).unwrap_or((0, None, Instant::now() - Duration::from_secs(3600)));
|
||||
let my_version_peer_saved = sync_peer_from_remote.get(peer).and_then(|v| v.packet.peer_version);
|
||||
let peer_have_latest_version = my_version_peer_saved == Some(version.get());
|
||||
if peer_have_latest_version && last_send_time.2.elapsed().as_secs() < SEND_ROUTE_PERIOD_SEC {
|
||||
last_send_time_map_new.insert(*peer, last_send_time);
|
||||
continue;
|
||||
}
|
||||
|
||||
tracing::trace!(
|
||||
my_id = ?my_peer_id,
|
||||
dst_peer_id = ?peer,
|
||||
version = version.get(),
|
||||
?my_version_peer_saved,
|
||||
last_send_version = ?last_send_time.0,
|
||||
last_send_peer_version = ?last_send_time.1,
|
||||
last_send_elapse = ?last_send_time.2.elapsed().as_secs(),
|
||||
"need send route info"
|
||||
);
|
||||
let peer_version_we_saved = sync_peer_from_remote.get(&peer).and_then(|v| Some(v.packet.version));
|
||||
last_send_time_map_new.insert(*peer, (version.get(), peer_version_we_saved, Instant::now()));
|
||||
let ret = Self::send_sync_peer_request(
|
||||
interface,
|
||||
my_peer_id.clone(),
|
||||
global_ctx.clone(),
|
||||
*peer,
|
||||
route_table.clone(),
|
||||
version.get(),
|
||||
peer_version_we_saved,
|
||||
!peer_have_latest_version,
|
||||
)
|
||||
.await;
|
||||
|
||||
match &ret {
|
||||
Ok(_) => {
|
||||
log::trace!("send sync peer request to peer: {}", peer);
|
||||
}
|
||||
Err(Error::PeerNoConnectionError(_)) => {
|
||||
log::trace!("peer {} no connection", peer);
|
||||
}
|
||||
Err(e) => {
|
||||
log::error!(
|
||||
"send sync peer request to peer: {} error: {:?}",
|
||||
peer,
|
||||
e
|
||||
);
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
last_send_time_map.clear();
|
||||
for item in last_send_time_map_new.iter() {
|
||||
let (k, v) = item.pair();
|
||||
last_send_time_map.insert(*k, *v);
|
||||
}
|
||||
|
||||
tokio::select! {
|
||||
_ = notifier.notified() => {
|
||||
log::trace!("sync peer request triggered by notifier");
|
||||
}
|
||||
_ = tokio::time::sleep(Duration::from_secs(1)) => {
|
||||
log::trace!("sync peer request triggered by timeout");
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
.instrument(
|
||||
tracing::info_span!("sync_peer_periodically", my_id = ?self.my_peer_id, global_ctx = ?self.global_ctx),
|
||||
),
|
||||
);
|
||||
}
|
||||
|
||||
async fn check_expired_sync_peer_from_remote(&self) {
|
||||
let route_table = self.route_table.clone();
|
||||
let my_peer_id = self.my_peer_id.clone();
|
||||
let sync_peer_from_remote = self.sync_peer_from_remote.clone();
|
||||
let notifier = self.need_sync_notifier.clone();
|
||||
let interface = self.interface.clone();
|
||||
let version = self.version.clone();
|
||||
self.tasks.lock().await.spawn(async move {
|
||||
loop {
|
||||
let mut need_update_route = false;
|
||||
let now = std::time::Instant::now();
|
||||
let mut need_remove = Vec::new();
|
||||
let connected_peers = interface.lock().await.as_ref().unwrap().list_peers().await;
|
||||
for item in sync_peer_from_remote.iter() {
|
||||
let (k, v) = item.pair();
|
||||
if now.duration_since(v.last_update).as_secs() > ROUTE_EXPIRED_SEC
|
||||
|| !connected_peers.contains(k)
|
||||
{
|
||||
need_update_route = true;
|
||||
need_remove.insert(0, k.clone());
|
||||
}
|
||||
}
|
||||
|
||||
for k in need_remove.iter() {
|
||||
log::warn!("remove expired sync peer: {:?}", k);
|
||||
sync_peer_from_remote.remove(k);
|
||||
}
|
||||
|
||||
if need_update_route {
|
||||
Self::update_route_table(
|
||||
my_peer_id,
|
||||
sync_peer_from_remote.clone(),
|
||||
route_table.clone(),
|
||||
);
|
||||
version.inc();
|
||||
tracing::info!(
|
||||
my_id = ?my_peer_id,
|
||||
version = version.get(),
|
||||
"update route table when check expired peer"
|
||||
);
|
||||
notifier.notify_one();
|
||||
}
|
||||
|
||||
tokio::time::sleep(Duration::from_secs(1)).await;
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
fn get_peer_id_for_proxy(&self, ipv4: &Ipv4Addr) -> Option<PeerId> {
|
||||
let ipv4 = std::net::IpAddr::V4(*ipv4);
|
||||
for item in self.route_table.cidr_peer_id_map.iter() {
|
||||
let (k, v) = item.pair();
|
||||
if k.contains(&ipv4) {
|
||||
return Some(*v);
|
||||
}
|
||||
}
|
||||
None
|
||||
}
|
||||
|
||||
#[tracing::instrument(skip(self, packet), fields(my_id = ?self.my_peer_id, ctx = ?self.global_ctx))]
|
||||
async fn handle_route_packet(&self, src_peer_id: PeerId, packet: Bytes) {
|
||||
let packet = postcard::from_bytes::<SyncPeer>(&packet).unwrap();
|
||||
let p = &packet;
|
||||
let mut updated = true;
|
||||
assert_eq!(packet.myself.peer_id, src_peer_id);
|
||||
self.sync_peer_from_remote
|
||||
.entry(packet.myself.peer_id.into())
|
||||
.and_modify(|v| {
|
||||
if v.packet.myself == p.myself && v.packet.neighbors == p.neighbors {
|
||||
updated = false;
|
||||
} else {
|
||||
v.packet = p.clone();
|
||||
}
|
||||
v.packet.version = p.version;
|
||||
v.packet.peer_version = p.peer_version;
|
||||
v.last_update = std::time::Instant::now();
|
||||
})
|
||||
.or_insert(SyncPeerFromRemote {
|
||||
packet: p.clone(),
|
||||
last_update: std::time::Instant::now(),
|
||||
});
|
||||
|
||||
if updated {
|
||||
Self::update_route_table(
|
||||
self.my_peer_id.clone(),
|
||||
self.sync_peer_from_remote.clone(),
|
||||
self.route_table.clone(),
|
||||
);
|
||||
self.version.inc();
|
||||
tracing::info!(
|
||||
my_id = ?self.my_peer_id,
|
||||
?p,
|
||||
version = self.version.get(),
|
||||
"update route table when receive route packet"
|
||||
);
|
||||
}
|
||||
|
||||
if packet.need_reply {
|
||||
self.last_send_time_map
|
||||
.entry(packet.myself.peer_id.into())
|
||||
.and_modify(|v| {
|
||||
const FAST_REPLY_DURATION: u64 =
|
||||
SEND_ROUTE_PERIOD_SEC - SEND_ROUTE_FAST_REPLY_SEC;
|
||||
if v.0 != self.version.get() || v.1 != Some(p.version) {
|
||||
v.2 = Instant::now() - Duration::from_secs(3600);
|
||||
} else if v.2.elapsed().as_secs() < FAST_REPLY_DURATION {
|
||||
// do not send same version route info too frequently
|
||||
v.2 = Instant::now() - Duration::from_secs(FAST_REPLY_DURATION);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
if updated || packet.need_reply {
|
||||
self.need_sync_notifier.notify_one();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Route for BasicRoute {
|
||||
async fn open(&self, interface: RouteInterfaceBox) -> Result<u8, ()> {
|
||||
*self.interface.lock().await = Some(interface);
|
||||
self.sync_peer_periodically().await;
|
||||
self.check_expired_sync_peer_from_remote().await;
|
||||
Ok(1)
|
||||
}
|
||||
|
||||
async fn close(&self) {}
|
||||
|
||||
async fn get_next_hop(&self, dst_peer_id: PeerId) -> Option<PeerId> {
|
||||
match self.route_table.route_info.get(&dst_peer_id) {
|
||||
Some(info) => {
|
||||
return Some(info.peer_id.clone().into());
|
||||
}
|
||||
None => {
|
||||
log::error!("no route info for dst_peer_id: {}", dst_peer_id);
|
||||
return None;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn list_routes(&self) -> Vec<crate::rpc::Route> {
|
||||
let mut routes = Vec::new();
|
||||
|
||||
let parse_route_info = |real_peer_id: PeerId, route_info: &SyncPeerInfo| {
|
||||
let mut route = crate::rpc::Route::default();
|
||||
route.ipv4_addr = if let Some(ipv4_addr) = route_info.ipv4_addr {
|
||||
ipv4_addr.to_string()
|
||||
} else {
|
||||
"".to_string()
|
||||
};
|
||||
route.peer_id = real_peer_id;
|
||||
route.next_hop_peer_id = route_info.peer_id;
|
||||
route.cost = route_info.cost as i32;
|
||||
route.proxy_cidrs = route_info.proxy_cidrs.clone();
|
||||
route.hostname = if let Some(hostname) = &route_info.hostname {
|
||||
hostname.clone()
|
||||
} else {
|
||||
"".to_string()
|
||||
};
|
||||
|
||||
let mut stun_info = StunInfo::default();
|
||||
if let Ok(udp_nat_type) = NatType::try_from(route_info.udp_stun_info as i32) {
|
||||
stun_info.set_udp_nat_type(udp_nat_type);
|
||||
}
|
||||
route.stun_info = Some(stun_info);
|
||||
|
||||
route
|
||||
};
|
||||
|
||||
self.route_table.route_info.iter().for_each(|item| {
|
||||
routes.push(parse_route_info(*item.key(), item.value()));
|
||||
});
|
||||
|
||||
routes
|
||||
}
|
||||
|
||||
async fn get_peer_id_by_ipv4(&self, ipv4_addr: &Ipv4Addr) -> Option<PeerId> {
|
||||
if let Some(peer_id) = self.route_table.ipv4_peer_id_map.get(ipv4_addr) {
|
||||
return Some(*peer_id);
|
||||
}
|
||||
|
||||
if let Some(peer_id) = self.get_peer_id_for_proxy(ipv4_addr) {
|
||||
return Some(peer_id);
|
||||
}
|
||||
|
||||
log::info!("no peer id for ipv4: {}", ipv4_addr);
|
||||
return None;
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait::async_trait]
|
||||
impl PeerPacketFilter for BasicRoute {
|
||||
async fn try_process_packet_from_peer(
|
||||
&self,
|
||||
packet: &packet::ArchivedPacket,
|
||||
_data: &Bytes,
|
||||
) -> Option<()> {
|
||||
if packet.packet_type == packet::PacketType::RoutePacket {
|
||||
let CtrlPacketPayload::RoutePacket(route_packet) =
|
||||
CtrlPacketPayload::from_packet(packet)
|
||||
else {
|
||||
return None;
|
||||
};
|
||||
|
||||
self.handle_route_packet(
|
||||
packet.from_peer.into(),
|
||||
route_packet.body.into_boxed_slice().into(),
|
||||
)
|
||||
.await;
|
||||
Some(())
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::sync::Arc;
|
||||
|
||||
use crate::{
|
||||
common::{global_ctx::tests::get_mock_global_ctx, PeerId},
|
||||
connector::udp_hole_punch::tests::replace_stun_info_collector,
|
||||
peers::{
|
||||
peer_manager::{PeerManager, RouteAlgoType},
|
||||
peer_rip_route::Version,
|
||||
tests::{connect_peer_manager, wait_route_appear},
|
||||
},
|
||||
rpc::NatType,
|
||||
};
|
||||
|
||||
async fn create_mock_pmgr() -> Arc<PeerManager> {
|
||||
let (s, _r) = tokio::sync::mpsc::channel(1000);
|
||||
let peer_mgr = Arc::new(PeerManager::new(
|
||||
RouteAlgoType::Rip,
|
||||
get_mock_global_ctx(),
|
||||
s,
|
||||
));
|
||||
replace_stun_info_collector(peer_mgr.clone(), NatType::Unknown);
|
||||
peer_mgr.run().await.unwrap();
|
||||
peer_mgr
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_rip_route() {
|
||||
let peer_mgr_a = create_mock_pmgr().await;
|
||||
let peer_mgr_b = create_mock_pmgr().await;
|
||||
let peer_mgr_c = create_mock_pmgr().await;
|
||||
connect_peer_manager(peer_mgr_a.clone(), peer_mgr_b.clone()).await;
|
||||
connect_peer_manager(peer_mgr_b.clone(), peer_mgr_c.clone()).await;
|
||||
wait_route_appear(peer_mgr_a.clone(), peer_mgr_b.clone())
|
||||
.await
|
||||
.unwrap();
|
||||
wait_route_appear(peer_mgr_a.clone(), peer_mgr_c.clone())
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let mgrs = vec![peer_mgr_a.clone(), peer_mgr_b.clone(), peer_mgr_c.clone()];
|
||||
|
||||
tokio::time::sleep(tokio::time::Duration::from_secs(4)).await;
|
||||
|
||||
let check_version = |version: Version, peer_id: PeerId, mgrs: &Vec<Arc<PeerManager>>| {
|
||||
for mgr in mgrs.iter() {
|
||||
tracing::warn!(
|
||||
"check version: {:?}, {:?}, {:?}, {:?}",
|
||||
version,
|
||||
peer_id,
|
||||
mgr,
|
||||
mgr.get_basic_route().sync_peer_from_remote
|
||||
);
|
||||
assert_eq!(
|
||||
version,
|
||||
mgr.get_basic_route()
|
||||
.sync_peer_from_remote
|
||||
.get(&peer_id)
|
||||
.unwrap()
|
||||
.packet
|
||||
.version,
|
||||
);
|
||||
assert_eq!(
|
||||
mgr.get_basic_route()
|
||||
.sync_peer_from_remote
|
||||
.get(&peer_id)
|
||||
.unwrap()
|
||||
.packet
|
||||
.peer_version
|
||||
.unwrap(),
|
||||
mgr.get_basic_route().version.get()
|
||||
);
|
||||
}
|
||||
};
|
||||
|
||||
let check_sanity = || {
|
||||
// check peer version in other peer mgr are correct.
|
||||
check_version(
|
||||
peer_mgr_b.get_basic_route().version.get(),
|
||||
peer_mgr_b.my_peer_id(),
|
||||
&vec![peer_mgr_a.clone(), peer_mgr_c.clone()],
|
||||
);
|
||||
|
||||
check_version(
|
||||
peer_mgr_a.get_basic_route().version.get(),
|
||||
peer_mgr_a.my_peer_id(),
|
||||
&vec![peer_mgr_b.clone()],
|
||||
);
|
||||
|
||||
check_version(
|
||||
peer_mgr_c.get_basic_route().version.get(),
|
||||
peer_mgr_c.my_peer_id(),
|
||||
&vec![peer_mgr_b.clone()],
|
||||
);
|
||||
};
|
||||
|
||||
check_sanity();
|
||||
|
||||
let versions = mgrs
|
||||
.iter()
|
||||
.map(|x| x.get_basic_route().version.get())
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
tokio::time::sleep(tokio::time::Duration::from_secs(5)).await;
|
||||
|
||||
let versions2 = mgrs
|
||||
.iter()
|
||||
.map(|x| x.get_basic_route().version.get())
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
assert_eq!(versions, versions2);
|
||||
check_sanity();
|
||||
|
||||
assert!(peer_mgr_a.get_basic_route().version.get() <= 3);
|
||||
assert!(peer_mgr_b.get_basic_route().version.get() <= 6);
|
||||
assert!(peer_mgr_c.get_basic_route().version.get() <= 3);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,581 @@
|
||||
use std::sync::{atomic::AtomicU32, Arc};
|
||||
|
||||
use dashmap::DashMap;
|
||||
use futures::{SinkExt, StreamExt};
|
||||
use rkyv::Deserialize;
|
||||
use tarpc::{server::Channel, transport::channel::UnboundedChannel};
|
||||
use tokio::{
|
||||
sync::mpsc::{self, UnboundedSender},
|
||||
task::JoinSet,
|
||||
};
|
||||
use tokio_util::bytes::Bytes;
|
||||
use tracing::Instrument;
|
||||
|
||||
use crate::{
|
||||
common::{error::Error, PeerId},
|
||||
peers::packet::Packet,
|
||||
};
|
||||
|
||||
use super::packet::CtrlPacketPayload;
|
||||
|
||||
type PeerRpcServiceId = u32;
|
||||
type PeerRpcTransactId = u32;
|
||||
|
||||
#[async_trait::async_trait]
|
||||
#[auto_impl::auto_impl(Arc)]
|
||||
pub trait PeerRpcManagerTransport: Send + Sync + 'static {
|
||||
fn my_peer_id(&self) -> PeerId;
|
||||
async fn send(&self, msg: Bytes, dst_peer_id: PeerId) -> Result<(), Error>;
|
||||
async fn recv(&self) -> Result<Bytes, Error>;
|
||||
}
|
||||
|
||||
type PacketSender = UnboundedSender<Packet>;
|
||||
|
||||
struct PeerRpcEndPoint {
|
||||
peer_id: PeerId,
|
||||
packet_sender: PacketSender,
|
||||
tasks: JoinSet<()>,
|
||||
}
|
||||
|
||||
type PeerRpcEndPointCreator = Box<dyn Fn(PeerId) -> PeerRpcEndPoint + Send + Sync + 'static>;
|
||||
#[derive(Hash, Eq, PartialEq, Clone)]
|
||||
struct PeerRpcClientCtxKey(PeerId, PeerRpcServiceId, PeerRpcTransactId);
|
||||
|
||||
// handle rpc request from one peer
|
||||
pub struct PeerRpcManager {
|
||||
service_map: Arc<DashMap<PeerRpcServiceId, PacketSender>>,
|
||||
tasks: JoinSet<()>,
|
||||
tspt: Arc<Box<dyn PeerRpcManagerTransport>>,
|
||||
|
||||
service_registry: Arc<DashMap<PeerRpcServiceId, PeerRpcEndPointCreator>>,
|
||||
peer_rpc_endpoints: Arc<DashMap<(PeerId, PeerRpcServiceId), PeerRpcEndPoint>>,
|
||||
|
||||
client_resp_receivers: Arc<DashMap<PeerRpcClientCtxKey, PacketSender>>,
|
||||
|
||||
transact_id: AtomicU32,
|
||||
}
|
||||
|
||||
impl std::fmt::Debug for PeerRpcManager {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
f.debug_struct("PeerRpcManager")
|
||||
.field("node_id", &self.tspt.my_peer_id())
|
||||
.finish()
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct TaRpcPacketInfo {
|
||||
from_peer: PeerId,
|
||||
to_peer: PeerId,
|
||||
service_id: PeerRpcServiceId,
|
||||
transact_id: PeerRpcTransactId,
|
||||
is_req: bool,
|
||||
content: Vec<u8>,
|
||||
}
|
||||
|
||||
impl PeerRpcManager {
|
||||
pub fn new(tspt: impl PeerRpcManagerTransport) -> Self {
|
||||
Self {
|
||||
service_map: Arc::new(DashMap::new()),
|
||||
tasks: JoinSet::new(),
|
||||
tspt: Arc::new(Box::new(tspt)),
|
||||
|
||||
service_registry: Arc::new(DashMap::new()),
|
||||
peer_rpc_endpoints: Arc::new(DashMap::new()),
|
||||
|
||||
client_resp_receivers: Arc::new(DashMap::new()),
|
||||
|
||||
transact_id: AtomicU32::new(0),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn run_service<S, Req>(self: &Self, service_id: PeerRpcServiceId, s: S) -> ()
|
||||
where
|
||||
S: tarpc::server::Serve<Req> + Clone + Send + Sync + 'static,
|
||||
Req: Send + 'static + serde::Serialize + for<'a> serde::Deserialize<'a>,
|
||||
S::Resp:
|
||||
Send + std::fmt::Debug + 'static + serde::Serialize + for<'a> serde::Deserialize<'a>,
|
||||
S::Fut: Send + 'static,
|
||||
{
|
||||
let tspt = self.tspt.clone();
|
||||
let creator = Box::new(move |peer_id: PeerId| {
|
||||
let mut tasks = JoinSet::new();
|
||||
let (packet_sender, mut packet_receiver) = mpsc::unbounded_channel::<Packet>();
|
||||
let (mut client_transport, server_transport) = tarpc::transport::channel::unbounded();
|
||||
let server = tarpc::server::BaseChannel::with_defaults(server_transport);
|
||||
|
||||
let my_peer_id_clone = tspt.my_peer_id();
|
||||
let peer_id_clone = peer_id.clone();
|
||||
|
||||
let o = server.execute(s.clone());
|
||||
tasks.spawn(o);
|
||||
|
||||
let tspt = tspt.clone();
|
||||
tasks.spawn(async move {
|
||||
let mut cur_req_peer_id = None;
|
||||
let mut cur_transact_id = 0;
|
||||
loop {
|
||||
tokio::select! {
|
||||
Some(resp) = client_transport.next() => {
|
||||
let Some(cur_req_peer_id) = cur_req_peer_id.take() else {
|
||||
tracing::error!("[PEER RPC MGR] cur_req_peer_id is none, ignore this resp");
|
||||
continue;
|
||||
};
|
||||
|
||||
tracing::trace!(resp = ?resp, "recv packet from client");
|
||||
if resp.is_err() {
|
||||
tracing::warn!(err = ?resp.err(),
|
||||
"[PEER RPC MGR] client_transport in server side got channel error, ignore it.");
|
||||
continue;
|
||||
}
|
||||
let resp = resp.unwrap();
|
||||
|
||||
let serialized_resp = postcard::to_allocvec(&resp);
|
||||
if serialized_resp.is_err() {
|
||||
tracing::error!(error = ?serialized_resp.err(), "serialize resp failed");
|
||||
continue;
|
||||
}
|
||||
|
||||
let msg = Packet::new_tarpc_packet(
|
||||
tspt.my_peer_id(),
|
||||
cur_req_peer_id,
|
||||
service_id,
|
||||
cur_transact_id,
|
||||
false,
|
||||
serialized_resp.unwrap(),
|
||||
);
|
||||
|
||||
if let Err(e) = tspt.send(msg.into(), peer_id).await {
|
||||
tracing::error!(error = ?e, peer_id = ?peer_id, service_id = ?service_id, "send resp to peer failed");
|
||||
}
|
||||
}
|
||||
Some(packet) = packet_receiver.recv() => {
|
||||
let info = Self::parse_rpc_packet(&packet);
|
||||
if let Err(e) = info {
|
||||
tracing::error!(error = ?e, packet = ?packet, "parse rpc packet failed");
|
||||
continue;
|
||||
}
|
||||
let info = info.unwrap();
|
||||
|
||||
if info.from_peer != peer_id {
|
||||
tracing::warn!("recv packet from peer, but peer_id not match, ignore it");
|
||||
continue;
|
||||
}
|
||||
|
||||
if cur_req_peer_id.is_some() {
|
||||
tracing::warn!("cur_req_peer_id is not none, ignore this packet");
|
||||
continue;
|
||||
}
|
||||
|
||||
assert_eq!(info.service_id, service_id);
|
||||
cur_req_peer_id = Some(packet.from_peer.clone().into());
|
||||
cur_transact_id = info.transact_id;
|
||||
|
||||
tracing::trace!("recv packet from peer, packet: {:?}", packet);
|
||||
|
||||
let decoded_ret = postcard::from_bytes(&info.content.as_slice());
|
||||
if let Err(e) = decoded_ret {
|
||||
tracing::error!(error = ?e, "decode rpc packet failed");
|
||||
continue;
|
||||
}
|
||||
let decoded: tarpc::ClientMessage<Req> = decoded_ret.unwrap();
|
||||
|
||||
if let Err(e) = client_transport.send(decoded).await {
|
||||
tracing::error!(error = ?e, "send to req to client transport failed");
|
||||
}
|
||||
}
|
||||
else => {
|
||||
tracing::warn!("[PEER RPC MGR] service runner destroy, peer_id: {}, service_id: {}", peer_id, service_id);
|
||||
}
|
||||
}
|
||||
}
|
||||
}.instrument(tracing::info_span!("service_runner", my_id = ?my_peer_id_clone, peer_id = ?peer_id_clone, service_id = ?service_id)));
|
||||
|
||||
tracing::info!(
|
||||
"[PEER RPC MGR] create new service endpoint for peer {}, service {}",
|
||||
peer_id,
|
||||
service_id
|
||||
);
|
||||
|
||||
return PeerRpcEndPoint {
|
||||
peer_id,
|
||||
packet_sender,
|
||||
tasks,
|
||||
};
|
||||
// let resp = client_transport.next().await;
|
||||
});
|
||||
|
||||
if let Some(_) = self.service_registry.insert(service_id, creator) {
|
||||
panic!(
|
||||
"[PEER RPC MGR] service {} is already registered",
|
||||
service_id
|
||||
);
|
||||
}
|
||||
|
||||
log::info!(
|
||||
"[PEER RPC MGR] register service {} succeed, my_node_id {}",
|
||||
service_id,
|
||||
self.tspt.my_peer_id()
|
||||
)
|
||||
}
|
||||
|
||||
fn parse_rpc_packet(packet: &Packet) -> Result<TaRpcPacketInfo, Error> {
|
||||
let ctrl_packet_payload = CtrlPacketPayload::from_packet2(&packet);
|
||||
match &ctrl_packet_payload {
|
||||
CtrlPacketPayload::TaRpc(id, tid, is_req, body) => Ok(TaRpcPacketInfo {
|
||||
from_peer: packet.from_peer.into(),
|
||||
to_peer: packet.to_peer.into(),
|
||||
service_id: *id,
|
||||
transact_id: *tid,
|
||||
is_req: *is_req,
|
||||
content: body.clone(),
|
||||
}),
|
||||
_ => Err(Error::ShellCommandError("invalid packet".to_owned())),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn run(&self) {
|
||||
let tspt = self.tspt.clone();
|
||||
let service_registry = self.service_registry.clone();
|
||||
let peer_rpc_endpoints = self.peer_rpc_endpoints.clone();
|
||||
let client_resp_receivers = self.client_resp_receivers.clone();
|
||||
tokio::spawn(async move {
|
||||
loop {
|
||||
let Ok(o) = tspt.recv().await else {
|
||||
tracing::warn!("peer rpc transport read aborted, exiting");
|
||||
break;
|
||||
};
|
||||
let packet = Packet::decode(&o);
|
||||
let packet: Packet = packet.deserialize(&mut rkyv::Infallible).unwrap();
|
||||
let info = Self::parse_rpc_packet(&packet).unwrap();
|
||||
|
||||
if info.is_req {
|
||||
if !service_registry.contains_key(&info.service_id) {
|
||||
log::warn!(
|
||||
"service {} not found, my_node_id: {}",
|
||||
info.service_id,
|
||||
tspt.my_peer_id()
|
||||
);
|
||||
continue;
|
||||
}
|
||||
|
||||
let endpoint = peer_rpc_endpoints
|
||||
.entry((info.from_peer, info.service_id))
|
||||
.or_insert_with(|| {
|
||||
service_registry.get(&info.service_id).unwrap()(info.from_peer)
|
||||
});
|
||||
|
||||
endpoint.packet_sender.send(packet).unwrap();
|
||||
} else {
|
||||
if let Some(a) = client_resp_receivers.get(&PeerRpcClientCtxKey(
|
||||
info.from_peer,
|
||||
info.service_id,
|
||||
info.transact_id,
|
||||
)) {
|
||||
log::trace!("recv resp: {:?}", packet);
|
||||
if let Err(e) = a.send(packet) {
|
||||
tracing::error!(error = ?e, "send resp to client failed");
|
||||
}
|
||||
} else {
|
||||
log::warn!("client resp receiver not found, info: {:?}", info);
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
#[tracing::instrument(skip(f))]
|
||||
pub async fn do_client_rpc_scoped<CM, Req, RpcRet, Fut>(
|
||||
&self,
|
||||
service_id: PeerRpcServiceId,
|
||||
dst_peer_id: PeerId,
|
||||
f: impl FnOnce(UnboundedChannel<CM, Req>) -> Fut,
|
||||
) -> RpcRet
|
||||
where
|
||||
CM: serde::Serialize + for<'a> serde::Deserialize<'a> + Send + Sync + 'static,
|
||||
Req: serde::Serialize + for<'a> serde::Deserialize<'a> + Send + Sync + 'static,
|
||||
Fut: std::future::Future<Output = RpcRet>,
|
||||
{
|
||||
let mut tasks = JoinSet::new();
|
||||
let (packet_sender, mut packet_receiver) = mpsc::unbounded_channel::<Packet>();
|
||||
let (client_transport, server_transport) =
|
||||
tarpc::transport::channel::unbounded::<CM, Req>();
|
||||
|
||||
let (mut server_s, mut server_r) = server_transport.split();
|
||||
|
||||
let transact_id = self
|
||||
.transact_id
|
||||
.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
|
||||
|
||||
let tspt = self.tspt.clone();
|
||||
tasks.spawn(async move {
|
||||
while let Some(a) = server_r.next().await {
|
||||
if a.is_err() {
|
||||
tracing::error!(error = ?a.err(), "channel error");
|
||||
continue;
|
||||
}
|
||||
|
||||
let a = postcard::to_allocvec(&a.unwrap());
|
||||
if a.is_err() {
|
||||
tracing::error!(error = ?a.err(), "bincode serialize failed");
|
||||
continue;
|
||||
}
|
||||
|
||||
let a = Packet::new_tarpc_packet(
|
||||
tspt.my_peer_id(),
|
||||
dst_peer_id,
|
||||
service_id,
|
||||
transact_id,
|
||||
true,
|
||||
a.unwrap(),
|
||||
);
|
||||
|
||||
if let Err(e) = tspt.send(a.into(), dst_peer_id).await {
|
||||
tracing::error!(error = ?e, dst_peer_id = ?dst_peer_id, "send to peer failed");
|
||||
}
|
||||
}
|
||||
|
||||
tracing::warn!("[PEER RPC MGR] server trasport read aborted");
|
||||
});
|
||||
|
||||
tasks.spawn(async move {
|
||||
while let Some(packet) = packet_receiver.recv().await {
|
||||
tracing::trace!("tunnel recv: {:?}", packet);
|
||||
|
||||
let info = PeerRpcManager::parse_rpc_packet(&packet);
|
||||
if let Err(e) = info {
|
||||
tracing::error!(error = ?e, "parse rpc packet failed");
|
||||
continue;
|
||||
}
|
||||
|
||||
let decoded = postcard::from_bytes(&info.unwrap().content.as_slice());
|
||||
if let Err(e) = decoded {
|
||||
tracing::error!(error = ?e, "decode rpc packet failed");
|
||||
continue;
|
||||
}
|
||||
|
||||
if let Err(e) = server_s.send(decoded.unwrap()).await {
|
||||
tracing::error!(error = ?e, "send to rpc server channel failed");
|
||||
}
|
||||
}
|
||||
|
||||
tracing::warn!("[PEER RPC MGR] server packet read aborted");
|
||||
});
|
||||
|
||||
let key = PeerRpcClientCtxKey(dst_peer_id, service_id, transact_id);
|
||||
let _insert_ret = self
|
||||
.client_resp_receivers
|
||||
.insert(key.clone(), packet_sender);
|
||||
|
||||
let ret = f(client_transport).await;
|
||||
|
||||
self.client_resp_receivers.remove(&key);
|
||||
|
||||
ret
|
||||
}
|
||||
|
||||
pub fn my_peer_id(&self) -> PeerId {
|
||||
self.tspt.my_peer_id()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use futures::{SinkExt, StreamExt};
|
||||
use tokio_util::bytes::Bytes;
|
||||
|
||||
use crate::{
|
||||
common::{error::Error, new_peer_id, PeerId},
|
||||
peers::{
|
||||
peer_rpc::PeerRpcManager,
|
||||
tests::{connect_peer_manager, create_mock_peer_manager, wait_route_appear},
|
||||
},
|
||||
tunnels::{self, ring_tunnel::create_ring_tunnel_pair},
|
||||
};
|
||||
|
||||
use super::PeerRpcManagerTransport;
|
||||
|
||||
#[tarpc::service]
|
||||
pub trait TestRpcService {
|
||||
async fn hello(s: String) -> String;
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
struct MockService {
|
||||
prefix: String,
|
||||
}
|
||||
|
||||
#[tarpc::server]
|
||||
impl TestRpcService for MockService {
|
||||
async fn hello(self, _: tarpc::context::Context, s: String) -> String {
|
||||
format!("{} {}", self.prefix, s)
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn peer_rpc_basic_test() {
|
||||
struct MockTransport {
|
||||
tunnel: Box<dyn tunnels::Tunnel>,
|
||||
my_peer_id: PeerId,
|
||||
}
|
||||
|
||||
#[async_trait::async_trait]
|
||||
impl PeerRpcManagerTransport for MockTransport {
|
||||
fn my_peer_id(&self) -> PeerId {
|
||||
self.my_peer_id
|
||||
}
|
||||
async fn send(&self, msg: Bytes, _dst_peer_id: PeerId) -> Result<(), Error> {
|
||||
println!("rpc mgr send: {:?}", msg);
|
||||
self.tunnel.pin_sink().send(msg).await.unwrap();
|
||||
Ok(())
|
||||
}
|
||||
async fn recv(&self) -> Result<Bytes, Error> {
|
||||
let ret = self.tunnel.pin_stream().next().await.unwrap();
|
||||
println!("rpc mgr recv: {:?}", ret);
|
||||
return ret.map(|v| v.freeze()).map_err(|_| Error::Unknown);
|
||||
}
|
||||
}
|
||||
|
||||
let (ct, st) = create_ring_tunnel_pair();
|
||||
|
||||
let server_rpc_mgr = PeerRpcManager::new(MockTransport {
|
||||
tunnel: st,
|
||||
my_peer_id: new_peer_id(),
|
||||
});
|
||||
server_rpc_mgr.run();
|
||||
let s = MockService {
|
||||
prefix: "hello".to_owned(),
|
||||
};
|
||||
server_rpc_mgr.run_service(1, s.serve());
|
||||
|
||||
let client_rpc_mgr = PeerRpcManager::new(MockTransport {
|
||||
tunnel: ct,
|
||||
my_peer_id: new_peer_id(),
|
||||
});
|
||||
client_rpc_mgr.run();
|
||||
|
||||
let ret = client_rpc_mgr
|
||||
.do_client_rpc_scoped(1, server_rpc_mgr.my_peer_id(), |c| async {
|
||||
let c = TestRpcServiceClient::new(tarpc::client::Config::default(), c).spawn();
|
||||
let ret = c.hello(tarpc::context::current(), "abc".to_owned()).await;
|
||||
ret
|
||||
})
|
||||
.await;
|
||||
|
||||
println!("ret: {:?}", ret);
|
||||
assert_eq!(ret.unwrap(), "hello abc");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_rpc_with_peer_manager() {
|
||||
let peer_mgr_a = create_mock_peer_manager().await;
|
||||
let peer_mgr_b = create_mock_peer_manager().await;
|
||||
let peer_mgr_c = create_mock_peer_manager().await;
|
||||
connect_peer_manager(peer_mgr_a.clone(), peer_mgr_b.clone()).await;
|
||||
connect_peer_manager(peer_mgr_b.clone(), peer_mgr_c.clone()).await;
|
||||
wait_route_appear(peer_mgr_a.clone(), peer_mgr_b.clone())
|
||||
.await
|
||||
.unwrap();
|
||||
wait_route_appear(peer_mgr_a.clone(), peer_mgr_c.clone())
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(peer_mgr_a.get_peer_map().list_peers().await.len(), 1);
|
||||
assert_eq!(
|
||||
peer_mgr_a.get_peer_map().list_peers().await[0],
|
||||
peer_mgr_b.my_peer_id()
|
||||
);
|
||||
|
||||
assert_eq!(peer_mgr_c.get_peer_map().list_peers().await.len(), 1);
|
||||
assert_eq!(
|
||||
peer_mgr_c.get_peer_map().list_peers().await[0],
|
||||
peer_mgr_b.my_peer_id()
|
||||
);
|
||||
|
||||
let s = MockService {
|
||||
prefix: "hello".to_owned(),
|
||||
};
|
||||
peer_mgr_b.get_peer_rpc_mgr().run_service(1, s.serve());
|
||||
|
||||
let ip_list = peer_mgr_a
|
||||
.get_peer_rpc_mgr()
|
||||
.do_client_rpc_scoped(1, peer_mgr_b.my_peer_id(), |c| async {
|
||||
let c = TestRpcServiceClient::new(tarpc::client::Config::default(), c).spawn();
|
||||
let ret = c.hello(tarpc::context::current(), "abc".to_owned()).await;
|
||||
ret
|
||||
})
|
||||
.await;
|
||||
println!("ip_list: {:?}", ip_list);
|
||||
assert_eq!(ip_list.as_ref().unwrap(), "hello abc");
|
||||
|
||||
// call again
|
||||
let ip_list = peer_mgr_a
|
||||
.get_peer_rpc_mgr()
|
||||
.do_client_rpc_scoped(1, peer_mgr_b.my_peer_id(), |c| async {
|
||||
let c = TestRpcServiceClient::new(tarpc::client::Config::default(), c).spawn();
|
||||
let ret = c.hello(tarpc::context::current(), "abcd".to_owned()).await;
|
||||
ret
|
||||
})
|
||||
.await;
|
||||
println!("ip_list: {:?}", ip_list);
|
||||
assert_eq!(ip_list.as_ref().unwrap(), "hello abcd");
|
||||
|
||||
let ip_list = peer_mgr_c
|
||||
.get_peer_rpc_mgr()
|
||||
.do_client_rpc_scoped(1, peer_mgr_b.my_peer_id(), |c| async {
|
||||
let c = TestRpcServiceClient::new(tarpc::client::Config::default(), c).spawn();
|
||||
let ret = c.hello(tarpc::context::current(), "bcd".to_owned()).await;
|
||||
ret
|
||||
})
|
||||
.await;
|
||||
println!("ip_list: {:?}", ip_list);
|
||||
assert_eq!(ip_list.as_ref().unwrap(), "hello bcd");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_multi_service_with_peer_manager() {
|
||||
let peer_mgr_a = create_mock_peer_manager().await;
|
||||
let peer_mgr_b = create_mock_peer_manager().await;
|
||||
connect_peer_manager(peer_mgr_a.clone(), peer_mgr_b.clone()).await;
|
||||
wait_route_appear(peer_mgr_a.clone(), peer_mgr_b.clone())
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(peer_mgr_a.get_peer_map().list_peers().await.len(), 1);
|
||||
assert_eq!(
|
||||
peer_mgr_a.get_peer_map().list_peers().await[0],
|
||||
peer_mgr_b.my_peer_id()
|
||||
);
|
||||
|
||||
let s = MockService {
|
||||
prefix: "hello_a".to_owned(),
|
||||
};
|
||||
peer_mgr_b.get_peer_rpc_mgr().run_service(1, s.serve());
|
||||
let b = MockService {
|
||||
prefix: "hello_b".to_owned(),
|
||||
};
|
||||
peer_mgr_b.get_peer_rpc_mgr().run_service(2, b.serve());
|
||||
|
||||
let ip_list = peer_mgr_a
|
||||
.get_peer_rpc_mgr()
|
||||
.do_client_rpc_scoped(1, peer_mgr_b.my_peer_id(), |c| async {
|
||||
let c = TestRpcServiceClient::new(tarpc::client::Config::default(), c).spawn();
|
||||
let ret = c.hello(tarpc::context::current(), "abc".to_owned()).await;
|
||||
ret
|
||||
})
|
||||
.await;
|
||||
|
||||
assert_eq!(ip_list.as_ref().unwrap(), "hello_a abc");
|
||||
|
||||
let ip_list = peer_mgr_a
|
||||
.get_peer_rpc_mgr()
|
||||
.do_client_rpc_scoped(2, peer_mgr_b.my_peer_id(), |c| async {
|
||||
let c = TestRpcServiceClient::new(tarpc::client::Config::default(), c).spawn();
|
||||
let ret = c.hello(tarpc::context::current(), "abc".to_owned()).await;
|
||||
ret
|
||||
})
|
||||
.await;
|
||||
|
||||
assert_eq!(ip_list.as_ref().unwrap(), "hello_b abc");
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,36 @@
|
||||
use std::{net::Ipv4Addr, sync::Arc};
|
||||
|
||||
use async_trait::async_trait;
|
||||
use tokio_util::bytes::Bytes;
|
||||
|
||||
use crate::common::{error::Error, PeerId};
|
||||
|
||||
#[async_trait]
|
||||
pub trait RouteInterface {
|
||||
async fn list_peers(&self) -> Vec<PeerId>;
|
||||
async fn send_route_packet(
|
||||
&self,
|
||||
msg: Bytes,
|
||||
route_id: u8,
|
||||
dst_peer_id: PeerId,
|
||||
) -> Result<(), Error>;
|
||||
fn my_peer_id(&self) -> PeerId;
|
||||
}
|
||||
|
||||
pub type RouteInterfaceBox = Box<dyn RouteInterface + Send + Sync>;
|
||||
|
||||
#[async_trait]
|
||||
#[auto_impl::auto_impl(Box, Arc)]
|
||||
pub trait Route {
|
||||
async fn open(&self, interface: RouteInterfaceBox) -> Result<u8, ()>;
|
||||
async fn close(&self);
|
||||
|
||||
async fn get_next_hop(&self, peer_id: PeerId) -> Option<PeerId>;
|
||||
async fn list_routes(&self) -> Vec<crate::rpc::Route>;
|
||||
|
||||
async fn get_peer_id_by_ipv4(&self, _ipv4: &Ipv4Addr) -> Option<PeerId> {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
pub type ArcRoute = Arc<Box<dyn Route + Send + Sync>>;
|
||||
@@ -0,0 +1,63 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use crate::rpc::{
|
||||
cli::PeerInfo,
|
||||
peer_manage_rpc_server::PeerManageRpc,
|
||||
{ListPeerRequest, ListPeerResponse, ListRouteRequest, ListRouteResponse},
|
||||
};
|
||||
use tonic::{Request, Response, Status};
|
||||
|
||||
use super::peer_manager::PeerManager;
|
||||
|
||||
pub struct PeerManagerRpcService {
|
||||
peer_manager: Arc<PeerManager>,
|
||||
}
|
||||
|
||||
impl PeerManagerRpcService {
|
||||
pub fn new(peer_manager: Arc<PeerManager>) -> Self {
|
||||
PeerManagerRpcService { peer_manager }
|
||||
}
|
||||
|
||||
pub async fn list_peers(&self) -> Vec<PeerInfo> {
|
||||
let peers = self.peer_manager.get_peer_map().list_peers().await;
|
||||
let mut peer_infos = Vec::new();
|
||||
for peer in peers {
|
||||
let mut peer_info = PeerInfo::default();
|
||||
peer_info.peer_id = peer;
|
||||
|
||||
if let Some(conns) = self.peer_manager.get_peer_map().list_peer_conns(peer).await {
|
||||
peer_info.conns = conns;
|
||||
}
|
||||
|
||||
peer_infos.push(peer_info);
|
||||
}
|
||||
|
||||
peer_infos
|
||||
}
|
||||
}
|
||||
|
||||
#[tonic::async_trait]
|
||||
impl PeerManageRpc for PeerManagerRpcService {
|
||||
async fn list_peer(
|
||||
&self,
|
||||
_request: Request<ListPeerRequest>, // Accept request of type HelloRequest
|
||||
) -> Result<Response<ListPeerResponse>, Status> {
|
||||
let mut reply = ListPeerResponse::default();
|
||||
|
||||
let peers = self.list_peers().await;
|
||||
for peer in peers {
|
||||
reply.peer_infos.push(peer);
|
||||
}
|
||||
|
||||
Ok(Response::new(reply))
|
||||
}
|
||||
|
||||
async fn list_route(
|
||||
&self,
|
||||
_request: Request<ListRouteRequest>, // Accept request of type HelloRequest
|
||||
) -> Result<Response<ListRouteResponse>, Status> {
|
||||
let mut reply = ListRouteResponse::default();
|
||||
reply.routes = self.peer_manager.list_routes().await;
|
||||
Ok(Response::new(reply))
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,75 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use futures::Future;
|
||||
|
||||
use crate::{
|
||||
common::{error::Error, global_ctx::tests::get_mock_global_ctx, PeerId},
|
||||
tunnels::ring_tunnel::create_ring_tunnel_pair,
|
||||
};
|
||||
|
||||
use super::peer_manager::{PeerManager, RouteAlgoType};
|
||||
|
||||
pub async fn create_mock_peer_manager() -> Arc<PeerManager> {
|
||||
let (s, _r) = tokio::sync::mpsc::channel(1000);
|
||||
let peer_mgr = Arc::new(PeerManager::new(
|
||||
RouteAlgoType::Ospf,
|
||||
get_mock_global_ctx(),
|
||||
s,
|
||||
));
|
||||
peer_mgr.run().await.unwrap();
|
||||
peer_mgr
|
||||
}
|
||||
|
||||
pub async fn connect_peer_manager(client: Arc<PeerManager>, server: Arc<PeerManager>) {
|
||||
let (a_ring, b_ring) = create_ring_tunnel_pair();
|
||||
let a_mgr_copy = client.clone();
|
||||
tokio::spawn(async move {
|
||||
a_mgr_copy.add_client_tunnel(a_ring).await.unwrap();
|
||||
});
|
||||
let b_mgr_copy = server.clone();
|
||||
tokio::spawn(async move {
|
||||
b_mgr_copy.add_tunnel_as_server(b_ring).await.unwrap();
|
||||
});
|
||||
}
|
||||
|
||||
pub async fn wait_route_appear_with_cost(
|
||||
peer_mgr: Arc<PeerManager>,
|
||||
node_id: PeerId,
|
||||
cost: Option<i32>,
|
||||
) -> Result<(), Error> {
|
||||
let now = std::time::Instant::now();
|
||||
while now.elapsed().as_secs() < 5 {
|
||||
let route = peer_mgr.list_routes().await;
|
||||
if route
|
||||
.iter()
|
||||
.any(|r| r.peer_id == node_id && (cost.is_none() || r.cost == cost.unwrap()))
|
||||
{
|
||||
return Ok(());
|
||||
}
|
||||
tokio::time::sleep(std::time::Duration::from_millis(50)).await;
|
||||
}
|
||||
return Err(Error::NotFound);
|
||||
}
|
||||
|
||||
pub async fn wait_route_appear(
|
||||
peer_mgr: Arc<PeerManager>,
|
||||
target_peer: Arc<PeerManager>,
|
||||
) -> Result<(), Error> {
|
||||
wait_route_appear_with_cost(peer_mgr.clone(), target_peer.my_peer_id(), None).await?;
|
||||
wait_route_appear_with_cost(target_peer, peer_mgr.my_peer_id(), None).await
|
||||
}
|
||||
|
||||
pub async fn wait_for_condition<F, FRet>(mut condition: F, timeout: std::time::Duration) -> ()
|
||||
where
|
||||
F: FnMut() -> FRet + Send,
|
||||
FRet: Future<Output = bool>,
|
||||
{
|
||||
let now = std::time::Instant::now();
|
||||
while now.elapsed() < timeout {
|
||||
if condition().await {
|
||||
return;
|
||||
}
|
||||
tokio::time::sleep(std::time::Duration::from_millis(50)).await;
|
||||
}
|
||||
assert!(condition().await, "Timeout")
|
||||
}
|
||||
@@ -0,0 +1 @@
|
||||
tonic::include_proto!("cli"); // The string specified here must match the proto package name
|
||||
@@ -0,0 +1,4 @@
|
||||
pub mod cli;
|
||||
pub use cli::*;
|
||||
|
||||
pub mod peer;
|
||||
@@ -0,0 +1,22 @@
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
#[derive(Clone, PartialEq, Debug, Serialize, Deserialize)]
|
||||
pub struct GetIpListResponse {
|
||||
pub public_ipv4: String,
|
||||
pub interface_ipv4s: Vec<String>,
|
||||
pub public_ipv6: String,
|
||||
pub interface_ipv6s: Vec<String>,
|
||||
pub listeners: Vec<url::Url>,
|
||||
}
|
||||
|
||||
impl GetIpListResponse {
|
||||
pub fn new() -> Self {
|
||||
GetIpListResponse {
|
||||
public_ipv4: "".to_string(),
|
||||
interface_ipv4s: vec![],
|
||||
public_ipv6: "".to_string(),
|
||||
interface_ipv6s: vec![],
|
||||
listeners: vec![],
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,181 @@
|
||||
use crate::common::PeerId;
|
||||
|
||||
mod three_node;
|
||||
|
||||
pub fn get_guest_veth_name(net_ns: &str) -> &str {
|
||||
Box::leak(format!("veth_{}_g", net_ns).into_boxed_str())
|
||||
}
|
||||
|
||||
pub fn get_host_veth_name(net_ns: &str) -> &str {
|
||||
Box::leak(format!("veth_{}_h", net_ns).into_boxed_str())
|
||||
}
|
||||
|
||||
pub fn del_netns(name: &str) {
|
||||
// del veth host
|
||||
let _ = std::process::Command::new("ip")
|
||||
.args(&["link", "del", get_host_veth_name(name)])
|
||||
.output();
|
||||
|
||||
let _ = std::process::Command::new("ip")
|
||||
.args(&["netns", "del", name])
|
||||
.output();
|
||||
}
|
||||
|
||||
pub fn create_netns(name: &str, ipv4: &str) {
|
||||
// create netns
|
||||
let _ = std::process::Command::new("ip")
|
||||
.args(&["netns", "add", name])
|
||||
.output()
|
||||
.unwrap();
|
||||
|
||||
// set lo up
|
||||
let _ = std::process::Command::new("ip")
|
||||
.args(&["netns", "exec", name, "ip", "link", "set", "lo", "up"])
|
||||
.output()
|
||||
.unwrap();
|
||||
|
||||
let _ = std::process::Command::new("ip")
|
||||
.args(&[
|
||||
"link",
|
||||
"add",
|
||||
get_host_veth_name(name),
|
||||
"type",
|
||||
"veth",
|
||||
"peer",
|
||||
"name",
|
||||
get_guest_veth_name(name),
|
||||
])
|
||||
.output()
|
||||
.unwrap();
|
||||
|
||||
let _ = std::process::Command::new("ip")
|
||||
.args(&["link", "set", get_guest_veth_name(name), "netns", name])
|
||||
.output()
|
||||
.unwrap();
|
||||
|
||||
let _ = std::process::Command::new("ip")
|
||||
.args(&[
|
||||
"netns",
|
||||
"exec",
|
||||
name,
|
||||
"ip",
|
||||
"link",
|
||||
"set",
|
||||
get_guest_veth_name(name),
|
||||
"up",
|
||||
])
|
||||
.output()
|
||||
.unwrap();
|
||||
|
||||
let _ = std::process::Command::new("ip")
|
||||
.args(&["link", "set", get_host_veth_name(name), "up"])
|
||||
.output()
|
||||
.unwrap();
|
||||
|
||||
let _ = std::process::Command::new("ip")
|
||||
.args(&[
|
||||
"netns",
|
||||
"exec",
|
||||
name,
|
||||
"ip",
|
||||
"addr",
|
||||
"add",
|
||||
ipv4,
|
||||
"dev",
|
||||
get_guest_veth_name(name),
|
||||
])
|
||||
.output()
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
pub fn prepare_bridge(name: &str) {
|
||||
// del bridge with brctl
|
||||
let _ = std::process::Command::new("brctl")
|
||||
.args(&["delbr", name])
|
||||
.output();
|
||||
|
||||
// create new br
|
||||
let _ = std::process::Command::new("brctl")
|
||||
.args(&["addbr", name])
|
||||
.output();
|
||||
}
|
||||
|
||||
pub fn add_ns_to_bridge(br_name: &str, ns_name: &str) {
|
||||
// use brctl to add ns to bridge
|
||||
let _ = std::process::Command::new("brctl")
|
||||
.args(&["addif", br_name, get_host_veth_name(ns_name)])
|
||||
.output()
|
||||
.unwrap();
|
||||
|
||||
// set bridge up
|
||||
let _ = std::process::Command::new("ip")
|
||||
.args(&["link", "set", br_name, "up"])
|
||||
.output()
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
pub fn enable_log() {
|
||||
let filter = tracing_subscriber::EnvFilter::builder()
|
||||
.with_default_directive(tracing::level_filters::LevelFilter::INFO.into())
|
||||
.from_env()
|
||||
.unwrap()
|
||||
.add_directive("tarpc=error".parse().unwrap());
|
||||
tracing_subscriber::fmt::fmt()
|
||||
.pretty()
|
||||
.with_env_filter(filter)
|
||||
.init();
|
||||
}
|
||||
|
||||
fn check_route(ipv4: &str, dst_peer_id: PeerId, routes: Vec<crate::rpc::Route>) {
|
||||
let mut found = false;
|
||||
for r in routes.iter() {
|
||||
if r.ipv4_addr == ipv4.to_string() {
|
||||
found = true;
|
||||
assert_eq!(r.peer_id, dst_peer_id, "{:?}", routes);
|
||||
}
|
||||
}
|
||||
assert!(
|
||||
found,
|
||||
"routes: {:?}, dst_peer_id: {}, ipv4: {}",
|
||||
routes, dst_peer_id, ipv4
|
||||
);
|
||||
}
|
||||
|
||||
async fn wait_proxy_route_appear(
|
||||
mgr: &std::sync::Arc<crate::peers::peer_manager::PeerManager>,
|
||||
ipv4: &str,
|
||||
dst_peer_id: PeerId,
|
||||
proxy_cidr: &str,
|
||||
) {
|
||||
let now = std::time::Instant::now();
|
||||
loop {
|
||||
for r in mgr.list_routes().await.iter() {
|
||||
let r = r;
|
||||
if r.proxy_cidrs.contains(&proxy_cidr.to_owned()) {
|
||||
assert_eq!(r.peer_id, dst_peer_id);
|
||||
assert_eq!(r.ipv4_addr, ipv4);
|
||||
return;
|
||||
}
|
||||
}
|
||||
if now.elapsed().as_secs() > 5 {
|
||||
panic!("wait proxy route appear timeout");
|
||||
}
|
||||
tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
|
||||
}
|
||||
}
|
||||
|
||||
fn set_link_status(net_ns: &str, up: bool) {
|
||||
let _ = std::process::Command::new("ip")
|
||||
.args(&[
|
||||
"netns",
|
||||
"exec",
|
||||
net_ns,
|
||||
"ip",
|
||||
"link",
|
||||
"set",
|
||||
get_guest_veth_name(net_ns),
|
||||
if up { "up" } else { "down" },
|
||||
])
|
||||
.output()
|
||||
.unwrap();
|
||||
}
|
||||
@@ -0,0 +1,409 @@
|
||||
use std::{
|
||||
sync::{atomic::AtomicU32, Arc},
|
||||
time::Duration,
|
||||
};
|
||||
|
||||
use tokio::{net::UdpSocket, task::JoinSet};
|
||||
|
||||
use super::*;
|
||||
|
||||
use crate::{
|
||||
common::{
|
||||
config::{ConfigLoader, NetworkIdentity, TomlConfigLoader},
|
||||
netns::{NetNS, ROOT_NETNS_NAME},
|
||||
},
|
||||
instance::instance::Instance,
|
||||
peers::tests::wait_for_condition,
|
||||
tunnels::{
|
||||
common::tests::_tunnel_pingpong_netns,
|
||||
ring_tunnel::RingTunnelConnector,
|
||||
tcp_tunnel::{TcpTunnelConnector, TcpTunnelListener},
|
||||
udp_tunnel::{UdpTunnelConnector, UdpTunnelListener},
|
||||
wireguard::{WgConfig, WgTunnelConnector},
|
||||
},
|
||||
};
|
||||
|
||||
pub fn prepare_linux_namespaces() {
|
||||
del_netns("net_a");
|
||||
del_netns("net_b");
|
||||
del_netns("net_c");
|
||||
del_netns("net_d");
|
||||
|
||||
create_netns("net_a", "10.1.1.1/24");
|
||||
create_netns("net_b", "10.1.1.2/24");
|
||||
create_netns("net_c", "10.1.2.3/24");
|
||||
create_netns("net_d", "10.1.2.4/24");
|
||||
|
||||
prepare_bridge("br_a");
|
||||
prepare_bridge("br_b");
|
||||
|
||||
add_ns_to_bridge("br_a", "net_a");
|
||||
add_ns_to_bridge("br_a", "net_b");
|
||||
add_ns_to_bridge("br_b", "net_c");
|
||||
add_ns_to_bridge("br_b", "net_d");
|
||||
}
|
||||
|
||||
pub fn get_inst_config(inst_name: &str, ns: Option<&str>, ipv4: &str) -> TomlConfigLoader {
|
||||
let config = TomlConfigLoader::default();
|
||||
config.set_inst_name(inst_name.to_owned());
|
||||
config.set_netns(ns.map(|s| s.to_owned()));
|
||||
config.set_ipv4(ipv4.parse().unwrap());
|
||||
config.set_listeners(vec![
|
||||
"tcp://0.0.0.0:11010".parse().unwrap(),
|
||||
"udp://0.0.0.0:11010".parse().unwrap(),
|
||||
"wg://0.0.0.0:11011".parse().unwrap(),
|
||||
]);
|
||||
config
|
||||
}
|
||||
|
||||
pub async fn init_three_node(proto: &str) -> Vec<Instance> {
|
||||
log::set_max_level(log::LevelFilter::Info);
|
||||
prepare_linux_namespaces();
|
||||
|
||||
let mut inst1 = Instance::new(get_inst_config("inst1", Some("net_a"), "10.144.144.1"));
|
||||
let mut inst2 = Instance::new(get_inst_config("inst2", Some("net_b"), "10.144.144.2"));
|
||||
let mut inst3 = Instance::new(get_inst_config("inst3", Some("net_c"), "10.144.144.3"));
|
||||
|
||||
inst1.run().await.unwrap();
|
||||
inst2.run().await.unwrap();
|
||||
inst3.run().await.unwrap();
|
||||
|
||||
if proto == "tcp" {
|
||||
inst2
|
||||
.get_conn_manager()
|
||||
.add_connector(TcpTunnelConnector::new(
|
||||
"tcp://10.1.1.1:11010".parse().unwrap(),
|
||||
));
|
||||
} else if proto == "udp" {
|
||||
inst2
|
||||
.get_conn_manager()
|
||||
.add_connector(UdpTunnelConnector::new(
|
||||
"udp://10.1.1.1:11010".parse().unwrap(),
|
||||
));
|
||||
} else if proto == "wg" {
|
||||
inst2
|
||||
.get_conn_manager()
|
||||
.add_connector(WgTunnelConnector::new(
|
||||
"wg://10.1.1.1:11011".parse().unwrap(),
|
||||
WgConfig::new_from_network_identity(
|
||||
&inst1.get_global_ctx().get_network_identity().network_name,
|
||||
&inst1.get_global_ctx().get_network_identity().network_secret,
|
||||
),
|
||||
));
|
||||
}
|
||||
|
||||
inst2
|
||||
.get_conn_manager()
|
||||
.add_connector(RingTunnelConnector::new(
|
||||
format!("ring://{}", inst3.id()).parse().unwrap(),
|
||||
));
|
||||
|
||||
// wait inst2 have two route.
|
||||
let now = std::time::Instant::now();
|
||||
loop {
|
||||
if inst2.get_peer_manager().list_routes().await.len() == 2 {
|
||||
break;
|
||||
}
|
||||
if now.elapsed().as_secs() > 5 {
|
||||
panic!("wait inst2 have two route timeout");
|
||||
}
|
||||
tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
|
||||
}
|
||||
|
||||
vec![inst1, inst2, inst3]
|
||||
}
|
||||
|
||||
#[rstest::rstest]
|
||||
#[tokio::test]
|
||||
#[serial_test::serial]
|
||||
pub async fn basic_three_node_test(#[values("tcp", "udp", "wg")] proto: &str) {
|
||||
let insts = init_three_node(proto).await;
|
||||
|
||||
check_route(
|
||||
"10.144.144.2",
|
||||
insts[1].peer_id(),
|
||||
insts[0].get_peer_manager().list_routes().await,
|
||||
);
|
||||
|
||||
check_route(
|
||||
"10.144.144.3",
|
||||
insts[2].peer_id(),
|
||||
insts[0].get_peer_manager().list_routes().await,
|
||||
);
|
||||
}
|
||||
|
||||
#[rstest::rstest]
|
||||
#[tokio::test]
|
||||
#[serial_test::serial]
|
||||
pub async fn tcp_proxy_three_node_test(#[values("tcp", "udp", "wg")] proto: &str) {
|
||||
let insts = init_three_node(proto).await;
|
||||
|
||||
insts[2]
|
||||
.get_global_ctx()
|
||||
.add_proxy_cidr("10.1.2.0/24".parse().unwrap())
|
||||
.unwrap();
|
||||
assert_eq!(insts[2].get_global_ctx().get_proxy_cidrs().len(), 1);
|
||||
|
||||
wait_proxy_route_appear(
|
||||
&insts[0].get_peer_manager(),
|
||||
"10.144.144.3",
|
||||
insts[2].peer_id(),
|
||||
"10.1.2.0/24",
|
||||
)
|
||||
.await;
|
||||
|
||||
// wait updater
|
||||
tokio::time::sleep(tokio::time::Duration::from_secs(6)).await;
|
||||
|
||||
let tcp_listener = TcpTunnelListener::new("tcp://10.1.2.4:22223".parse().unwrap());
|
||||
let tcp_connector = TcpTunnelConnector::new("tcp://10.1.2.4:22223".parse().unwrap());
|
||||
|
||||
_tunnel_pingpong_netns(
|
||||
tcp_listener,
|
||||
tcp_connector,
|
||||
NetNS::new(Some("net_d".into())),
|
||||
NetNS::new(Some("net_a".into())),
|
||||
)
|
||||
.await;
|
||||
}
|
||||
|
||||
#[rstest::rstest]
|
||||
#[tokio::test]
|
||||
#[serial_test::serial]
|
||||
pub async fn icmp_proxy_three_node_test(#[values("tcp", "udp", "wg")] proto: &str) {
|
||||
let insts = init_three_node(proto).await;
|
||||
|
||||
insts[2]
|
||||
.get_global_ctx()
|
||||
.add_proxy_cidr("10.1.2.0/24".parse().unwrap())
|
||||
.unwrap();
|
||||
assert_eq!(insts[2].get_global_ctx().get_proxy_cidrs().len(), 1);
|
||||
|
||||
wait_proxy_route_appear(
|
||||
&insts[0].get_peer_manager(),
|
||||
"10.144.144.3",
|
||||
insts[2].peer_id(),
|
||||
"10.1.2.0/24",
|
||||
)
|
||||
.await;
|
||||
|
||||
// wait updater
|
||||
tokio::time::sleep(tokio::time::Duration::from_secs(6)).await;
|
||||
|
||||
// send ping with shell in net_a to net_d
|
||||
let _g = NetNS::new(Some(ROOT_NETNS_NAME.to_owned())).guard();
|
||||
let code = tokio::process::Command::new("ip")
|
||||
.args(&[
|
||||
"netns", "exec", "net_a", "ping", "-c", "1", "-W", "1", "10.1.2.4",
|
||||
])
|
||||
.status()
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(code.code().unwrap(), 0);
|
||||
}
|
||||
|
||||
#[rstest::rstest]
|
||||
#[tokio::test]
|
||||
#[serial_test::serial]
|
||||
pub async fn proxy_three_node_disconnect_test(#[values("tcp", "wg")] proto: &str) {
|
||||
let insts = init_three_node(proto).await;
|
||||
let mut inst4 = Instance::new(get_inst_config("inst4", Some("net_d"), "10.144.144.4"));
|
||||
if proto == "tcp" {
|
||||
inst4
|
||||
.get_conn_manager()
|
||||
.add_connector(TcpTunnelConnector::new(
|
||||
"tcp://10.1.2.3:11010".parse().unwrap(),
|
||||
));
|
||||
} else if proto == "wg" {
|
||||
inst4
|
||||
.get_conn_manager()
|
||||
.add_connector(WgTunnelConnector::new(
|
||||
"wg://10.1.2.3:11011".parse().unwrap(),
|
||||
WgConfig::new_from_network_identity(
|
||||
&inst4.get_global_ctx().get_network_identity().network_name,
|
||||
&inst4.get_global_ctx().get_network_identity().network_secret,
|
||||
),
|
||||
));
|
||||
} else {
|
||||
unreachable!("not support");
|
||||
}
|
||||
inst4.run().await.unwrap();
|
||||
|
||||
let task = tokio::spawn(async move {
|
||||
for _ in 1..=2 {
|
||||
tokio::time::sleep(tokio::time::Duration::from_secs(8)).await;
|
||||
// inst4 should be in inst1's route list
|
||||
let routes = insts[0].get_peer_manager().list_routes().await;
|
||||
assert!(
|
||||
routes
|
||||
.iter()
|
||||
.find(|r| r.peer_id == inst4.peer_id())
|
||||
.is_some(),
|
||||
"inst4 should be in inst1's route list, {:?}",
|
||||
routes
|
||||
);
|
||||
|
||||
set_link_status("net_d", false);
|
||||
tokio::time::sleep(tokio::time::Duration::from_secs(8)).await;
|
||||
let routes = insts[0].get_peer_manager().list_routes().await;
|
||||
assert!(
|
||||
routes
|
||||
.iter()
|
||||
.find(|r| r.peer_id == inst4.peer_id())
|
||||
.is_none(),
|
||||
"inst4 should not be in inst1's route list, {:?}",
|
||||
routes
|
||||
);
|
||||
set_link_status("net_d", true);
|
||||
}
|
||||
});
|
||||
|
||||
let (ret,) = tokio::join!(task);
|
||||
assert!(ret.is_ok());
|
||||
}
|
||||
|
||||
#[rstest::rstest]
|
||||
#[tokio::test]
|
||||
#[serial_test::serial]
|
||||
pub async fn udp_proxy_three_node_test(#[values("tcp", "udp", "wg")] proto: &str) {
|
||||
let insts = init_three_node(proto).await;
|
||||
|
||||
insts[2]
|
||||
.get_global_ctx()
|
||||
.add_proxy_cidr("10.1.2.0/24".parse().unwrap())
|
||||
.unwrap();
|
||||
assert_eq!(insts[2].get_global_ctx().get_proxy_cidrs().len(), 1);
|
||||
|
||||
wait_proxy_route_appear(
|
||||
&insts[0].get_peer_manager(),
|
||||
"10.144.144.3",
|
||||
insts[2].peer_id(),
|
||||
"10.1.2.0/24",
|
||||
)
|
||||
.await;
|
||||
|
||||
// wait updater
|
||||
tokio::time::sleep(tokio::time::Duration::from_secs(5)).await;
|
||||
|
||||
let tcp_listener = UdpTunnelListener::new("udp://10.1.2.4:22233".parse().unwrap());
|
||||
let tcp_connector = UdpTunnelConnector::new("udp://10.1.2.4:22233".parse().unwrap());
|
||||
|
||||
_tunnel_pingpong_netns(
|
||||
tcp_listener,
|
||||
tcp_connector,
|
||||
NetNS::new(Some("net_d".into())),
|
||||
NetNS::new(Some("net_a".into())),
|
||||
)
|
||||
.await;
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
#[serial_test::serial]
|
||||
pub async fn udp_broadcast_test() {
|
||||
let _insts = init_three_node("tcp").await;
|
||||
|
||||
let udp_broadcast_responder = |net_ns: NetNS, counter: Arc<AtomicU32>| async move {
|
||||
let _g = net_ns.guard();
|
||||
let socket: UdpSocket = UdpSocket::bind("0.0.0.0:22111").await.unwrap();
|
||||
socket.set_broadcast(true).unwrap();
|
||||
|
||||
println!("Awaiting responses..."); // self.recv_buff is a [u8; 8092]
|
||||
let mut recv_buff = [0; 8092];
|
||||
while let Ok((n, addr)) = socket.recv_from(&mut recv_buff).await {
|
||||
println!("{} bytes response from {:?}", n, addr);
|
||||
counter.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
|
||||
// Remaining code not directly relevant to the question
|
||||
}
|
||||
};
|
||||
|
||||
let mut tasks = JoinSet::new();
|
||||
let counter = Arc::new(AtomicU32::new(0));
|
||||
tasks.spawn(udp_broadcast_responder(
|
||||
NetNS::new(Some("net_b".into())),
|
||||
counter.clone(),
|
||||
));
|
||||
tasks.spawn(udp_broadcast_responder(
|
||||
NetNS::new(Some("net_c".into())),
|
||||
counter.clone(),
|
||||
));
|
||||
|
||||
tokio::time::sleep(tokio::time::Duration::from_secs(1)).await;
|
||||
|
||||
// send broadcast
|
||||
let net_ns = NetNS::new(Some("net_a".into()));
|
||||
let _g = net_ns.guard();
|
||||
let socket: UdpSocket = UdpSocket::bind("0.0.0.0:0").await.unwrap();
|
||||
socket.set_broadcast(true).unwrap();
|
||||
// socket.connect(("10.144.144.255", 22111)).await.unwrap();
|
||||
let call: Vec<u8> = vec![1; 1024];
|
||||
println!("Sending call, {} bytes", call.len());
|
||||
match socket.send_to(&call, "10.144.144.255:22111").await {
|
||||
Err(e) => panic!("Error sending call: {:?}", e),
|
||||
_ => {}
|
||||
}
|
||||
|
||||
tokio::time::sleep(tokio::time::Duration::from_secs(2)).await;
|
||||
assert_eq!(counter.load(std::sync::atomic::Ordering::Relaxed), 2);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
#[serial_test::serial]
|
||||
pub async fn foreign_network_forward_nic_data() {
|
||||
prepare_linux_namespaces();
|
||||
|
||||
let center_node_config = get_inst_config("inst1", Some("net_a"), "10.144.144.1");
|
||||
center_node_config.set_network_identity(NetworkIdentity {
|
||||
network_name: "center".to_string(),
|
||||
network_secret: "".to_string(),
|
||||
});
|
||||
let mut center_inst = Instance::new(center_node_config);
|
||||
|
||||
let mut inst1 = Instance::new(get_inst_config("inst1", Some("net_b"), "10.144.145.1"));
|
||||
let mut inst2 = Instance::new(get_inst_config("inst2", Some("net_c"), "10.144.145.2"));
|
||||
|
||||
center_inst.run().await.unwrap();
|
||||
inst1.run().await.unwrap();
|
||||
inst2.run().await.unwrap();
|
||||
|
||||
assert_ne!(inst1.id(), center_inst.id());
|
||||
assert_ne!(inst2.id(), center_inst.id());
|
||||
|
||||
inst1
|
||||
.get_conn_manager()
|
||||
.add_connector(RingTunnelConnector::new(
|
||||
format!("ring://{}", center_inst.id()).parse().unwrap(),
|
||||
));
|
||||
|
||||
inst2
|
||||
.get_conn_manager()
|
||||
.add_connector(RingTunnelConnector::new(
|
||||
format!("ring://{}", center_inst.id()).parse().unwrap(),
|
||||
));
|
||||
|
||||
wait_for_condition(
|
||||
|| async {
|
||||
inst1.get_peer_manager().list_routes().await.len() == 1
|
||||
&& inst2.get_peer_manager().list_routes().await.len() == 1
|
||||
},
|
||||
Duration::from_secs(5),
|
||||
)
|
||||
.await;
|
||||
|
||||
let _g = NetNS::new(Some(ROOT_NETNS_NAME.to_owned())).guard();
|
||||
let code = tokio::process::Command::new("ip")
|
||||
.args(&[
|
||||
"netns",
|
||||
"exec",
|
||||
"net_b",
|
||||
"ping",
|
||||
"-c",
|
||||
"1",
|
||||
"-W",
|
||||
"1",
|
||||
"10.144.145.2",
|
||||
])
|
||||
.status()
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(code.code().unwrap(), 0);
|
||||
}
|
||||
@@ -0,0 +1,54 @@
|
||||
use std::result::Result;
|
||||
use tokio::io;
|
||||
use tokio_util::{
|
||||
bytes::{BufMut, Bytes, BytesMut},
|
||||
codec::{Decoder, Encoder},
|
||||
};
|
||||
|
||||
#[derive(Copy, Clone, Debug, Eq, PartialEq, Ord, PartialOrd, Hash, Default)]
|
||||
pub struct BytesCodec {
|
||||
capacity: usize,
|
||||
}
|
||||
|
||||
impl BytesCodec {
|
||||
/// Creates a new `BytesCodec` for shipping around raw bytes.
|
||||
pub fn new(capacity: usize) -> BytesCodec {
|
||||
BytesCodec { capacity }
|
||||
}
|
||||
}
|
||||
|
||||
impl Decoder for BytesCodec {
|
||||
type Item = BytesMut;
|
||||
type Error = io::Error;
|
||||
|
||||
fn decode(&mut self, buf: &mut BytesMut) -> Result<Option<BytesMut>, io::Error> {
|
||||
if !buf.is_empty() {
|
||||
let len = buf.len();
|
||||
let ret = Some(buf.split_to(len));
|
||||
buf.reserve(self.capacity);
|
||||
Ok(ret)
|
||||
} else {
|
||||
Ok(None)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Encoder<Bytes> for BytesCodec {
|
||||
type Error = io::Error;
|
||||
|
||||
fn encode(&mut self, data: Bytes, buf: &mut BytesMut) -> Result<(), io::Error> {
|
||||
buf.reserve(data.len());
|
||||
buf.put(data);
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
impl Encoder<BytesMut> for BytesCodec {
|
||||
type Error = io::Error;
|
||||
|
||||
fn encode(&mut self, data: BytesMut, buf: &mut BytesMut) -> Result<(), io::Error> {
|
||||
buf.reserve(data.len());
|
||||
buf.put(data);
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,460 @@
|
||||
use std::{
|
||||
collections::VecDeque,
|
||||
net::{IpAddr, SocketAddr},
|
||||
sync::Arc,
|
||||
task::{ready, Context, Poll},
|
||||
};
|
||||
|
||||
use async_stream::stream;
|
||||
use futures::{Future, FutureExt, Sink, SinkExt, Stream, StreamExt};
|
||||
use network_interface::NetworkInterfaceConfig;
|
||||
use tokio::{sync::Mutex, time::error::Elapsed};
|
||||
|
||||
use std::pin::Pin;
|
||||
|
||||
use crate::tunnels::{SinkError, TunnelError};
|
||||
|
||||
use super::{DatagramSink, DatagramStream, SinkItem, StreamT, Tunnel, TunnelInfo};
|
||||
|
||||
pub struct FramedTunnel<R, W> {
|
||||
read: Arc<Mutex<R>>,
|
||||
write: Arc<Mutex<W>>,
|
||||
|
||||
info: Option<TunnelInfo>,
|
||||
}
|
||||
|
||||
impl<R, RE, W, WE> FramedTunnel<R, W>
|
||||
where
|
||||
R: Stream<Item = Result<StreamT, RE>> + Send + Sync + Unpin + 'static,
|
||||
W: Sink<SinkItem, Error = WE> + Send + Sync + Unpin + 'static,
|
||||
RE: std::error::Error + std::fmt::Debug + Send + Sync + 'static,
|
||||
WE: std::error::Error + std::fmt::Debug + Send + Sync + 'static + From<Elapsed>,
|
||||
{
|
||||
pub fn new(read: R, write: W, info: Option<TunnelInfo>) -> Self {
|
||||
FramedTunnel {
|
||||
read: Arc::new(Mutex::new(read)),
|
||||
write: Arc::new(Mutex::new(write)),
|
||||
info,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn new_tunnel_with_info(read: R, write: W, info: TunnelInfo) -> Box<dyn Tunnel> {
|
||||
Box::new(FramedTunnel::new(read, write, Some(info)))
|
||||
}
|
||||
|
||||
pub fn recv_stream(&self) -> impl DatagramStream {
|
||||
let read = self.read.clone();
|
||||
let info = self.info.clone();
|
||||
stream! {
|
||||
loop {
|
||||
let read_ret = read.lock().await.next().await;
|
||||
if read_ret.is_none() {
|
||||
tracing::info!(?info, "read_ret is none");
|
||||
yield Err(TunnelError::CommonError("recv stream closed".to_string()));
|
||||
} else {
|
||||
let read_ret = read_ret.unwrap();
|
||||
if read_ret.is_err() {
|
||||
let err = read_ret.err().unwrap();
|
||||
tracing::info!(?info, "recv stream read error");
|
||||
yield Err(TunnelError::CommonError(err.to_string()));
|
||||
} else {
|
||||
yield Ok(read_ret.unwrap());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn send_sink(&self) -> impl DatagramSink {
|
||||
struct SendSink<W, WE> {
|
||||
write: Arc<Mutex<W>>,
|
||||
max_buffer_size: usize,
|
||||
sending_buffers: Option<VecDeque<SinkItem>>,
|
||||
send_task:
|
||||
Option<Pin<Box<dyn Future<Output = Result<(), WE>> + Send + Sync + 'static>>>,
|
||||
close_task:
|
||||
Option<Pin<Box<dyn Future<Output = Result<(), WE>> + Send + Sync + 'static>>>,
|
||||
}
|
||||
|
||||
impl<W, WE> SendSink<W, WE>
|
||||
where
|
||||
W: Sink<SinkItem, Error = WE> + Send + Sync + Unpin + 'static,
|
||||
WE: std::error::Error + std::fmt::Debug + Send + Sync + From<Elapsed>,
|
||||
{
|
||||
fn try_send_buffser(
|
||||
&mut self,
|
||||
cx: &mut Context<'_>,
|
||||
) -> Poll<std::result::Result<(), WE>> {
|
||||
if self.send_task.is_none() {
|
||||
let mut buffers = self.sending_buffers.take().unwrap();
|
||||
let tun = self.write.clone();
|
||||
let send_task = async move {
|
||||
if buffers.is_empty() {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let mut locked_tun = tun.lock_owned().await;
|
||||
while let Some(buf) = buffers.front() {
|
||||
log::trace!(
|
||||
"try_send buffer, len: {:?}, buf: {:?}",
|
||||
buffers.len(),
|
||||
&buf
|
||||
);
|
||||
let timeout_task = tokio::time::timeout(
|
||||
std::time::Duration::from_secs(1),
|
||||
locked_tun.send(buf.clone()),
|
||||
);
|
||||
let send_res = timeout_task.await;
|
||||
let Ok(send_res) = send_res else {
|
||||
// panic!("send timeout");
|
||||
let err = send_res.err().unwrap();
|
||||
return Err(err.into());
|
||||
};
|
||||
let Ok(_) = send_res else {
|
||||
let err = send_res.err().unwrap();
|
||||
println!("send error: {:?}", err);
|
||||
return Err(err);
|
||||
};
|
||||
buffers.pop_front();
|
||||
}
|
||||
return Ok(());
|
||||
};
|
||||
self.send_task = Some(Box::pin(send_task));
|
||||
}
|
||||
|
||||
let ret = ready!(self.send_task.as_mut().unwrap().poll_unpin(cx));
|
||||
self.send_task = None;
|
||||
self.sending_buffers = Some(VecDeque::new());
|
||||
return Poll::Ready(ret);
|
||||
}
|
||||
}
|
||||
|
||||
impl<W, WE> Sink<SinkItem> for SendSink<W, WE>
|
||||
where
|
||||
W: Sink<SinkItem, Error = WE> + Send + Sync + Unpin + 'static,
|
||||
WE: std::error::Error + std::fmt::Debug + Send + Sync + From<Elapsed>,
|
||||
{
|
||||
type Error = SinkError;
|
||||
|
||||
fn poll_ready(
|
||||
self: Pin<&mut Self>,
|
||||
cx: &mut std::task::Context<'_>,
|
||||
) -> Poll<Result<(), Self::Error>> {
|
||||
let self_mut = self.get_mut();
|
||||
let sending_buf = self_mut.sending_buffers.as_ref();
|
||||
// if sending_buffers is None, must already be doing flush
|
||||
if sending_buf.is_none() || sending_buf.unwrap().len() > self_mut.max_buffer_size {
|
||||
return self_mut.poll_flush_unpin(cx);
|
||||
} else {
|
||||
return Poll::Ready(Ok(()));
|
||||
}
|
||||
}
|
||||
|
||||
fn start_send(self: Pin<&mut Self>, item: SinkItem) -> Result<(), Self::Error> {
|
||||
assert!(self.send_task.is_none());
|
||||
let self_mut = self.get_mut();
|
||||
self_mut.sending_buffers.as_mut().unwrap().push_back(item);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn poll_flush(
|
||||
self: Pin<&mut Self>,
|
||||
cx: &mut std::task::Context<'_>,
|
||||
) -> Poll<Result<(), Self::Error>> {
|
||||
let self_mut = self.get_mut();
|
||||
let ret = self_mut.try_send_buffser(cx);
|
||||
match ret {
|
||||
Poll::Ready(Ok(())) => Poll::Ready(Ok(())),
|
||||
Poll::Ready(Err(e)) => Poll::Ready(Err(SinkError::CommonError(e.to_string()))),
|
||||
Poll::Pending => {
|
||||
return Poll::Pending;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn poll_close(
|
||||
self: Pin<&mut Self>,
|
||||
cx: &mut std::task::Context<'_>,
|
||||
) -> Poll<Result<(), Self::Error>> {
|
||||
let self_mut = self.get_mut();
|
||||
if self_mut.close_task.is_none() {
|
||||
let tun = self_mut.write.clone();
|
||||
let close_task = async move {
|
||||
let mut locked_tun = tun.lock_owned().await;
|
||||
return locked_tun.close().await;
|
||||
};
|
||||
self_mut.close_task = Some(Box::pin(close_task));
|
||||
}
|
||||
|
||||
let ret = ready!(self_mut.close_task.as_mut().unwrap().poll_unpin(cx));
|
||||
self_mut.close_task = None;
|
||||
|
||||
if ret.is_err() {
|
||||
return Poll::Ready(Err(SinkError::CommonError(
|
||||
ret.err().unwrap().to_string(),
|
||||
)));
|
||||
} else {
|
||||
return Poll::Ready(Ok(()));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
SendSink {
|
||||
write: self.write.clone(),
|
||||
max_buffer_size: 1000,
|
||||
sending_buffers: Some(VecDeque::new()),
|
||||
send_task: None,
|
||||
close_task: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<R, RE, W, WE> Tunnel for FramedTunnel<R, W>
|
||||
where
|
||||
R: Stream<Item = Result<StreamT, RE>> + Send + Sync + Unpin + 'static,
|
||||
W: Sink<SinkItem, Error = WE> + Send + Sync + Unpin + 'static,
|
||||
RE: std::error::Error + std::fmt::Debug + Send + Sync + 'static,
|
||||
WE: std::error::Error + std::fmt::Debug + Send + Sync + 'static + From<Elapsed>,
|
||||
{
|
||||
fn stream(&self) -> Box<dyn DatagramStream> {
|
||||
Box::new(self.recv_stream())
|
||||
}
|
||||
|
||||
fn sink(&self) -> Box<dyn DatagramSink> {
|
||||
Box::new(self.send_sink())
|
||||
}
|
||||
|
||||
fn info(&self) -> Option<TunnelInfo> {
|
||||
if self.info.is_none() {
|
||||
None
|
||||
} else {
|
||||
Some(self.info.clone().unwrap())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct TunnelWithCustomInfo {
|
||||
tunnel: Box<dyn Tunnel>,
|
||||
info: TunnelInfo,
|
||||
}
|
||||
|
||||
impl TunnelWithCustomInfo {
|
||||
pub fn new(tunnel: Box<dyn Tunnel>, info: TunnelInfo) -> Self {
|
||||
TunnelWithCustomInfo { tunnel, info }
|
||||
}
|
||||
}
|
||||
|
||||
impl Tunnel for TunnelWithCustomInfo {
|
||||
fn stream(&self) -> Box<dyn DatagramStream> {
|
||||
self.tunnel.stream()
|
||||
}
|
||||
|
||||
fn sink(&self) -> Box<dyn DatagramSink> {
|
||||
self.tunnel.sink()
|
||||
}
|
||||
|
||||
fn info(&self) -> Option<TunnelInfo> {
|
||||
Some(self.info.clone())
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn get_interface_name_by_ip(local_ip: &IpAddr) -> Option<String> {
|
||||
if local_ip.is_unspecified() || local_ip.is_multicast() {
|
||||
return None;
|
||||
}
|
||||
let ifaces = network_interface::NetworkInterface::show().ok()?;
|
||||
for iface in ifaces {
|
||||
for addr in iface.addr {
|
||||
if addr.ip() == *local_ip {
|
||||
return Some(iface.name);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
tracing::error!(?local_ip, "can not find interface name by ip");
|
||||
None
|
||||
}
|
||||
|
||||
pub(crate) fn setup_sokcet2_ext(
|
||||
socket2_socket: &socket2::Socket,
|
||||
bind_addr: &SocketAddr,
|
||||
bind_dev: Option<String>,
|
||||
) -> Result<(), TunnelError> {
|
||||
#[cfg(target_os = "windows")]
|
||||
{
|
||||
let is_udp = matches!(socket2_socket.r#type()?, socket2::Type::DGRAM);
|
||||
crate::arch::windows::setup_socket_for_win(socket2_socket, bind_addr, bind_dev, is_udp)?;
|
||||
}
|
||||
|
||||
socket2_socket.set_nonblocking(true)?;
|
||||
socket2_socket.set_reuse_address(true)?;
|
||||
socket2_socket.bind(&socket2::SockAddr::from(*bind_addr))?;
|
||||
|
||||
// #[cfg(all(unix, not(target_os = "solaris"), not(target_os = "illumos")))]
|
||||
// socket2_socket.set_reuse_port(true)?;
|
||||
|
||||
if bind_addr.ip().is_unspecified() {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
// linux/mac does not use interface of bind_addr to send packet, so we need to bind device
|
||||
// win can handle this with bind correctly
|
||||
#[cfg(any(target_os = "ios", target_os = "macos"))]
|
||||
if let Some(dev_name) = bind_dev {
|
||||
// use IP_BOUND_IF to bind device
|
||||
unsafe {
|
||||
let dev_idx = nix::libc::if_nametoindex(dev_name.as_str().as_ptr() as *const i8);
|
||||
tracing::warn!(?dev_idx, ?dev_name, "bind device");
|
||||
socket2_socket.bind_device_by_index_v4(std::num::NonZeroU32::new(dev_idx))?;
|
||||
tracing::warn!(?dev_idx, ?dev_name, "bind device doen");
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(any(target_os = "android", target_os = "fuchsia", target_os = "linux"))]
|
||||
if let Some(dev_name) = bind_dev {
|
||||
tracing::trace!(dev_name = ?dev_name, "bind device");
|
||||
socket2_socket.bind_device(Some(dev_name.as_bytes()))?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub(crate) fn setup_sokcet2(
|
||||
socket2_socket: &socket2::Socket,
|
||||
bind_addr: &SocketAddr,
|
||||
) -> Result<(), TunnelError> {
|
||||
setup_sokcet2_ext(
|
||||
socket2_socket,
|
||||
bind_addr,
|
||||
super::common::get_interface_name_by_ip(&bind_addr.ip()),
|
||||
)
|
||||
}
|
||||
|
||||
pub mod tests {
|
||||
use std::time::Instant;
|
||||
|
||||
use futures::SinkExt;
|
||||
use tokio_stream::StreamExt;
|
||||
use tokio_util::bytes::{BufMut, Bytes, BytesMut};
|
||||
|
||||
use crate::{
|
||||
common::netns::NetNS,
|
||||
tunnels::{close_tunnel, TunnelConnector, TunnelListener},
|
||||
};
|
||||
|
||||
pub async fn _tunnel_echo_server(tunnel: Box<dyn super::Tunnel>, once: bool) {
|
||||
let mut recv = Box::into_pin(tunnel.stream());
|
||||
let mut send = Box::into_pin(tunnel.sink());
|
||||
|
||||
while let Some(ret) = recv.next().await {
|
||||
if ret.is_err() {
|
||||
log::trace!("recv error: {:?}", ret.err().unwrap());
|
||||
break;
|
||||
}
|
||||
let res = ret.unwrap();
|
||||
log::trace!("recv a msg, try echo back: {:?}", res);
|
||||
send.send(Bytes::from(res)).await.unwrap();
|
||||
if once {
|
||||
break;
|
||||
}
|
||||
}
|
||||
log::warn!("echo server exit...");
|
||||
}
|
||||
|
||||
pub(crate) async fn _tunnel_pingpong<L, C>(listener: L, connector: C)
|
||||
where
|
||||
L: TunnelListener + Send + Sync + 'static,
|
||||
C: TunnelConnector + Send + Sync + 'static,
|
||||
{
|
||||
_tunnel_pingpong_netns(listener, connector, NetNS::new(None), NetNS::new(None)).await
|
||||
}
|
||||
|
||||
pub(crate) async fn _tunnel_pingpong_netns<L, C>(
|
||||
mut listener: L,
|
||||
mut connector: C,
|
||||
l_netns: NetNS,
|
||||
c_netns: NetNS,
|
||||
) where
|
||||
L: TunnelListener + Send + Sync + 'static,
|
||||
C: TunnelConnector + Send + Sync + 'static,
|
||||
{
|
||||
l_netns
|
||||
.run_async(|| async {
|
||||
listener.listen().await.unwrap();
|
||||
})
|
||||
.await;
|
||||
|
||||
let lis = tokio::spawn(async move {
|
||||
let ret = listener.accept().await.unwrap();
|
||||
assert_eq!(
|
||||
ret.info().unwrap().local_addr,
|
||||
listener.local_url().to_string()
|
||||
);
|
||||
_tunnel_echo_server(ret, false).await
|
||||
});
|
||||
|
||||
let tunnel = c_netns.run_async(|| connector.connect()).await.unwrap();
|
||||
|
||||
assert_eq!(
|
||||
tunnel.info().unwrap().remote_addr,
|
||||
connector.remote_url().to_string()
|
||||
);
|
||||
|
||||
let mut send = tunnel.pin_sink();
|
||||
let mut recv = tunnel.pin_stream();
|
||||
let send_data = Bytes::from("12345678abcdefg");
|
||||
send.send(send_data).await.unwrap();
|
||||
let ret = tokio::time::timeout(tokio::time::Duration::from_secs(1), recv.next())
|
||||
.await
|
||||
.unwrap()
|
||||
.unwrap()
|
||||
.unwrap();
|
||||
println!("echo back: {:?}", ret);
|
||||
assert_eq!(ret, Bytes::from("12345678abcdefg"));
|
||||
|
||||
close_tunnel(&tunnel).await.unwrap();
|
||||
|
||||
if ["udp", "wg"].contains(&connector.remote_url().scheme()) {
|
||||
lis.abort();
|
||||
} else {
|
||||
// lis should finish in 1 second
|
||||
let ret = tokio::time::timeout(tokio::time::Duration::from_secs(1), lis).await;
|
||||
assert!(ret.is_ok());
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) async fn _tunnel_bench<L, C>(mut listener: L, mut connector: C)
|
||||
where
|
||||
L: TunnelListener + Send + Sync + 'static,
|
||||
C: TunnelConnector + Send + Sync + 'static,
|
||||
{
|
||||
listener.listen().await.unwrap();
|
||||
|
||||
let lis = tokio::spawn(async move {
|
||||
let ret = listener.accept().await.unwrap();
|
||||
_tunnel_echo_server(ret, false).await
|
||||
});
|
||||
|
||||
let tunnel = connector.connect().await.unwrap();
|
||||
|
||||
let mut send = tunnel.pin_sink();
|
||||
let mut recv = tunnel.pin_stream();
|
||||
|
||||
// prepare a 4k buffer with random data
|
||||
let mut send_buf = BytesMut::new();
|
||||
for _ in 0..64 {
|
||||
send_buf.put_i128(rand::random::<i128>());
|
||||
}
|
||||
|
||||
let now = Instant::now();
|
||||
let mut count = 0;
|
||||
while now.elapsed().as_secs() < 3 {
|
||||
send.send(send_buf.clone().freeze()).await.unwrap();
|
||||
let _ = recv.next().await.unwrap().unwrap();
|
||||
count += 1;
|
||||
}
|
||||
println!("bps: {}", (count / 1024) * 4 / now.elapsed().as_secs());
|
||||
|
||||
lis.abort();
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,192 @@
|
||||
pub mod codec;
|
||||
pub mod common;
|
||||
pub mod ring_tunnel;
|
||||
pub mod stats;
|
||||
pub mod tcp_tunnel;
|
||||
pub mod tunnel_filter;
|
||||
pub mod udp_tunnel;
|
||||
pub mod wireguard;
|
||||
|
||||
use std::{fmt::Debug, net::SocketAddr, pin::Pin, sync::Arc};
|
||||
|
||||
use crate::rpc::TunnelInfo;
|
||||
use async_trait::async_trait;
|
||||
use futures::{Sink, SinkExt, Stream};
|
||||
|
||||
use thiserror::Error;
|
||||
use tokio_util::bytes::{Bytes, BytesMut};
|
||||
|
||||
#[derive(Error, Debug)]
|
||||
pub enum TunnelError {
|
||||
#[error("Error: {0}")]
|
||||
CommonError(String),
|
||||
#[error("io error")]
|
||||
IOError(#[from] std::io::Error),
|
||||
#[error("wait resp error")]
|
||||
WaitRespError(String),
|
||||
#[error("Connect Error: {0}")]
|
||||
ConnectError(String),
|
||||
#[error("Invalid Protocol: {0}")]
|
||||
InvalidProtocol(String),
|
||||
#[error("Invalid Addr: {0}")]
|
||||
InvalidAddr(String),
|
||||
#[error("Tun Error: {0}")]
|
||||
TunError(String),
|
||||
#[error("timeout")]
|
||||
Timeout(#[from] tokio::time::error::Elapsed),
|
||||
}
|
||||
|
||||
pub type StreamT = BytesMut;
|
||||
pub type StreamItem = Result<StreamT, TunnelError>;
|
||||
pub type SinkItem = Bytes;
|
||||
pub type SinkError = TunnelError;
|
||||
|
||||
pub trait DatagramStream: Stream<Item = StreamItem> + Send + Sync {}
|
||||
impl<T> DatagramStream for T where T: Stream<Item = StreamItem> + Send + Sync {}
|
||||
pub trait DatagramSink: Sink<SinkItem, Error = SinkError> + Send + Sync {}
|
||||
impl<T> DatagramSink for T where T: Sink<SinkItem, Error = SinkError> + Send + Sync {}
|
||||
|
||||
#[auto_impl::auto_impl(Box, Arc)]
|
||||
pub trait Tunnel: Send + Sync {
|
||||
fn stream(&self) -> Box<dyn DatagramStream>;
|
||||
fn sink(&self) -> Box<dyn DatagramSink>;
|
||||
|
||||
fn pin_stream(&self) -> Pin<Box<dyn DatagramStream>> {
|
||||
Box::into_pin(self.stream())
|
||||
}
|
||||
|
||||
fn pin_sink(&self) -> Pin<Box<dyn DatagramSink>> {
|
||||
Box::into_pin(self.sink())
|
||||
}
|
||||
|
||||
fn info(&self) -> Option<TunnelInfo>;
|
||||
}
|
||||
|
||||
pub async fn close_tunnel(t: &Box<dyn Tunnel>) -> Result<(), TunnelError> {
|
||||
t.pin_sink().close().await
|
||||
}
|
||||
|
||||
#[auto_impl::auto_impl(Arc)]
|
||||
pub trait TunnelConnCounter: 'static + Send + Sync + Debug {
|
||||
fn get(&self) -> u32;
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
#[auto_impl::auto_impl(Box)]
|
||||
pub trait TunnelListener: Send + Sync {
|
||||
async fn listen(&mut self) -> Result<(), TunnelError>;
|
||||
async fn accept(&mut self) -> Result<Box<dyn Tunnel>, TunnelError>;
|
||||
fn local_url(&self) -> url::Url;
|
||||
fn get_conn_counter(&self) -> Arc<Box<dyn TunnelConnCounter>> {
|
||||
#[derive(Debug)]
|
||||
struct FakeTunnelConnCounter {}
|
||||
impl TunnelConnCounter for FakeTunnelConnCounter {
|
||||
fn get(&self) -> u32 {
|
||||
0
|
||||
}
|
||||
}
|
||||
Arc::new(Box::new(FakeTunnelConnCounter {}))
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
#[auto_impl::auto_impl(Box)]
|
||||
pub trait TunnelConnector {
|
||||
async fn connect(&mut self) -> Result<Box<dyn Tunnel>, TunnelError>;
|
||||
fn remote_url(&self) -> url::Url;
|
||||
fn set_bind_addrs(&mut self, _addrs: Vec<SocketAddr>) {}
|
||||
}
|
||||
|
||||
pub fn build_url_from_socket_addr(addr: &String, scheme: &str) -> url::Url {
|
||||
url::Url::parse(format!("{}://{}", scheme, addr).as_str()).unwrap()
|
||||
}
|
||||
|
||||
impl std::fmt::Debug for dyn Tunnel {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
f.debug_struct("Tunnel")
|
||||
.field("info", &self.info())
|
||||
.finish()
|
||||
}
|
||||
}
|
||||
|
||||
impl std::fmt::Debug for dyn TunnelConnector + Sync + Send {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
f.debug_struct("TunnelConnector")
|
||||
.field("remote_url", &self.remote_url())
|
||||
.finish()
|
||||
}
|
||||
}
|
||||
|
||||
impl std::fmt::Debug for dyn TunnelListener {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
f.debug_struct("TunnelListener")
|
||||
.field("local_url", &self.local_url())
|
||||
.finish()
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) trait FromUrl {
|
||||
fn from_url(url: url::Url) -> Result<Self, TunnelError>
|
||||
where
|
||||
Self: Sized;
|
||||
}
|
||||
|
||||
pub(crate) fn check_scheme_and_get_socket_addr<T>(
|
||||
url: &url::Url,
|
||||
scheme: &str,
|
||||
) -> Result<T, TunnelError>
|
||||
where
|
||||
T: FromUrl,
|
||||
{
|
||||
if url.scheme() != scheme {
|
||||
return Err(TunnelError::InvalidProtocol(url.scheme().to_string()));
|
||||
}
|
||||
|
||||
Ok(T::from_url(url.clone())?)
|
||||
}
|
||||
|
||||
impl FromUrl for SocketAddr {
|
||||
fn from_url(url: url::Url) -> Result<Self, TunnelError> {
|
||||
Ok(url.socket_addrs(|| None)?.pop().unwrap())
|
||||
}
|
||||
}
|
||||
|
||||
impl FromUrl for uuid::Uuid {
|
||||
fn from_url(url: url::Url) -> Result<Self, TunnelError> {
|
||||
let o = url.host_str().unwrap();
|
||||
let o = uuid::Uuid::parse_str(o).map_err(|e| TunnelError::InvalidAddr(e.to_string()))?;
|
||||
Ok(o)
|
||||
}
|
||||
}
|
||||
|
||||
pub struct TunnelUrl {
|
||||
inner: url::Url,
|
||||
}
|
||||
|
||||
impl From<url::Url> for TunnelUrl {
|
||||
fn from(url: url::Url) -> Self {
|
||||
TunnelUrl { inner: url }
|
||||
}
|
||||
}
|
||||
|
||||
impl From<TunnelUrl> for url::Url {
|
||||
fn from(url: TunnelUrl) -> Self {
|
||||
url.into_inner()
|
||||
}
|
||||
}
|
||||
|
||||
impl TunnelUrl {
|
||||
pub fn into_inner(self) -> url::Url {
|
||||
self.inner
|
||||
}
|
||||
|
||||
pub fn bind_dev(&self) -> Option<String> {
|
||||
self.inner.path().strip_prefix("/").and_then(|s| {
|
||||
if s.is_empty() {
|
||||
None
|
||||
} else {
|
||||
Some(String::from_utf8(percent_encoding::percent_decode_str(&s).collect()).unwrap())
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,396 @@
|
||||
use std::{
|
||||
collections::HashMap,
|
||||
sync::{atomic::AtomicBool, Arc},
|
||||
task::Poll,
|
||||
};
|
||||
|
||||
use async_stream::stream;
|
||||
use crossbeam_queue::ArrayQueue;
|
||||
|
||||
use async_trait::async_trait;
|
||||
use futures::Sink;
|
||||
use once_cell::sync::Lazy;
|
||||
use tokio::sync::{
|
||||
mpsc::{UnboundedReceiver, UnboundedSender},
|
||||
Mutex, Notify,
|
||||
};
|
||||
|
||||
use futures::FutureExt;
|
||||
use tokio_util::bytes::BytesMut;
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::tunnels::{SinkError, SinkItem};
|
||||
|
||||
use super::{
|
||||
build_url_from_socket_addr, check_scheme_and_get_socket_addr, DatagramSink, DatagramStream,
|
||||
Tunnel, TunnelConnector, TunnelError, TunnelInfo, TunnelListener,
|
||||
};
|
||||
|
||||
static RING_TUNNEL_CAP: usize = 1000;
|
||||
|
||||
pub struct RingTunnel {
|
||||
id: Uuid,
|
||||
ring: Arc<ArrayQueue<SinkItem>>,
|
||||
consume_notify: Arc<Notify>,
|
||||
produce_notify: Arc<Notify>,
|
||||
closed: Arc<AtomicBool>,
|
||||
}
|
||||
|
||||
impl RingTunnel {
|
||||
pub fn new(cap: usize) -> Self {
|
||||
RingTunnel {
|
||||
id: Uuid::new_v4(),
|
||||
ring: Arc::new(ArrayQueue::new(cap)),
|
||||
consume_notify: Arc::new(Notify::new()),
|
||||
produce_notify: Arc::new(Notify::new()),
|
||||
closed: Arc::new(AtomicBool::new(false)),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn new_with_id(id: Uuid, cap: usize) -> Self {
|
||||
let mut ret = Self::new(cap);
|
||||
ret.id = id;
|
||||
ret
|
||||
}
|
||||
|
||||
fn recv_stream(&self) -> impl DatagramStream {
|
||||
let ring = self.ring.clone();
|
||||
let produce_notify = self.produce_notify.clone();
|
||||
let consume_notify = self.consume_notify.clone();
|
||||
let closed = self.closed.clone();
|
||||
let id = self.id;
|
||||
stream! {
|
||||
loop {
|
||||
if closed.load(std::sync::atomic::Ordering::Relaxed) {
|
||||
log::warn!("ring recv tunnel {:?} closed", id);
|
||||
yield Err(TunnelError::CommonError("Closed".to_owned()));
|
||||
}
|
||||
match ring.pop() {
|
||||
Some(v) => {
|
||||
let mut out = BytesMut::new();
|
||||
out.extend_from_slice(&v);
|
||||
consume_notify.notify_one();
|
||||
log::trace!("id: {}, recv buffer, len: {:?}, buf: {:?}", id, v.len(), &v);
|
||||
yield Ok(out);
|
||||
},
|
||||
None => {
|
||||
log::trace!("waiting recv buffer, id: {}", id);
|
||||
produce_notify.notified().await;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn send_sink(&self) -> impl DatagramSink {
|
||||
let ring = self.ring.clone();
|
||||
let produce_notify = self.produce_notify.clone();
|
||||
let consume_notify = self.consume_notify.clone();
|
||||
let closed = self.closed.clone();
|
||||
let id = self.id;
|
||||
|
||||
// type T = RingTunnel;
|
||||
|
||||
use tokio::task::JoinHandle;
|
||||
|
||||
struct T {
|
||||
ring: RingTunnel,
|
||||
wait_consume_task: Option<JoinHandle<()>>,
|
||||
}
|
||||
|
||||
impl T {
|
||||
fn wait_ring_consume(
|
||||
self: std::pin::Pin<&mut Self>,
|
||||
cx: &mut std::task::Context<'_>,
|
||||
expected_size: usize,
|
||||
) -> std::task::Poll<()> {
|
||||
let self_mut = self.get_mut();
|
||||
if self_mut.ring.ring.len() <= expected_size {
|
||||
return Poll::Ready(());
|
||||
}
|
||||
if self_mut.wait_consume_task.is_none() {
|
||||
let id = self_mut.ring.id;
|
||||
let consume_notify = self_mut.ring.consume_notify.clone();
|
||||
let ring = self_mut.ring.ring.clone();
|
||||
let task = async move {
|
||||
log::trace!(
|
||||
"waiting ring consume done, expected_size: {}, id: {}",
|
||||
expected_size,
|
||||
id
|
||||
);
|
||||
while ring.len() > expected_size {
|
||||
consume_notify.notified().await;
|
||||
}
|
||||
log::trace!(
|
||||
"ring consume done, expected_size: {}, id: {}",
|
||||
expected_size,
|
||||
id
|
||||
);
|
||||
};
|
||||
self_mut.wait_consume_task = Some(tokio::spawn(task));
|
||||
}
|
||||
let task = self_mut.wait_consume_task.as_mut().unwrap();
|
||||
match task.poll_unpin(cx) {
|
||||
Poll::Ready(_) => {
|
||||
self_mut.wait_consume_task = None;
|
||||
Poll::Ready(())
|
||||
}
|
||||
Poll::Pending => Poll::Pending,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Sink<SinkItem> for T {
|
||||
type Error = SinkError;
|
||||
|
||||
fn poll_ready(
|
||||
self: std::pin::Pin<&mut Self>,
|
||||
cx: &mut std::task::Context<'_>,
|
||||
) -> std::task::Poll<Result<(), Self::Error>> {
|
||||
let expected_size = self.ring.ring.capacity() - 1;
|
||||
match self.wait_ring_consume(cx, expected_size) {
|
||||
Poll::Ready(_) => Poll::Ready(Ok(())),
|
||||
Poll::Pending => Poll::Pending,
|
||||
}
|
||||
}
|
||||
|
||||
fn start_send(
|
||||
self: std::pin::Pin<&mut Self>,
|
||||
item: SinkItem,
|
||||
) -> Result<(), Self::Error> {
|
||||
log::trace!("id: {}, send buffer, buf: {:?}", self.ring.id, &item);
|
||||
self.ring.ring.push(item).unwrap();
|
||||
self.ring.produce_notify.notify_one();
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn poll_flush(
|
||||
self: std::pin::Pin<&mut Self>,
|
||||
_cx: &mut std::task::Context<'_>,
|
||||
) -> std::task::Poll<Result<(), Self::Error>> {
|
||||
Poll::Ready(Ok(()))
|
||||
}
|
||||
|
||||
fn poll_close(
|
||||
self: std::pin::Pin<&mut Self>,
|
||||
_cx: &mut std::task::Context<'_>,
|
||||
) -> std::task::Poll<Result<(), Self::Error>> {
|
||||
self.ring
|
||||
.closed
|
||||
.store(true, std::sync::atomic::Ordering::Relaxed);
|
||||
log::warn!("ring tunnel send {:?} closed", self.ring.id);
|
||||
self.ring.produce_notify.notify_one();
|
||||
Poll::Ready(Ok(()))
|
||||
}
|
||||
}
|
||||
|
||||
T {
|
||||
ring: RingTunnel {
|
||||
id,
|
||||
ring,
|
||||
consume_notify,
|
||||
produce_notify,
|
||||
closed,
|
||||
},
|
||||
wait_consume_task: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
struct Connection {
|
||||
client: RingTunnel,
|
||||
server: RingTunnel,
|
||||
}
|
||||
|
||||
impl Tunnel for RingTunnel {
|
||||
fn stream(&self) -> Box<dyn DatagramStream> {
|
||||
Box::new(self.recv_stream())
|
||||
}
|
||||
|
||||
fn sink(&self) -> Box<dyn DatagramSink> {
|
||||
Box::new(self.send_sink())
|
||||
}
|
||||
|
||||
fn info(&self) -> Option<TunnelInfo> {
|
||||
None
|
||||
// Some(TunnelInfo {
|
||||
// tunnel_type: "ring".to_owned(),
|
||||
// local_addr: format!("ring://{}", self.id),
|
||||
// remote_addr: format!("ring://{}", self.id),
|
||||
// })
|
||||
}
|
||||
}
|
||||
|
||||
static CONNECTION_MAP: Lazy<Arc<Mutex<HashMap<uuid::Uuid, UnboundedSender<Arc<Connection>>>>>> =
|
||||
Lazy::new(|| Arc::new(Mutex::new(HashMap::new())));
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct RingTunnelListener {
|
||||
listerner_addr: url::Url,
|
||||
conn_sender: UnboundedSender<Arc<Connection>>,
|
||||
conn_receiver: UnboundedReceiver<Arc<Connection>>,
|
||||
}
|
||||
|
||||
impl RingTunnelListener {
|
||||
pub fn new(key: url::Url) -> Self {
|
||||
let (conn_sender, conn_receiver) = tokio::sync::mpsc::unbounded_channel();
|
||||
RingTunnelListener {
|
||||
listerner_addr: key,
|
||||
conn_sender,
|
||||
conn_receiver,
|
||||
}
|
||||
}
|
||||
}
|
||||
struct ConnectionForServer {
|
||||
conn: Arc<Connection>,
|
||||
}
|
||||
|
||||
impl Tunnel for ConnectionForServer {
|
||||
fn stream(&self) -> Box<dyn DatagramStream> {
|
||||
Box::new(self.conn.server.recv_stream())
|
||||
}
|
||||
|
||||
fn sink(&self) -> Box<dyn DatagramSink> {
|
||||
Box::new(self.conn.client.send_sink())
|
||||
}
|
||||
|
||||
fn info(&self) -> Option<TunnelInfo> {
|
||||
Some(TunnelInfo {
|
||||
tunnel_type: "ring".to_owned(),
|
||||
local_addr: build_url_from_socket_addr(&self.conn.server.id.into(), "ring").into(),
|
||||
remote_addr: build_url_from_socket_addr(&self.conn.client.id.into(), "ring").into(),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
struct ConnectionForClient {
|
||||
conn: Arc<Connection>,
|
||||
}
|
||||
|
||||
impl Tunnel for ConnectionForClient {
|
||||
fn stream(&self) -> Box<dyn DatagramStream> {
|
||||
Box::new(self.conn.client.recv_stream())
|
||||
}
|
||||
|
||||
fn sink(&self) -> Box<dyn DatagramSink> {
|
||||
Box::new(self.conn.server.send_sink())
|
||||
}
|
||||
|
||||
fn info(&self) -> Option<TunnelInfo> {
|
||||
Some(TunnelInfo {
|
||||
tunnel_type: "ring".to_owned(),
|
||||
local_addr: build_url_from_socket_addr(&self.conn.client.id.into(), "ring").into(),
|
||||
remote_addr: build_url_from_socket_addr(&self.conn.server.id.into(), "ring").into(),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl RingTunnelListener {
|
||||
fn get_addr(&self) -> Result<uuid::Uuid, TunnelError> {
|
||||
check_scheme_and_get_socket_addr::<Uuid>(&self.listerner_addr, "ring")
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl TunnelListener for RingTunnelListener {
|
||||
async fn listen(&mut self) -> Result<(), TunnelError> {
|
||||
log::info!("listen new conn of key: {}", self.listerner_addr);
|
||||
CONNECTION_MAP
|
||||
.lock()
|
||||
.await
|
||||
.insert(self.get_addr()?, self.conn_sender.clone());
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn accept(&mut self) -> Result<Box<dyn Tunnel>, TunnelError> {
|
||||
log::info!("waiting accept new conn of key: {}", self.listerner_addr);
|
||||
let my_addr = self.get_addr()?;
|
||||
if let Some(conn) = self.conn_receiver.recv().await {
|
||||
if conn.server.id == my_addr {
|
||||
log::info!("accept new conn of key: {}", self.listerner_addr);
|
||||
return Ok(Box::new(ConnectionForServer { conn }));
|
||||
} else {
|
||||
tracing::error!(?conn.server.id, ?my_addr, "got new conn with wrong id");
|
||||
return Err(TunnelError::CommonError(
|
||||
"accept got wrong ring server id".to_owned(),
|
||||
));
|
||||
}
|
||||
}
|
||||
|
||||
return Err(TunnelError::CommonError("conn receiver stopped".to_owned()));
|
||||
}
|
||||
|
||||
fn local_url(&self) -> url::Url {
|
||||
self.listerner_addr.clone()
|
||||
}
|
||||
}
|
||||
|
||||
pub struct RingTunnelConnector {
|
||||
remote_addr: url::Url,
|
||||
}
|
||||
|
||||
impl RingTunnelConnector {
|
||||
pub fn new(remote_addr: url::Url) -> Self {
|
||||
RingTunnelConnector { remote_addr }
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl TunnelConnector for RingTunnelConnector {
|
||||
async fn connect(&mut self) -> Result<Box<dyn Tunnel>, super::TunnelError> {
|
||||
let remote_addr = check_scheme_and_get_socket_addr::<Uuid>(&self.remote_addr, "ring")?;
|
||||
let entry = CONNECTION_MAP
|
||||
.lock()
|
||||
.await
|
||||
.get(&remote_addr)
|
||||
.unwrap()
|
||||
.clone();
|
||||
log::info!("connecting");
|
||||
let conn = Arc::new(Connection {
|
||||
client: RingTunnel::new(RING_TUNNEL_CAP),
|
||||
server: RingTunnel::new_with_id(remote_addr.clone(), RING_TUNNEL_CAP),
|
||||
});
|
||||
entry
|
||||
.send(conn.clone())
|
||||
.map_err(|_| TunnelError::CommonError("send conn to listner failed".to_owned()))?;
|
||||
Ok(Box::new(ConnectionForClient { conn }))
|
||||
}
|
||||
|
||||
fn remote_url(&self) -> url::Url {
|
||||
self.remote_addr.clone()
|
||||
}
|
||||
}
|
||||
|
||||
pub fn create_ring_tunnel_pair() -> (Box<dyn Tunnel>, Box<dyn Tunnel>) {
|
||||
let conn = Arc::new(Connection {
|
||||
client: RingTunnel::new(RING_TUNNEL_CAP),
|
||||
server: RingTunnel::new(RING_TUNNEL_CAP),
|
||||
});
|
||||
(
|
||||
Box::new(ConnectionForServer { conn: conn.clone() }),
|
||||
Box::new(ConnectionForClient { conn }),
|
||||
)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use crate::tunnels::common::tests::{_tunnel_bench, _tunnel_pingpong};
|
||||
|
||||
use super::*;
|
||||
|
||||
#[tokio::test]
|
||||
async fn ring_pingpong() {
|
||||
let id: url::Url = format!("ring://{}", Uuid::new_v4()).parse().unwrap();
|
||||
let listener = RingTunnelListener::new(id.clone());
|
||||
let connector = RingTunnelConnector::new(id.clone());
|
||||
_tunnel_pingpong(listener, connector).await
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn ring_bench() {
|
||||
let id: url::Url = format!("ring://{}", Uuid::new_v4()).parse().unwrap();
|
||||
let listener = RingTunnelListener::new(id.clone());
|
||||
let connector = RingTunnelConnector::new(id);
|
||||
_tunnel_bench(listener, connector).await
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,95 @@
|
||||
use std::sync::atomic::{AtomicU32, AtomicU64, Ordering::Relaxed};
|
||||
|
||||
pub struct WindowLatency {
|
||||
latency_us_window: Vec<AtomicU32>,
|
||||
latency_us_window_index: AtomicU32,
|
||||
latency_us_window_size: u32,
|
||||
|
||||
sum: AtomicU32,
|
||||
count: AtomicU32,
|
||||
}
|
||||
|
||||
impl WindowLatency {
|
||||
pub fn new(window_size: u32) -> Self {
|
||||
Self {
|
||||
latency_us_window: (0..window_size).map(|_| AtomicU32::new(0)).collect(),
|
||||
latency_us_window_index: AtomicU32::new(0),
|
||||
latency_us_window_size: window_size,
|
||||
|
||||
sum: AtomicU32::new(0),
|
||||
count: AtomicU32::new(0),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn record_latency(&self, latency_us: u32) {
|
||||
let index = self.latency_us_window_index.fetch_add(1, Relaxed);
|
||||
if self.count.load(Relaxed) < self.latency_us_window_size {
|
||||
self.count.fetch_add(1, Relaxed);
|
||||
}
|
||||
|
||||
let index = index % self.latency_us_window_size;
|
||||
let old_lat = self.latency_us_window[index as usize].swap(latency_us, Relaxed);
|
||||
|
||||
if old_lat < latency_us {
|
||||
self.sum.fetch_add(latency_us - old_lat, Relaxed);
|
||||
} else {
|
||||
self.sum.fetch_sub(old_lat - latency_us, Relaxed);
|
||||
}
|
||||
}
|
||||
|
||||
pub fn get_latency_us<T: From<u32> + std::ops::Div<Output = T>>(&self) -> T {
|
||||
let count = self.count.load(Relaxed);
|
||||
let sum = self.sum.load(Relaxed);
|
||||
if count == 0 {
|
||||
0.into()
|
||||
} else {
|
||||
(T::from(sum)) / T::from(count)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct Throughput {
|
||||
tx_bytes: AtomicU64,
|
||||
rx_bytes: AtomicU64,
|
||||
|
||||
tx_packets: AtomicU64,
|
||||
rx_packets: AtomicU64,
|
||||
}
|
||||
|
||||
impl Throughput {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
tx_bytes: AtomicU64::new(0),
|
||||
rx_bytes: AtomicU64::new(0),
|
||||
|
||||
tx_packets: AtomicU64::new(0),
|
||||
rx_packets: AtomicU64::new(0),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn tx_bytes(&self) -> u64 {
|
||||
self.tx_bytes.load(Relaxed)
|
||||
}
|
||||
|
||||
pub fn rx_bytes(&self) -> u64 {
|
||||
self.rx_bytes.load(Relaxed)
|
||||
}
|
||||
|
||||
pub fn tx_packets(&self) -> u64 {
|
||||
self.tx_packets.load(Relaxed)
|
||||
}
|
||||
|
||||
pub fn rx_packets(&self) -> u64 {
|
||||
self.rx_packets.load(Relaxed)
|
||||
}
|
||||
|
||||
pub fn record_tx_bytes(&self, bytes: u64) {
|
||||
self.tx_bytes.fetch_add(bytes, Relaxed);
|
||||
self.tx_packets.fetch_add(1, Relaxed);
|
||||
}
|
||||
|
||||
pub fn record_rx_bytes(&self, bytes: u64) {
|
||||
self.rx_bytes.fetch_add(bytes, Relaxed);
|
||||
self.rx_packets.fetch_add(1, Relaxed);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,273 @@
|
||||
use std::net::SocketAddr;
|
||||
|
||||
use async_trait::async_trait;
|
||||
use futures::{stream::FuturesUnordered, StreamExt};
|
||||
use tokio::net::{TcpListener, TcpSocket, TcpStream};
|
||||
use tokio_util::codec::{FramedRead, FramedWrite, LengthDelimitedCodec};
|
||||
|
||||
use crate::tunnels::common::setup_sokcet2;
|
||||
|
||||
use super::{
|
||||
check_scheme_and_get_socket_addr, common::FramedTunnel, Tunnel, TunnelInfo, TunnelListener,
|
||||
};
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct TcpTunnelListener {
|
||||
addr: url::Url,
|
||||
listener: Option<TcpListener>,
|
||||
}
|
||||
|
||||
impl TcpTunnelListener {
|
||||
pub fn new(addr: url::Url) -> Self {
|
||||
TcpTunnelListener {
|
||||
addr,
|
||||
listener: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl TunnelListener for TcpTunnelListener {
|
||||
async fn listen(&mut self) -> Result<(), super::TunnelError> {
|
||||
let addr = check_scheme_and_get_socket_addr::<SocketAddr>(&self.addr, "tcp")?;
|
||||
|
||||
let socket = if addr.is_ipv4() {
|
||||
TcpSocket::new_v4()?
|
||||
} else {
|
||||
TcpSocket::new_v6()?
|
||||
};
|
||||
|
||||
socket.set_reuseaddr(true)?;
|
||||
// #[cfg(all(unix, not(target_os = "solaris"), not(target_os = "illumos")))]
|
||||
// socket.set_reuseport(true)?;
|
||||
socket.bind(addr)?;
|
||||
|
||||
self.listener = Some(socket.listen(1024)?);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn accept(&mut self) -> Result<Box<dyn Tunnel>, super::TunnelError> {
|
||||
let listener = self.listener.as_ref().unwrap();
|
||||
let (stream, _) = listener.accept().await?;
|
||||
stream.set_nodelay(true).unwrap();
|
||||
let info = TunnelInfo {
|
||||
tunnel_type: "tcp".to_owned(),
|
||||
local_addr: self.local_url().into(),
|
||||
remote_addr: super::build_url_from_socket_addr(&stream.peer_addr()?.to_string(), "tcp")
|
||||
.into(),
|
||||
};
|
||||
|
||||
let (r, w) = tokio::io::split(stream);
|
||||
Ok(FramedTunnel::new_tunnel_with_info(
|
||||
FramedRead::new(r, LengthDelimitedCodec::new()),
|
||||
FramedWrite::new(w, LengthDelimitedCodec::new()),
|
||||
info,
|
||||
))
|
||||
}
|
||||
|
||||
fn local_url(&self) -> url::Url {
|
||||
self.addr.clone()
|
||||
}
|
||||
}
|
||||
|
||||
fn get_tunnel_with_tcp_stream(
|
||||
stream: TcpStream,
|
||||
remote_url: url::Url,
|
||||
) -> Result<Box<dyn Tunnel>, super::TunnelError> {
|
||||
stream.set_nodelay(true).unwrap();
|
||||
|
||||
let info = TunnelInfo {
|
||||
tunnel_type: "tcp".to_owned(),
|
||||
local_addr: super::build_url_from_socket_addr(&stream.local_addr()?.to_string(), "tcp")
|
||||
.into(),
|
||||
remote_addr: remote_url.into(),
|
||||
};
|
||||
|
||||
let (r, w) = tokio::io::split(stream);
|
||||
Ok(Box::new(FramedTunnel::new_tunnel_with_info(
|
||||
FramedRead::new(r, LengthDelimitedCodec::new()),
|
||||
FramedWrite::new(w, LengthDelimitedCodec::new()),
|
||||
info,
|
||||
)))
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct TcpTunnelConnector {
|
||||
addr: url::Url,
|
||||
|
||||
bind_addrs: Vec<SocketAddr>,
|
||||
}
|
||||
|
||||
impl TcpTunnelConnector {
|
||||
pub fn new(addr: url::Url) -> Self {
|
||||
TcpTunnelConnector {
|
||||
addr,
|
||||
bind_addrs: vec![],
|
||||
}
|
||||
}
|
||||
|
||||
async fn connect_with_default_bind(&mut self) -> Result<Box<dyn Tunnel>, super::TunnelError> {
|
||||
tracing::info!(addr = ?self.addr, "connect tcp start");
|
||||
let addr = check_scheme_and_get_socket_addr::<SocketAddr>(&self.addr, "tcp")?;
|
||||
let stream = TcpStream::connect(addr).await?;
|
||||
tracing::info!(addr = ?self.addr, "connect tcp succ");
|
||||
return get_tunnel_with_tcp_stream(stream, self.addr.clone().into());
|
||||
}
|
||||
|
||||
async fn connect_with_custom_bind(&mut self) -> Result<Box<dyn Tunnel>, super::TunnelError> {
|
||||
let mut futures = FuturesUnordered::new();
|
||||
let dst_addr = check_scheme_and_get_socket_addr::<SocketAddr>(&self.addr, "tcp")?;
|
||||
|
||||
for bind_addr in self.bind_addrs.iter() {
|
||||
tracing::info!(bind_addr = ?bind_addr, ?dst_addr, "bind addr");
|
||||
|
||||
let socket2_socket = socket2::Socket::new(
|
||||
socket2::Domain::for_address(dst_addr),
|
||||
socket2::Type::STREAM,
|
||||
Some(socket2::Protocol::TCP),
|
||||
)?;
|
||||
setup_sokcet2(&socket2_socket, bind_addr)?;
|
||||
|
||||
let socket = TcpSocket::from_std_stream(socket2_socket.into());
|
||||
futures.push(socket.connect(dst_addr.clone()));
|
||||
}
|
||||
|
||||
let Some(ret) = futures.next().await else {
|
||||
return Err(super::TunnelError::CommonError(
|
||||
"join connect futures failed".to_owned(),
|
||||
));
|
||||
};
|
||||
|
||||
return get_tunnel_with_tcp_stream(ret?, self.addr.clone().into());
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl super::TunnelConnector for TcpTunnelConnector {
|
||||
async fn connect(&mut self) -> Result<Box<dyn Tunnel>, super::TunnelError> {
|
||||
if self.bind_addrs.is_empty() {
|
||||
self.connect_with_default_bind().await
|
||||
} else {
|
||||
self.connect_with_custom_bind().await
|
||||
}
|
||||
}
|
||||
|
||||
fn remote_url(&self) -> url::Url {
|
||||
self.addr.clone()
|
||||
}
|
||||
fn set_bind_addrs(&mut self, addrs: Vec<SocketAddr>) {
|
||||
self.bind_addrs = addrs;
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use futures::SinkExt;
|
||||
|
||||
use crate::tunnels::{
|
||||
common::tests::{_tunnel_bench, _tunnel_pingpong},
|
||||
TunnelConnector,
|
||||
};
|
||||
|
||||
use super::*;
|
||||
|
||||
#[tokio::test]
|
||||
async fn tcp_pingpong() {
|
||||
let listener = TcpTunnelListener::new("tcp://0.0.0.0:11011".parse().unwrap());
|
||||
let connector = TcpTunnelConnector::new("tcp://127.0.0.1:11011".parse().unwrap());
|
||||
_tunnel_pingpong(listener, connector).await
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn tcp_bench() {
|
||||
let listener = TcpTunnelListener::new("tcp://0.0.0.0:11012".parse().unwrap());
|
||||
let connector = TcpTunnelConnector::new("tcp://127.0.0.1:11012".parse().unwrap());
|
||||
_tunnel_bench(listener, connector).await
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn tcp_bench_with_bind() {
|
||||
let listener = TcpTunnelListener::new("tcp://127.0.0.1:11013".parse().unwrap());
|
||||
let mut connector = TcpTunnelConnector::new("tcp://127.0.0.1:11013".parse().unwrap());
|
||||
connector.set_bind_addrs(vec!["127.0.0.1:0".parse().unwrap()]);
|
||||
_tunnel_pingpong(listener, connector).await
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
#[should_panic]
|
||||
async fn tcp_bench_with_bind_fail() {
|
||||
let listener = TcpTunnelListener::new("tcp://127.0.0.1:11014".parse().unwrap());
|
||||
let mut connector = TcpTunnelConnector::new("tcp://127.0.0.1:11014".parse().unwrap());
|
||||
connector.set_bind_addrs(vec!["10.0.0.1:0".parse().unwrap()]);
|
||||
_tunnel_pingpong(listener, connector).await
|
||||
}
|
||||
|
||||
// test slow send lock in framed tunnel
|
||||
#[tokio::test]
|
||||
async fn tcp_multiple_sender_and_slow_receiver() {
|
||||
// console_subscriber::init();
|
||||
let mut listener = TcpTunnelListener::new("tcp://127.0.0.1:11014".parse().unwrap());
|
||||
let mut connector = TcpTunnelConnector::new("tcp://127.0.0.1:11014".parse().unwrap());
|
||||
|
||||
listener.listen().await.unwrap();
|
||||
let t1 = tokio::spawn(async move {
|
||||
let t = listener.accept().await.unwrap();
|
||||
let mut stream = t.pin_stream();
|
||||
|
||||
let now = tokio::time::Instant::now();
|
||||
|
||||
while let Some(Ok(_)) = stream.next().await {
|
||||
tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
|
||||
if now.elapsed().as_secs() > 5 {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
tracing::info!("t1 exit");
|
||||
});
|
||||
|
||||
let tunnel = connector.connect().await.unwrap();
|
||||
let mut sink1 = tunnel.pin_sink();
|
||||
let t2 = tokio::spawn(async move {
|
||||
for i in 0..1000000 {
|
||||
let a = sink1.send(b"hello".to_vec().into()).await;
|
||||
if a.is_err() {
|
||||
tracing::info!(?a, "t2 exit with err");
|
||||
break;
|
||||
}
|
||||
|
||||
if i % 5000 == 0 {
|
||||
tracing::info!(i, "send2 1000");
|
||||
}
|
||||
}
|
||||
|
||||
tracing::info!("t2 exit");
|
||||
});
|
||||
|
||||
let mut sink2 = tunnel.pin_sink();
|
||||
let t3 = tokio::spawn(async move {
|
||||
for i in 0..1000000 {
|
||||
let a = sink2.send(b"hello".to_vec().into()).await;
|
||||
if a.is_err() {
|
||||
tracing::info!(?a, "t3 exit with err");
|
||||
break;
|
||||
}
|
||||
|
||||
if i % 5000 == 0 {
|
||||
tracing::info!(i, "send2 1000");
|
||||
}
|
||||
}
|
||||
|
||||
tracing::info!("t3 exit");
|
||||
});
|
||||
|
||||
let t4 = tokio::spawn(async move {
|
||||
tokio::time::sleep(tokio::time::Duration::from_secs(5)).await;
|
||||
tracing::info!("closing");
|
||||
let close_ret = tunnel.pin_sink().close().await;
|
||||
tracing::info!("closed {:?}", close_ret);
|
||||
});
|
||||
|
||||
let _ = tokio::join!(t1, t2, t3, t4);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,279 @@
|
||||
use std::{
|
||||
sync::Arc,
|
||||
task::{Context, Poll},
|
||||
};
|
||||
|
||||
use crate::rpc::TunnelInfo;
|
||||
use futures::{Sink, SinkExt, Stream, StreamExt};
|
||||
|
||||
use self::stats::Throughput;
|
||||
|
||||
use super::*;
|
||||
use crate::tunnels::{DatagramSink, DatagramStream, SinkError, SinkItem, StreamItem, Tunnel};
|
||||
|
||||
pub trait TunnelFilter {
|
||||
fn before_send(&self, data: SinkItem) -> Option<Result<SinkItem, SinkError>> {
|
||||
Some(Ok(data))
|
||||
}
|
||||
fn after_received(&self, data: StreamItem) -> Option<Result<BytesMut, TunnelError>> {
|
||||
match data {
|
||||
Ok(v) => Some(Ok(v)),
|
||||
Err(e) => Some(Err(e)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct TunnelWithFilter<T, F> {
|
||||
inner: T,
|
||||
filter: Arc<F>,
|
||||
}
|
||||
|
||||
impl<T, F> Tunnel for TunnelWithFilter<T, F>
|
||||
where
|
||||
T: Tunnel + Send + Sync + 'static,
|
||||
F: TunnelFilter + Send + Sync + 'static,
|
||||
{
|
||||
fn sink(&self) -> Box<dyn DatagramSink> {
|
||||
struct SinkWrapper<F> {
|
||||
sink: Pin<Box<dyn DatagramSink>>,
|
||||
filter: Arc<F>,
|
||||
}
|
||||
impl<F> Sink<SinkItem> for SinkWrapper<F>
|
||||
where
|
||||
F: TunnelFilter + Send + Sync + 'static,
|
||||
{
|
||||
type Error = SinkError;
|
||||
|
||||
fn poll_ready(
|
||||
self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
) -> Poll<Result<(), Self::Error>> {
|
||||
self.get_mut().sink.poll_ready_unpin(cx)
|
||||
}
|
||||
|
||||
fn start_send(self: Pin<&mut Self>, item: SinkItem) -> Result<(), Self::Error> {
|
||||
let Some(item) = self.filter.before_send(item) else {
|
||||
return Ok(());
|
||||
};
|
||||
self.get_mut().sink.start_send_unpin(item?)
|
||||
}
|
||||
|
||||
fn poll_flush(
|
||||
self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
) -> Poll<Result<(), Self::Error>> {
|
||||
self.get_mut().sink.poll_flush_unpin(cx)
|
||||
}
|
||||
|
||||
fn poll_close(
|
||||
self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
) -> Poll<Result<(), Self::Error>> {
|
||||
self.get_mut().sink.poll_close_unpin(cx)
|
||||
}
|
||||
}
|
||||
|
||||
Box::new(SinkWrapper {
|
||||
sink: self.inner.pin_sink(),
|
||||
filter: self.filter.clone(),
|
||||
})
|
||||
}
|
||||
|
||||
fn stream(&self) -> Box<dyn DatagramStream> {
|
||||
struct StreamWrapper<F> {
|
||||
stream: Pin<Box<dyn DatagramStream>>,
|
||||
filter: Arc<F>,
|
||||
}
|
||||
impl<F> Stream for StreamWrapper<F>
|
||||
where
|
||||
F: TunnelFilter + Send + Sync + 'static,
|
||||
{
|
||||
type Item = StreamItem;
|
||||
|
||||
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
|
||||
let self_mut = self.get_mut();
|
||||
loop {
|
||||
match self_mut.stream.poll_next_unpin(cx) {
|
||||
Poll::Ready(Some(ret)) => {
|
||||
let Some(ret) = self_mut.filter.after_received(ret) else {
|
||||
continue;
|
||||
};
|
||||
return Poll::Ready(Some(ret));
|
||||
}
|
||||
Poll::Ready(None) => {
|
||||
return Poll::Ready(None);
|
||||
}
|
||||
Poll::Pending => {
|
||||
return Poll::Pending;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Box::new(StreamWrapper {
|
||||
stream: self.inner.pin_stream(),
|
||||
filter: self.filter.clone(),
|
||||
})
|
||||
}
|
||||
|
||||
fn info(&self) -> Option<TunnelInfo> {
|
||||
self.inner.info()
|
||||
}
|
||||
}
|
||||
|
||||
impl<T, F> TunnelWithFilter<T, F>
|
||||
where
|
||||
T: Tunnel + Send + Sync + 'static,
|
||||
F: TunnelFilter + Send + Sync + 'static,
|
||||
{
|
||||
pub fn new(inner: T, filter: Arc<F>) -> Self {
|
||||
Self { inner, filter }
|
||||
}
|
||||
}
|
||||
|
||||
pub struct PacketRecorderTunnelFilter {
|
||||
pub received: Arc<std::sync::Mutex<Vec<Bytes>>>,
|
||||
pub sent: Arc<std::sync::Mutex<Vec<Bytes>>>,
|
||||
}
|
||||
|
||||
impl TunnelFilter for PacketRecorderTunnelFilter {
|
||||
fn before_send(&self, data: SinkItem) -> Option<Result<SinkItem, SinkError>> {
|
||||
self.received.lock().unwrap().push(data.clone());
|
||||
Some(Ok(data))
|
||||
}
|
||||
|
||||
fn after_received(&self, data: StreamItem) -> Option<Result<BytesMut, TunnelError>> {
|
||||
match data {
|
||||
Ok(v) => {
|
||||
self.sent.lock().unwrap().push(v.clone().into());
|
||||
Some(Ok(v))
|
||||
}
|
||||
Err(e) => Some(Err(e)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl PacketRecorderTunnelFilter {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
received: Arc::new(std::sync::Mutex::new(Vec::new())),
|
||||
sent: Arc::new(std::sync::Mutex::new(Vec::new())),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct StatsRecorderTunnelFilter {
|
||||
throughput: Arc<Throughput>,
|
||||
}
|
||||
|
||||
impl TunnelFilter for StatsRecorderTunnelFilter {
|
||||
fn before_send(&self, data: SinkItem) -> Option<Result<SinkItem, SinkError>> {
|
||||
self.throughput.record_tx_bytes(data.len() as u64);
|
||||
Some(Ok(data))
|
||||
}
|
||||
|
||||
fn after_received(&self, data: StreamItem) -> Option<Result<BytesMut, TunnelError>> {
|
||||
match data {
|
||||
Ok(v) => {
|
||||
self.throughput.record_rx_bytes(v.len() as u64);
|
||||
Some(Ok(v))
|
||||
}
|
||||
Err(e) => Some(Err(e)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl StatsRecorderTunnelFilter {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
throughput: Arc::new(Throughput::new()),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn get_throughput(&self) -> Arc<Throughput> {
|
||||
self.throughput.clone()
|
||||
}
|
||||
}
|
||||
|
||||
#[macro_export]
|
||||
macro_rules! define_tunnel_filter_chain {
|
||||
($type_name:ident $(, $field_name:ident = $filter_type:ty)+) => (
|
||||
pub struct $type_name {
|
||||
$($field_name: std::sync::Arc<$filter_type>,)+
|
||||
}
|
||||
|
||||
impl $type_name {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
$($field_name: std::sync::Arc::new(<$filter_type>::new()),)+
|
||||
}
|
||||
}
|
||||
|
||||
pub fn wrap_tunnel(&self, tunnel: impl Tunnel + 'static) -> impl Tunnel {
|
||||
$(
|
||||
let tunnel = crate::tunnels::tunnel_filter::TunnelWithFilter::new(tunnel, self.$field_name.clone());
|
||||
)+
|
||||
tunnel
|
||||
}
|
||||
}
|
||||
)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
pub mod tests {
|
||||
use std::sync::atomic::{AtomicU32, Ordering};
|
||||
|
||||
use super::*;
|
||||
use crate::tunnels::ring_tunnel::RingTunnel;
|
||||
|
||||
pub struct DropSendTunnelFilter {
|
||||
start: AtomicU32,
|
||||
end: AtomicU32,
|
||||
cur: AtomicU32,
|
||||
}
|
||||
|
||||
impl TunnelFilter for DropSendTunnelFilter {
|
||||
fn before_send(&self, data: SinkItem) -> Option<Result<SinkItem, SinkError>> {
|
||||
self.cur.fetch_add(1, Ordering::SeqCst);
|
||||
if self.cur.load(Ordering::SeqCst) >= self.start.load(Ordering::SeqCst)
|
||||
&& self.cur.load(std::sync::atomic::Ordering::SeqCst)
|
||||
< self.end.load(Ordering::SeqCst)
|
||||
{
|
||||
tracing::trace!("drop packet: {:?}", data);
|
||||
return None;
|
||||
}
|
||||
Some(Ok(data))
|
||||
}
|
||||
}
|
||||
|
||||
impl DropSendTunnelFilter {
|
||||
pub fn new(start: u32, end: u32) -> Self {
|
||||
Self {
|
||||
start: AtomicU32::new(start),
|
||||
end: AtomicU32::new(end),
|
||||
cur: AtomicU32::new(0),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_nested_filter() {
|
||||
define_tunnel_filter_chain!(
|
||||
Filter,
|
||||
a = PacketRecorderTunnelFilter,
|
||||
b = PacketRecorderTunnelFilter,
|
||||
c = PacketRecorderTunnelFilter
|
||||
);
|
||||
|
||||
let filter = Filter::new();
|
||||
let tunnel = filter.wrap_tunnel(RingTunnel::new(1));
|
||||
|
||||
let mut s = tunnel.pin_sink();
|
||||
s.send(Bytes::from("hello")).await.unwrap();
|
||||
|
||||
assert_eq!(1, filter.a.received.lock().unwrap().len());
|
||||
assert_eq!(1, filter.b.received.lock().unwrap().len());
|
||||
assert_eq!(1, filter.c.received.lock().unwrap().len());
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,772 @@
|
||||
use std::{fmt::Debug, pin::Pin, sync::Arc};
|
||||
|
||||
use async_trait::async_trait;
|
||||
use dashmap::DashMap;
|
||||
use futures::{stream::FuturesUnordered, SinkExt, StreamExt};
|
||||
use rkyv::{Archive, Deserialize, Serialize};
|
||||
use std::net::SocketAddr;
|
||||
use tokio::{net::UdpSocket, sync::Mutex, task::JoinSet};
|
||||
use tokio_util::{
|
||||
bytes::{Buf, Bytes, BytesMut},
|
||||
udp::UdpFramed,
|
||||
};
|
||||
use tracing::Instrument;
|
||||
|
||||
use crate::{
|
||||
common::{
|
||||
join_joinset_background,
|
||||
rkyv_util::{self, encode_to_bytes, vec_to_string},
|
||||
},
|
||||
rpc::TunnelInfo,
|
||||
tunnels::{build_url_from_socket_addr, close_tunnel, TunnelConnCounter, TunnelConnector},
|
||||
};
|
||||
|
||||
use super::{
|
||||
codec::BytesCodec,
|
||||
common::{setup_sokcet2, setup_sokcet2_ext, FramedTunnel, TunnelWithCustomInfo},
|
||||
ring_tunnel::create_ring_tunnel_pair,
|
||||
DatagramSink, DatagramStream, Tunnel, TunnelListener, TunnelUrl,
|
||||
};
|
||||
|
||||
pub const UDP_DATA_MTU: usize = 65000;
|
||||
|
||||
#[derive(Archive, Deserialize, Serialize)]
|
||||
#[archive(compare(PartialEq), check_bytes)]
|
||||
// Derives can be passed through to the generated type:
|
||||
pub enum UdpPacketPayload {
|
||||
Syn,
|
||||
Sack,
|
||||
HolePunch(String),
|
||||
Data(String),
|
||||
}
|
||||
|
||||
impl std::fmt::Debug for UdpPacketPayload {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
let mut tmp = f.debug_struct("ArchivedUdpPacketPayload");
|
||||
match self {
|
||||
UdpPacketPayload::Syn => tmp.field("Syn", &"").finish(),
|
||||
UdpPacketPayload::Sack => tmp.field("Sack", &"").finish(),
|
||||
UdpPacketPayload::HolePunch(s) => tmp.field("HolePunch", &s.as_bytes()).finish(),
|
||||
UdpPacketPayload::Data(s) => tmp.field("Data", &s.as_bytes()).finish(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Archive, Deserialize, Serialize, Debug)]
|
||||
#[archive(compare(PartialEq), check_bytes)]
|
||||
#[archive_attr(derive(Debug))]
|
||||
pub struct UdpPacket {
|
||||
pub conn_id: u32,
|
||||
pub magic: u32,
|
||||
pub payload: UdpPacketPayload,
|
||||
}
|
||||
|
||||
const UDP_PACKET_MAGIC: u32 = 0x19941126;
|
||||
|
||||
impl std::fmt::Debug for ArchivedUdpPacketPayload {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
let mut tmp = f.debug_struct("ArchivedUdpPacketPayload");
|
||||
match self {
|
||||
ArchivedUdpPacketPayload::Syn => tmp.field("Syn", &"").finish(),
|
||||
ArchivedUdpPacketPayload::Sack => tmp.field("Sack", &"").finish(),
|
||||
ArchivedUdpPacketPayload::HolePunch(s) => {
|
||||
tmp.field("HolePunch", &s.as_bytes()).finish()
|
||||
}
|
||||
ArchivedUdpPacketPayload::Data(s) => tmp.field("Data", &s.as_bytes()).finish(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl UdpPacket {
|
||||
pub fn new_data_packet(conn_id: u32, data: Vec<u8>) -> Self {
|
||||
Self {
|
||||
conn_id,
|
||||
magic: UDP_PACKET_MAGIC,
|
||||
payload: UdpPacketPayload::Data(vec_to_string(data)),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn new_hole_punch_packet(data: Vec<u8>) -> Self {
|
||||
Self {
|
||||
conn_id: 0,
|
||||
magic: UDP_PACKET_MAGIC,
|
||||
payload: UdpPacketPayload::HolePunch(vec_to_string(data)),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn new_syn_packet(conn_id: u32) -> Self {
|
||||
Self {
|
||||
conn_id,
|
||||
magic: UDP_PACKET_MAGIC,
|
||||
payload: UdpPacketPayload::Syn,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn new_sack_packet(conn_id: u32) -> Self {
|
||||
Self {
|
||||
conn_id,
|
||||
magic: UDP_PACKET_MAGIC,
|
||||
payload: UdpPacketPayload::Sack,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn try_get_data_payload(mut buf: BytesMut, conn_id: u32) -> Option<BytesMut> {
|
||||
let Ok(udp_packet) = rkyv_util::decode_from_bytes::<UdpPacket>(&buf) else {
|
||||
tracing::warn!(?buf, "udp decode error");
|
||||
return None;
|
||||
};
|
||||
|
||||
if udp_packet.conn_id != conn_id.clone() {
|
||||
tracing::warn!(?udp_packet, ?conn_id, "udp conn id not match");
|
||||
return None;
|
||||
}
|
||||
|
||||
if udp_packet.magic != UDP_PACKET_MAGIC {
|
||||
tracing::trace!(?udp_packet, "udp magic not match");
|
||||
return None;
|
||||
}
|
||||
|
||||
let ArchivedUdpPacketPayload::Data(payload) = &udp_packet.payload else {
|
||||
tracing::warn!(?udp_packet, "udp payload not data");
|
||||
return None;
|
||||
};
|
||||
|
||||
let offset = payload.as_ptr() as usize - buf.as_ptr() as usize;
|
||||
let len = payload.len();
|
||||
if offset + len > buf.len() {
|
||||
tracing::warn!(?offset, ?len, ?buf, "udp payload data out of range");
|
||||
return None;
|
||||
}
|
||||
|
||||
buf.advance(offset);
|
||||
buf.truncate(len);
|
||||
tracing::trace!(?offset, ?len, ?buf, "udp payload data");
|
||||
|
||||
Some(buf)
|
||||
}
|
||||
|
||||
fn get_tunnel_from_socket(
|
||||
socket: Arc<UdpSocket>,
|
||||
addr: SocketAddr,
|
||||
conn_id: u32,
|
||||
) -> Box<dyn super::Tunnel> {
|
||||
let udp = UdpFramed::new(socket.clone(), BytesCodec::new(UDP_DATA_MTU));
|
||||
let (sink, stream) = udp.split();
|
||||
|
||||
let recv_addr = addr;
|
||||
let stream = stream.filter_map(move |v| async move {
|
||||
tracing::trace!(?v, "udp stream recv something");
|
||||
if v.is_err() {
|
||||
tracing::warn!(?v, "udp stream error");
|
||||
return Some(Err(super::TunnelError::CommonError(
|
||||
"udp stream error".to_owned(),
|
||||
)));
|
||||
}
|
||||
|
||||
let (buf, addr) = v.unwrap();
|
||||
if recv_addr != addr {
|
||||
tracing::warn!(?addr, ?recv_addr, "udp recv addr not match");
|
||||
return None;
|
||||
}
|
||||
|
||||
Some(Ok(try_get_data_payload(buf, conn_id.clone())?))
|
||||
});
|
||||
let stream = Box::pin(stream);
|
||||
|
||||
let sender_addr = addr;
|
||||
let sink = Box::pin(sink.with(move |v: Bytes| async move {
|
||||
if false {
|
||||
return Err(super::TunnelError::CommonError("udp sink error".to_owned()));
|
||||
}
|
||||
|
||||
// TODO: two copy here, how to avoid?
|
||||
let udp_packet = UdpPacket::new_data_packet(conn_id, v.to_vec());
|
||||
let v = encode_to_bytes::<_, UDP_DATA_MTU>(&udp_packet);
|
||||
tracing::trace!(?udp_packet, ?v, "udp send packet");
|
||||
|
||||
Ok((v, sender_addr))
|
||||
}));
|
||||
|
||||
FramedTunnel::new_tunnel_with_info(
|
||||
stream,
|
||||
sink,
|
||||
// TODO: this remote addr is not a url
|
||||
super::TunnelInfo {
|
||||
tunnel_type: "udp".to_owned(),
|
||||
local_addr: super::build_url_from_socket_addr(
|
||||
&socket.local_addr().unwrap().to_string(),
|
||||
"udp",
|
||||
)
|
||||
.into(),
|
||||
remote_addr: super::build_url_from_socket_addr(&addr.to_string(), "udp").into(),
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
pub(crate) struct StreamSinkPair(
|
||||
pub Pin<Box<dyn DatagramStream>>,
|
||||
pub Pin<Box<dyn DatagramSink>>,
|
||||
pub u32,
|
||||
);
|
||||
pub(crate) type ArcStreamSinkPair = Arc<Mutex<StreamSinkPair>>;
|
||||
|
||||
pub struct UdpTunnelListener {
|
||||
addr: url::Url,
|
||||
socket: Option<Arc<UdpSocket>>,
|
||||
|
||||
sock_map: Arc<DashMap<SocketAddr, ArcStreamSinkPair>>,
|
||||
forward_tasks: Arc<std::sync::Mutex<JoinSet<()>>>,
|
||||
|
||||
conn_recv: tokio::sync::mpsc::Receiver<Box<dyn Tunnel>>,
|
||||
conn_send: Option<tokio::sync::mpsc::Sender<Box<dyn Tunnel>>>,
|
||||
}
|
||||
|
||||
impl UdpTunnelListener {
|
||||
pub fn new(addr: url::Url) -> Self {
|
||||
let (conn_send, conn_recv) = tokio::sync::mpsc::channel(100);
|
||||
Self {
|
||||
addr,
|
||||
socket: None,
|
||||
sock_map: Arc::new(DashMap::new()),
|
||||
forward_tasks: Arc::new(std::sync::Mutex::new(JoinSet::new())),
|
||||
conn_recv,
|
||||
conn_send: Some(conn_send),
|
||||
}
|
||||
}
|
||||
|
||||
async fn try_forward_packet(
|
||||
sock_map: &DashMap<SocketAddr, ArcStreamSinkPair>,
|
||||
buf: BytesMut,
|
||||
addr: SocketAddr,
|
||||
) -> Result<(), super::TunnelError> {
|
||||
let entry = sock_map.get_mut(&addr);
|
||||
if entry.is_none() {
|
||||
log::warn!("udp forward packet: {:?}, {:?}, no entry", addr, buf);
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
log::trace!("udp forward packet: {:?}, {:?}", addr, buf);
|
||||
let entry = entry.unwrap();
|
||||
let pair = entry.value().clone();
|
||||
drop(entry);
|
||||
|
||||
let Some(buf) = try_get_data_payload(buf, pair.lock().await.2) else {
|
||||
return Ok(());
|
||||
};
|
||||
pair.lock().await.1.send(buf.freeze()).await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn handle_connect(
|
||||
socket: Arc<UdpSocket>,
|
||||
addr: SocketAddr,
|
||||
forward_tasks: Arc<std::sync::Mutex<JoinSet<()>>>,
|
||||
sock_map: Arc<DashMap<SocketAddr, ArcStreamSinkPair>>,
|
||||
local_url: url::Url,
|
||||
conn_id: u32,
|
||||
) -> Result<Box<dyn Tunnel>, super::TunnelError> {
|
||||
tracing::info!(?conn_id, ?addr, "udp connection accept handling",);
|
||||
|
||||
let udp_packet = UdpPacket::new_sack_packet(conn_id);
|
||||
let sack_buf = encode_to_bytes::<_, UDP_DATA_MTU>(&udp_packet);
|
||||
socket.send_to(&sack_buf, addr).await?;
|
||||
|
||||
let (ctunnel, stunnel) = create_ring_tunnel_pair();
|
||||
let udp_tunnel = get_tunnel_from_socket(socket.clone(), addr, conn_id);
|
||||
let ss_pair = StreamSinkPair(ctunnel.pin_stream(), ctunnel.pin_sink(), conn_id);
|
||||
let addr_copy = addr.clone();
|
||||
sock_map.insert(addr, Arc::new(Mutex::new(ss_pair)));
|
||||
let ctunnel_stream = ctunnel.pin_stream();
|
||||
forward_tasks.lock().unwrap().spawn(async move {
|
||||
let ret = ctunnel_stream
|
||||
.map(|v| {
|
||||
tracing::trace!(?v, "udp stream recv something in forward task");
|
||||
if v.is_err() {
|
||||
return Err(super::TunnelError::CommonError(
|
||||
"udp stream error".to_owned(),
|
||||
));
|
||||
}
|
||||
Ok(v.unwrap().freeze())
|
||||
})
|
||||
.forward(udp_tunnel.pin_sink())
|
||||
.await;
|
||||
if let None = sock_map.remove(&addr_copy) {
|
||||
log::warn!("udp forward packet: {:?}, no entry", addr_copy);
|
||||
}
|
||||
close_tunnel(&ctunnel).await.unwrap();
|
||||
log::warn!("udp connection forward done: {:?}, {:?}", addr_copy, ret);
|
||||
});
|
||||
|
||||
Ok(Box::new(TunnelWithCustomInfo::new(
|
||||
stunnel,
|
||||
TunnelInfo {
|
||||
tunnel_type: "udp".to_owned(),
|
||||
local_addr: local_url.into(),
|
||||
remote_addr: build_url_from_socket_addr(&addr.to_string(), "udp").into(),
|
||||
},
|
||||
)))
|
||||
}
|
||||
|
||||
pub fn get_socket(&self) -> Option<Arc<UdpSocket>> {
|
||||
self.socket.clone()
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl TunnelListener for UdpTunnelListener {
|
||||
async fn listen(&mut self) -> Result<(), super::TunnelError> {
|
||||
let addr = super::check_scheme_and_get_socket_addr::<SocketAddr>(&self.addr, "udp")?;
|
||||
|
||||
let socket2_socket = socket2::Socket::new(
|
||||
socket2::Domain::for_address(addr),
|
||||
socket2::Type::DGRAM,
|
||||
Some(socket2::Protocol::UDP),
|
||||
)?;
|
||||
|
||||
let tunnel_url: TunnelUrl = self.addr.clone().into();
|
||||
if let Some(bind_dev) = tunnel_url.bind_dev() {
|
||||
setup_sokcet2_ext(&socket2_socket, &addr, Some(bind_dev))?;
|
||||
} else {
|
||||
setup_sokcet2(&socket2_socket, &addr)?;
|
||||
}
|
||||
|
||||
self.socket = Some(Arc::new(UdpSocket::from_std(socket2_socket.into())?));
|
||||
|
||||
let socket = self.socket.as_ref().unwrap().clone();
|
||||
let forward_tasks = self.forward_tasks.clone();
|
||||
let sock_map = self.sock_map.clone();
|
||||
let conn_send = self.conn_send.take().unwrap();
|
||||
let local_url = self.local_url().clone();
|
||||
self.forward_tasks.lock().unwrap().spawn(
|
||||
async move {
|
||||
loop {
|
||||
let mut buf = BytesMut::new();
|
||||
buf.resize(UDP_DATA_MTU, 0);
|
||||
let (_size, addr) = socket.recv_from(&mut buf).await.unwrap();
|
||||
let _ = buf.split_off(_size);
|
||||
log::trace!(
|
||||
"udp recv packet: {:?}, buf: {:?}, size: {}",
|
||||
addr,
|
||||
buf,
|
||||
_size
|
||||
);
|
||||
|
||||
let Ok(udp_packet) = rkyv_util::decode_from_bytes::<UdpPacket>(&buf) else {
|
||||
tracing::warn!(?buf, "udp decode error in forward task");
|
||||
continue;
|
||||
};
|
||||
|
||||
if udp_packet.magic != UDP_PACKET_MAGIC {
|
||||
tracing::trace!(?udp_packet, "udp magic not match");
|
||||
continue;
|
||||
}
|
||||
|
||||
if matches!(udp_packet.payload, ArchivedUdpPacketPayload::Syn) {
|
||||
let Ok(conn) = Self::handle_connect(
|
||||
socket.clone(),
|
||||
addr,
|
||||
forward_tasks.clone(),
|
||||
sock_map.clone(),
|
||||
local_url.clone(),
|
||||
udp_packet.conn_id.into(),
|
||||
)
|
||||
.await
|
||||
else {
|
||||
tracing::error!(?addr, "udp handle connect error");
|
||||
continue;
|
||||
};
|
||||
if let Err(e) = conn_send.send(conn).await {
|
||||
tracing::warn!(?e, "udp send conn to accept channel error");
|
||||
}
|
||||
} else {
|
||||
Self::try_forward_packet(sock_map.as_ref(), buf, addr)
|
||||
.await
|
||||
.unwrap();
|
||||
}
|
||||
}
|
||||
}
|
||||
.instrument(tracing::info_span!("udp forward task", ?self.socket)),
|
||||
);
|
||||
|
||||
join_joinset_background(self.forward_tasks.clone(), "UdpTunnelListener".to_owned());
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn accept(&mut self) -> Result<Box<dyn super::Tunnel>, super::TunnelError> {
|
||||
log::info!("start udp accept: {:?}", self.addr);
|
||||
while let Some(conn) = self.conn_recv.recv().await {
|
||||
return Ok(conn);
|
||||
}
|
||||
return Err(super::TunnelError::CommonError(
|
||||
"udp accept error".to_owned(),
|
||||
));
|
||||
}
|
||||
|
||||
fn local_url(&self) -> url::Url {
|
||||
self.addr.clone()
|
||||
}
|
||||
|
||||
fn get_conn_counter(&self) -> Arc<Box<dyn TunnelConnCounter>> {
|
||||
struct UdpTunnelConnCounter {
|
||||
sock_map: Arc<DashMap<SocketAddr, ArcStreamSinkPair>>,
|
||||
}
|
||||
|
||||
impl Debug for UdpTunnelConnCounter {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
f.debug_struct("UdpTunnelConnCounter")
|
||||
.field("sock_map_len", &self.sock_map.len())
|
||||
.finish()
|
||||
}
|
||||
}
|
||||
|
||||
impl TunnelConnCounter for UdpTunnelConnCounter {
|
||||
fn get(&self) -> u32 {
|
||||
self.sock_map.len() as u32
|
||||
}
|
||||
}
|
||||
|
||||
Arc::new(Box::new(UdpTunnelConnCounter {
|
||||
sock_map: self.sock_map.clone(),
|
||||
}))
|
||||
}
|
||||
}
|
||||
|
||||
pub struct UdpTunnelConnector {
|
||||
addr: url::Url,
|
||||
bind_addrs: Vec<SocketAddr>,
|
||||
}
|
||||
|
||||
impl UdpTunnelConnector {
|
||||
pub fn new(addr: url::Url) -> Self {
|
||||
Self {
|
||||
addr,
|
||||
bind_addrs: vec![],
|
||||
}
|
||||
}
|
||||
|
||||
async fn wait_sack(
|
||||
socket: &UdpSocket,
|
||||
addr: SocketAddr,
|
||||
conn_id: u32,
|
||||
) -> Result<(), super::TunnelError> {
|
||||
let mut buf = BytesMut::new();
|
||||
buf.resize(128, 0);
|
||||
|
||||
let (usize, recv_addr) = tokio::time::timeout(
|
||||
tokio::time::Duration::from_secs(3),
|
||||
socket.recv_from(&mut buf),
|
||||
)
|
||||
.await??;
|
||||
|
||||
if recv_addr != addr {
|
||||
return Err(super::TunnelError::ConnectError(format!(
|
||||
"udp connect error, unexpected sack addr: {:?}, {:?}",
|
||||
recv_addr, addr
|
||||
)));
|
||||
}
|
||||
|
||||
let _ = buf.split_off(usize);
|
||||
|
||||
let Ok(udp_packet) = rkyv_util::decode_from_bytes::<UdpPacket>(&buf) else {
|
||||
tracing::warn!(?buf, "udp decode error in wait sack");
|
||||
return Err(super::TunnelError::ConnectError(format!(
|
||||
"udp connect error, decode error. buf: {:?}",
|
||||
buf
|
||||
)));
|
||||
};
|
||||
|
||||
if udp_packet.magic != UDP_PACKET_MAGIC {
|
||||
tracing::trace!(?udp_packet, "udp magic not match");
|
||||
return Err(super::TunnelError::ConnectError(format!(
|
||||
"udp connect error, magic not match. magic: {:?}",
|
||||
udp_packet.magic
|
||||
)));
|
||||
}
|
||||
|
||||
if conn_id != udp_packet.conn_id {
|
||||
return Err(super::TunnelError::ConnectError(format!(
|
||||
"udp connect error, conn id not match. conn_id: {:?}, {:?}",
|
||||
conn_id, udp_packet.conn_id
|
||||
)));
|
||||
}
|
||||
|
||||
if !matches!(udp_packet.payload, ArchivedUdpPacketPayload::Sack) {
|
||||
return Err(super::TunnelError::ConnectError(format!(
|
||||
"udp connect error, unexpected payload. payload: {:?}",
|
||||
udp_packet.payload
|
||||
)));
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn wait_sack_loop(
|
||||
socket: &UdpSocket,
|
||||
addr: SocketAddr,
|
||||
conn_id: u32,
|
||||
) -> Result<(), super::TunnelError> {
|
||||
while let Err(err) = Self::wait_sack(socket, addr, conn_id).await {
|
||||
tracing::warn!(?err, "udp wait sack error");
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn try_connect_with_socket(
|
||||
&self,
|
||||
socket: UdpSocket,
|
||||
) -> Result<Box<dyn super::Tunnel>, super::TunnelError> {
|
||||
let addr = super::check_scheme_and_get_socket_addr::<SocketAddr>(&self.addr, "udp")?;
|
||||
log::warn!("udp connect: {:?}", self.addr);
|
||||
|
||||
#[cfg(target_os = "windows")]
|
||||
crate::arch::windows::disable_connection_reset(&socket)?;
|
||||
|
||||
// send syn
|
||||
let conn_id = rand::random();
|
||||
let udp_packet = UdpPacket::new_syn_packet(conn_id);
|
||||
let b = encode_to_bytes::<_, UDP_DATA_MTU>(&udp_packet);
|
||||
let ret = socket.send_to(&b, &addr).await?;
|
||||
tracing::warn!(?udp_packet, ?ret, "udp send syn");
|
||||
|
||||
// wait sack
|
||||
tokio::time::timeout(
|
||||
tokio::time::Duration::from_secs(3),
|
||||
Self::wait_sack_loop(&socket, addr, conn_id),
|
||||
)
|
||||
.await??;
|
||||
|
||||
// sack done
|
||||
let local_addr = socket.local_addr().unwrap().to_string();
|
||||
Ok(Box::new(TunnelWithCustomInfo::new(
|
||||
get_tunnel_from_socket(Arc::new(socket), addr, conn_id),
|
||||
TunnelInfo {
|
||||
tunnel_type: "udp".to_owned(),
|
||||
local_addr: super::build_url_from_socket_addr(&local_addr, "udp").into(),
|
||||
remote_addr: self.remote_url().into(),
|
||||
},
|
||||
)))
|
||||
}
|
||||
|
||||
async fn connect_with_default_bind(&mut self) -> Result<Box<dyn Tunnel>, super::TunnelError> {
|
||||
let socket = UdpSocket::bind("0.0.0.0:0").await?;
|
||||
return self.try_connect_with_socket(socket).await;
|
||||
}
|
||||
|
||||
async fn connect_with_custom_bind(&mut self) -> Result<Box<dyn Tunnel>, super::TunnelError> {
|
||||
let mut futures = FuturesUnordered::new();
|
||||
|
||||
for bind_addr in self.bind_addrs.iter() {
|
||||
let socket2_socket = socket2::Socket::new(
|
||||
socket2::Domain::for_address(*bind_addr),
|
||||
socket2::Type::DGRAM,
|
||||
Some(socket2::Protocol::UDP),
|
||||
)?;
|
||||
setup_sokcet2(&socket2_socket, &bind_addr)?;
|
||||
let socket = UdpSocket::from_std(socket2_socket.into())?;
|
||||
futures.push(self.try_connect_with_socket(socket));
|
||||
}
|
||||
|
||||
let Some(ret) = futures.next().await else {
|
||||
return Err(super::TunnelError::CommonError(
|
||||
"join connect futures failed".to_owned(),
|
||||
));
|
||||
};
|
||||
|
||||
return ret;
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl super::TunnelConnector for UdpTunnelConnector {
|
||||
async fn connect(&mut self) -> Result<Box<dyn super::Tunnel>, super::TunnelError> {
|
||||
if self.bind_addrs.is_empty() {
|
||||
self.connect_with_default_bind().await
|
||||
} else {
|
||||
self.connect_with_custom_bind().await
|
||||
}
|
||||
}
|
||||
|
||||
fn remote_url(&self) -> url::Url {
|
||||
self.addr.clone()
|
||||
}
|
||||
|
||||
fn set_bind_addrs(&mut self, addrs: Vec<SocketAddr>) {
|
||||
self.bind_addrs = addrs;
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::time::Duration;
|
||||
|
||||
use rand::Rng;
|
||||
use tokio::time::timeout;
|
||||
|
||||
use crate::{
|
||||
common::global_ctx::tests::get_mock_global_ctx,
|
||||
tunnels::{
|
||||
check_scheme_and_get_socket_addr,
|
||||
common::{
|
||||
get_interface_name_by_ip, setup_sokcet2_ext,
|
||||
tests::{_tunnel_bench, _tunnel_echo_server, _tunnel_pingpong},
|
||||
},
|
||||
},
|
||||
};
|
||||
|
||||
use super::*;
|
||||
|
||||
#[tokio::test]
|
||||
async fn udp_pingpong() {
|
||||
let listener = UdpTunnelListener::new("udp://0.0.0.0:5556".parse().unwrap());
|
||||
let connector = UdpTunnelConnector::new("udp://127.0.0.1:5556".parse().unwrap());
|
||||
_tunnel_pingpong(listener, connector).await
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn udp_bench() {
|
||||
let listener = UdpTunnelListener::new("udp://0.0.0.0:5555".parse().unwrap());
|
||||
let connector = UdpTunnelConnector::new("udp://127.0.0.1:5555".parse().unwrap());
|
||||
_tunnel_bench(listener, connector).await
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn udp_bench_with_bind() {
|
||||
let listener = UdpTunnelListener::new("udp://127.0.0.1:5554".parse().unwrap());
|
||||
let mut connector = UdpTunnelConnector::new("udp://127.0.0.1:5554".parse().unwrap());
|
||||
connector.set_bind_addrs(vec!["127.0.0.1:0".parse().unwrap()]);
|
||||
_tunnel_pingpong(listener, connector).await
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
#[should_panic]
|
||||
async fn udp_bench_with_bind_fail() {
|
||||
let listener = UdpTunnelListener::new("udp://127.0.0.1:5553".parse().unwrap());
|
||||
let mut connector = UdpTunnelConnector::new("udp://127.0.0.1:5553".parse().unwrap());
|
||||
connector.set_bind_addrs(vec!["10.0.0.1:0".parse().unwrap()]);
|
||||
_tunnel_pingpong(listener, connector).await
|
||||
}
|
||||
|
||||
async fn send_random_data_to_socket(remote_url: url::Url) {
|
||||
let socket = UdpSocket::bind("0.0.0.0:0").await.unwrap();
|
||||
socket
|
||||
.connect(format!(
|
||||
"{}:{}",
|
||||
remote_url.host().unwrap(),
|
||||
remote_url.port().unwrap()
|
||||
))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// get a random 100-len buf
|
||||
loop {
|
||||
let mut buf = vec![0u8; 100];
|
||||
rand::thread_rng().fill(&mut buf[..]);
|
||||
socket.send(&buf).await.unwrap();
|
||||
tokio::time::sleep(tokio::time::Duration::from_millis(50)).await;
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn udp_multiple_conns() {
|
||||
let mut listener = UdpTunnelListener::new("udp://0.0.0.0:5557".parse().unwrap());
|
||||
listener.listen().await.unwrap();
|
||||
|
||||
let _lis = tokio::spawn(async move {
|
||||
loop {
|
||||
let ret = listener.accept().await.unwrap();
|
||||
assert_eq!(
|
||||
ret.info().unwrap().local_addr,
|
||||
listener.local_url().to_string()
|
||||
);
|
||||
tokio::spawn(async move { _tunnel_echo_server(ret, false).await });
|
||||
}
|
||||
});
|
||||
|
||||
let mut connector1 = UdpTunnelConnector::new("udp://127.0.0.1:5557".parse().unwrap());
|
||||
let mut connector2 = UdpTunnelConnector::new("udp://127.0.0.1:5557".parse().unwrap());
|
||||
|
||||
let t1 = connector1.connect().await.unwrap();
|
||||
let t2 = connector2.connect().await.unwrap();
|
||||
|
||||
tokio::spawn(timeout(
|
||||
Duration::from_secs(2),
|
||||
send_random_data_to_socket(t1.info().unwrap().local_addr.parse().unwrap()),
|
||||
));
|
||||
tokio::spawn(timeout(
|
||||
Duration::from_secs(2),
|
||||
send_random_data_to_socket(t1.info().unwrap().remote_addr.parse().unwrap()),
|
||||
));
|
||||
tokio::spawn(timeout(
|
||||
Duration::from_secs(2),
|
||||
send_random_data_to_socket(t2.info().unwrap().remote_addr.parse().unwrap()),
|
||||
));
|
||||
|
||||
let sender1 = tokio::spawn(async move {
|
||||
let mut sink = t1.pin_sink();
|
||||
let mut stream = t1.pin_stream();
|
||||
|
||||
for i in 0..10 {
|
||||
sink.send(Bytes::from("hello1")).await.unwrap();
|
||||
let recv = stream.next().await.unwrap().unwrap();
|
||||
println!("t1 recv: {:?}, {:?}", recv, i);
|
||||
assert_eq!(recv, Bytes::from("hello1"));
|
||||
tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
|
||||
}
|
||||
});
|
||||
|
||||
let sender2 = tokio::spawn(async move {
|
||||
let mut sink = t2.pin_sink();
|
||||
let mut stream = t2.pin_stream();
|
||||
|
||||
for i in 0..10 {
|
||||
sink.send(Bytes::from("hello2")).await.unwrap();
|
||||
let recv = stream.next().await.unwrap().unwrap();
|
||||
println!("t2 recv: {:?}, {:?}", recv, i);
|
||||
assert_eq!(recv, Bytes::from("hello2"));
|
||||
tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
|
||||
}
|
||||
});
|
||||
|
||||
let _ = tokio::join!(sender1, sender2);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn udp_packet_print() {
|
||||
let udp_packet = UdpPacket::new_data_packet(1, vec![1, 2, 3, 4, 5]);
|
||||
let b = encode_to_bytes::<_, UDP_DATA_MTU>(&udp_packet);
|
||||
let a_udp_packet = rkyv_util::decode_from_bytes::<UdpPacket>(&b).unwrap();
|
||||
println!("{:?}, {:?}", udp_packet, a_udp_packet);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn bind_multi_ip_to_same_dev() {
|
||||
let global_ctx = get_mock_global_ctx();
|
||||
let ips = global_ctx
|
||||
.get_ip_collector()
|
||||
.collect_ip_addrs()
|
||||
.await
|
||||
.interface_ipv4s;
|
||||
if ips.is_empty() {
|
||||
return;
|
||||
}
|
||||
let bind_dev = get_interface_name_by_ip(&ips[0].parse().unwrap());
|
||||
|
||||
for ip in ips {
|
||||
println!("bind to ip: {:?}, {:?}", ip, bind_dev);
|
||||
let addr = check_scheme_and_get_socket_addr::<SocketAddr>(
|
||||
&format!("udp://{}:11111", ip).parse().unwrap(),
|
||||
"udp",
|
||||
)
|
||||
.unwrap();
|
||||
let socket2_socket = socket2::Socket::new(
|
||||
socket2::Domain::for_address(addr),
|
||||
socket2::Type::DGRAM,
|
||||
Some(socket2::Protocol::UDP),
|
||||
)
|
||||
.unwrap();
|
||||
setup_sokcet2_ext(&socket2_socket, &addr, bind_dev.clone()).unwrap();
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,797 @@
|
||||
use std::{
|
||||
collections::hash_map::DefaultHasher,
|
||||
fmt::{Debug, Formatter},
|
||||
hash::Hasher,
|
||||
net::SocketAddr,
|
||||
pin::Pin,
|
||||
sync::Arc,
|
||||
time::Duration,
|
||||
};
|
||||
|
||||
use anyhow::Context;
|
||||
use async_recursion::async_recursion;
|
||||
use async_trait::async_trait;
|
||||
use boringtun::{
|
||||
noise::{errors::WireGuardError, Tunn, TunnResult},
|
||||
x25519::{PublicKey, StaticSecret},
|
||||
};
|
||||
use dashmap::DashMap;
|
||||
use futures::{stream::FuturesUnordered, SinkExt, StreamExt};
|
||||
use rand::RngCore;
|
||||
use tokio::{net::UdpSocket, sync::Mutex, task::JoinSet};
|
||||
|
||||
use crate::{
|
||||
rpc::TunnelInfo,
|
||||
tunnels::{build_url_from_socket_addr, common::TunnelWithCustomInfo},
|
||||
};
|
||||
|
||||
use super::{
|
||||
check_scheme_and_get_socket_addr,
|
||||
common::{setup_sokcet2, setup_sokcet2_ext},
|
||||
ring_tunnel::create_ring_tunnel_pair,
|
||||
DatagramSink, DatagramStream, Tunnel, TunnelError, TunnelListener, TunnelUrl,
|
||||
};
|
||||
|
||||
const MAX_PACKET: usize = 65500;
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
enum WgType {
|
||||
// used by easytier peer, need remove/add ip header for in/out wg msg
|
||||
InternalUse,
|
||||
// used by wireguard peer, keep original ip header
|
||||
ExternalUse,
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct WgConfig {
|
||||
my_secret_key: StaticSecret,
|
||||
my_public_key: PublicKey,
|
||||
|
||||
peer_secret_key: StaticSecret,
|
||||
peer_public_key: PublicKey,
|
||||
|
||||
wg_type: WgType,
|
||||
}
|
||||
|
||||
impl WgConfig {
|
||||
pub fn new_from_network_identity(network_name: &str, network_secret: &str) -> Self {
|
||||
let mut my_sec = [0u8; 32];
|
||||
let mut hasher = DefaultHasher::new();
|
||||
hasher.write(network_name.as_bytes());
|
||||
hasher.write(network_secret.as_bytes());
|
||||
my_sec[0..8].copy_from_slice(&hasher.finish().to_be_bytes());
|
||||
hasher.write(&my_sec[0..8]);
|
||||
my_sec[8..16].copy_from_slice(&hasher.finish().to_be_bytes());
|
||||
hasher.write(&my_sec[0..16]);
|
||||
my_sec[16..24].copy_from_slice(&hasher.finish().to_be_bytes());
|
||||
hasher.write(&my_sec[0..24]);
|
||||
my_sec[24..32].copy_from_slice(&hasher.finish().to_be_bytes());
|
||||
|
||||
let my_secret_key = StaticSecret::from(my_sec);
|
||||
let my_public_key = PublicKey::from(&my_secret_key);
|
||||
let peer_secret_key = StaticSecret::from(my_sec);
|
||||
let peer_public_key = my_public_key.clone();
|
||||
|
||||
WgConfig {
|
||||
my_secret_key,
|
||||
my_public_key,
|
||||
peer_secret_key,
|
||||
peer_public_key,
|
||||
|
||||
wg_type: WgType::InternalUse,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn new_for_portal(server_key_seed: &str, client_key_seed: &str) -> Self {
|
||||
let server_cfg = Self::new_from_network_identity("server", server_key_seed);
|
||||
let client_cfg = Self::new_from_network_identity("client", client_key_seed);
|
||||
Self {
|
||||
my_secret_key: server_cfg.my_secret_key,
|
||||
my_public_key: server_cfg.my_public_key,
|
||||
peer_secret_key: client_cfg.my_secret_key,
|
||||
peer_public_key: client_cfg.my_public_key,
|
||||
|
||||
wg_type: WgType::ExternalUse,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn my_secret_key(&self) -> &[u8] {
|
||||
self.my_secret_key.as_bytes()
|
||||
}
|
||||
|
||||
pub fn peer_secret_key(&self) -> &[u8] {
|
||||
self.peer_secret_key.as_bytes()
|
||||
}
|
||||
|
||||
pub fn my_public_key(&self) -> &[u8] {
|
||||
self.my_public_key.as_bytes()
|
||||
}
|
||||
|
||||
pub fn peer_public_key(&self) -> &[u8] {
|
||||
self.peer_public_key.as_bytes()
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
struct WgPeerData {
|
||||
udp: Arc<UdpSocket>, // only for send
|
||||
endpoint: SocketAddr,
|
||||
tunn: Arc<Mutex<Tunn>>,
|
||||
sink: Arc<Mutex<Pin<Box<dyn DatagramSink>>>>,
|
||||
stream: Arc<Mutex<Pin<Box<dyn DatagramStream>>>>,
|
||||
wg_type: WgType,
|
||||
}
|
||||
|
||||
impl Debug for WgPeerData {
|
||||
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
|
||||
f.debug_struct("WgPeerData")
|
||||
.field("endpoint", &self.endpoint)
|
||||
.field("local", &self.udp.local_addr())
|
||||
.finish()
|
||||
}
|
||||
}
|
||||
|
||||
impl WgPeerData {
|
||||
#[tracing::instrument]
|
||||
async fn handle_one_packet_from_me(&self, packet: &[u8]) -> Result<(), anyhow::Error> {
|
||||
let mut send_buf = vec![0u8; MAX_PACKET];
|
||||
|
||||
let encapsulate_result = {
|
||||
let mut peer = self.tunn.lock().await;
|
||||
if matches!(self.wg_type, WgType::InternalUse) {
|
||||
peer.encapsulate(&self.add_ip_header(&packet), &mut send_buf)
|
||||
} else {
|
||||
peer.encapsulate(&packet, &mut send_buf)
|
||||
}
|
||||
};
|
||||
|
||||
tracing::trace!(
|
||||
?encapsulate_result,
|
||||
"Received {} bytes from me",
|
||||
packet.len()
|
||||
);
|
||||
|
||||
match encapsulate_result {
|
||||
TunnResult::WriteToNetwork(packet) => {
|
||||
self.udp
|
||||
.send_to(packet, self.endpoint)
|
||||
.await
|
||||
.context("Failed to send encrypted IP packet to WireGuard endpoint.")?;
|
||||
tracing::debug!(
|
||||
"Sent {} bytes to WireGuard endpoint (encrypted IP packet)",
|
||||
packet.len()
|
||||
);
|
||||
}
|
||||
TunnResult::Err(e) => {
|
||||
tracing::error!("Failed to encapsulate IP packet: {:?}", e);
|
||||
}
|
||||
TunnResult::Done => {
|
||||
// Ignored
|
||||
}
|
||||
other => {
|
||||
tracing::error!(
|
||||
"Unexpected WireGuard state during encapsulation: {:?}",
|
||||
other
|
||||
);
|
||||
}
|
||||
};
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// WireGuard consumption task. Receives encrypted packets from the WireGuard endpoint,
|
||||
/// decapsulates them, and dispatches newly received IP packets.
|
||||
#[tracing::instrument]
|
||||
pub async fn handle_one_packet_from_peer(&self, recv_buf: &[u8]) {
|
||||
let mut send_buf = vec![0u8; MAX_PACKET];
|
||||
let data = &recv_buf[..];
|
||||
let decapsulate_result = {
|
||||
let mut peer = self.tunn.lock().await;
|
||||
peer.decapsulate(None, data, &mut send_buf)
|
||||
};
|
||||
|
||||
tracing::debug!("Decapsulation result: {:?}", decapsulate_result);
|
||||
|
||||
match decapsulate_result {
|
||||
TunnResult::WriteToNetwork(packet) => {
|
||||
match self.udp.send_to(packet, self.endpoint).await {
|
||||
Ok(_) => {}
|
||||
Err(e) => {
|
||||
tracing::error!("Failed to send decapsulation-instructed packet to WireGuard endpoint: {:?}", e);
|
||||
return;
|
||||
}
|
||||
};
|
||||
let mut peer = self.tunn.lock().await;
|
||||
loop {
|
||||
let mut send_buf = vec![0u8; MAX_PACKET];
|
||||
match peer.decapsulate(None, &[], &mut send_buf) {
|
||||
TunnResult::WriteToNetwork(packet) => {
|
||||
match self.udp.send_to(packet, self.endpoint).await {
|
||||
Ok(_) => {}
|
||||
Err(e) => {
|
||||
tracing::error!("Failed to send decapsulation-instructed packet to WireGuard endpoint: {:?}", e);
|
||||
break;
|
||||
}
|
||||
};
|
||||
}
|
||||
_ => {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
TunnResult::WriteToTunnelV4(packet, _) | TunnResult::WriteToTunnelV6(packet, _) => {
|
||||
tracing::debug!(
|
||||
"WireGuard endpoint sent IP packet of {} bytes",
|
||||
packet.len()
|
||||
);
|
||||
let ret = self
|
||||
.sink
|
||||
.lock()
|
||||
.await
|
||||
.send(
|
||||
if matches!(self.wg_type, WgType::InternalUse) {
|
||||
self.remove_ip_header(packet, packet[0] >> 4 == 4)
|
||||
} else {
|
||||
packet
|
||||
}
|
||||
.to_vec()
|
||||
.into(),
|
||||
)
|
||||
.await;
|
||||
if ret.is_err() {
|
||||
tracing::error!("Failed to send packet to tunnel: {:?}", ret);
|
||||
}
|
||||
}
|
||||
_ => {
|
||||
tracing::warn!(
|
||||
"Unexpected WireGuard state during decapsulation: {:?}",
|
||||
decapsulate_result
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[tracing::instrument]
|
||||
#[async_recursion]
|
||||
async fn handle_routine_tun_result<'a: 'async_recursion>(&self, result: TunnResult<'a>) -> () {
|
||||
match result {
|
||||
TunnResult::WriteToNetwork(packet) => {
|
||||
tracing::debug!(
|
||||
"Sending routine packet of {} bytes to WireGuard endpoint",
|
||||
packet.len()
|
||||
);
|
||||
match self.udp.send_to(packet, self.endpoint).await {
|
||||
Ok(_) => {}
|
||||
Err(e) => {
|
||||
tracing::error!(
|
||||
"Failed to send routine packet to WireGuard endpoint: {:?}",
|
||||
e
|
||||
);
|
||||
}
|
||||
};
|
||||
}
|
||||
TunnResult::Err(WireGuardError::ConnectionExpired) => {
|
||||
tracing::warn!("Wireguard handshake has expired!");
|
||||
|
||||
let mut buf = vec![0u8; MAX_PACKET];
|
||||
let result = self
|
||||
.tunn
|
||||
.lock()
|
||||
.await
|
||||
.format_handshake_initiation(&mut buf[..], false);
|
||||
|
||||
self.handle_routine_tun_result(result).await
|
||||
}
|
||||
TunnResult::Err(e) => {
|
||||
tracing::error!(
|
||||
"Failed to prepare routine packet for WireGuard endpoint: {:?}",
|
||||
e
|
||||
);
|
||||
}
|
||||
TunnResult::Done => {
|
||||
// Sleep for a bit
|
||||
tokio::time::sleep(Duration::from_millis(250)).await;
|
||||
}
|
||||
other => {
|
||||
tracing::warn!("Unexpected WireGuard routine task state: {:?}", other);
|
||||
tokio::time::sleep(Duration::from_millis(250)).await;
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
/// WireGuard Routine task. Handles Handshake, keep-alive, etc.
|
||||
pub async fn routine_task(self) {
|
||||
loop {
|
||||
let mut send_buf = vec![0u8; MAX_PACKET];
|
||||
let tun_result = { self.tunn.lock().await.update_timers(&mut send_buf) };
|
||||
self.handle_routine_tun_result(tun_result).await;
|
||||
}
|
||||
}
|
||||
|
||||
fn add_ip_header(&self, packet: &[u8]) -> Vec<u8> {
|
||||
let mut ret = vec![0u8; packet.len() + 20];
|
||||
let ip_header = ret.as_mut_slice();
|
||||
ip_header[0] = 0x45;
|
||||
ip_header[1] = 0;
|
||||
ip_header[2..4].copy_from_slice(&((packet.len() + 20) as u16).to_be_bytes());
|
||||
ip_header[4..6].copy_from_slice(&0u16.to_be_bytes());
|
||||
ip_header[6..8].copy_from_slice(&0u16.to_be_bytes());
|
||||
ip_header[8] = 64;
|
||||
ip_header[9] = 0;
|
||||
ip_header[10..12].copy_from_slice(&0u16.to_be_bytes());
|
||||
ip_header[12..16].copy_from_slice(&0u32.to_be_bytes());
|
||||
ip_header[16..20].copy_from_slice(&0u32.to_be_bytes());
|
||||
ip_header[20..].copy_from_slice(packet);
|
||||
ret
|
||||
}
|
||||
|
||||
fn remove_ip_header<'a>(&self, packet: &'a [u8], is_v4: bool) -> &'a [u8] {
|
||||
if is_v4 {
|
||||
return &packet[20..];
|
||||
} else {
|
||||
return &packet[40..];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
struct WgPeer {
|
||||
udp: Arc<UdpSocket>, // only for send
|
||||
config: WgConfig,
|
||||
endpoint: SocketAddr,
|
||||
|
||||
data: Option<WgPeerData>,
|
||||
tasks: JoinSet<()>,
|
||||
|
||||
access_time: std::time::Instant,
|
||||
}
|
||||
|
||||
impl WgPeer {
|
||||
fn new(udp: Arc<UdpSocket>, config: WgConfig, endpoint: SocketAddr) -> Self {
|
||||
WgPeer {
|
||||
udp,
|
||||
config,
|
||||
endpoint,
|
||||
|
||||
data: None,
|
||||
tasks: JoinSet::new(),
|
||||
|
||||
access_time: std::time::Instant::now(),
|
||||
}
|
||||
}
|
||||
|
||||
async fn handle_packet_from_me(data: WgPeerData) {
|
||||
while let Some(Ok(packet)) = data.stream.lock().await.next().await {
|
||||
let ret = data.handle_one_packet_from_me(&packet).await;
|
||||
if let Err(e) = ret {
|
||||
tracing::error!("Failed to handle packet from me: {}", e);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn handle_packet_from_peer(&mut self, packet: &[u8]) {
|
||||
self.access_time = std::time::Instant::now();
|
||||
tracing::trace!("Received {} bytes from peer", packet.len());
|
||||
let data = self.data.as_ref().unwrap();
|
||||
data.handle_one_packet_from_peer(packet).await;
|
||||
}
|
||||
|
||||
fn start_and_get_tunnel(&mut self) -> Box<dyn Tunnel> {
|
||||
let (stunnel, ctunnel) = create_ring_tunnel_pair();
|
||||
|
||||
let data = WgPeerData {
|
||||
udp: self.udp.clone(),
|
||||
endpoint: self.endpoint,
|
||||
tunn: Arc::new(Mutex::new(
|
||||
Tunn::new(
|
||||
self.config.my_secret_key.clone(),
|
||||
self.config.peer_public_key.clone(),
|
||||
None,
|
||||
None,
|
||||
rand::thread_rng().next_u32(),
|
||||
None,
|
||||
)
|
||||
.unwrap(),
|
||||
)),
|
||||
sink: Arc::new(Mutex::new(stunnel.pin_sink())),
|
||||
stream: Arc::new(Mutex::new(stunnel.pin_stream())),
|
||||
wg_type: self.config.wg_type.clone(),
|
||||
};
|
||||
|
||||
self.data = Some(data.clone());
|
||||
self.tasks.spawn(Self::handle_packet_from_me(data.clone()));
|
||||
self.tasks.spawn(data.routine_task());
|
||||
|
||||
ctunnel
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for WgPeer {
|
||||
fn drop(&mut self) {
|
||||
self.tasks.abort_all();
|
||||
if let Some(data) = self.data.clone() {
|
||||
tokio::spawn(async move {
|
||||
let _ = data.sink.lock().await.close().await;
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
type ConnSender = tokio::sync::mpsc::UnboundedSender<Box<dyn Tunnel>>;
|
||||
type ConnReceiver = tokio::sync::mpsc::UnboundedReceiver<Box<dyn Tunnel>>;
|
||||
|
||||
pub struct WgTunnelListener {
|
||||
addr: url::Url,
|
||||
config: WgConfig,
|
||||
|
||||
udp: Option<Arc<UdpSocket>>,
|
||||
conn_recv: ConnReceiver,
|
||||
conn_send: Option<ConnSender>,
|
||||
|
||||
tasks: JoinSet<()>,
|
||||
}
|
||||
|
||||
impl WgTunnelListener {
|
||||
pub fn new(addr: url::Url, config: WgConfig) -> Self {
|
||||
let (conn_send, conn_recv) = tokio::sync::mpsc::unbounded_channel();
|
||||
WgTunnelListener {
|
||||
addr,
|
||||
config,
|
||||
|
||||
udp: None,
|
||||
conn_recv,
|
||||
conn_send: Some(conn_send),
|
||||
|
||||
tasks: JoinSet::new(),
|
||||
}
|
||||
}
|
||||
|
||||
fn get_udp_socket(&self) -> Arc<UdpSocket> {
|
||||
self.udp.as_ref().unwrap().clone()
|
||||
}
|
||||
|
||||
async fn handle_udp_incoming(
|
||||
socket: Arc<UdpSocket>,
|
||||
config: WgConfig,
|
||||
conn_sender: ConnSender,
|
||||
) {
|
||||
let mut tasks = JoinSet::new();
|
||||
let peer_map: Arc<DashMap<SocketAddr, WgPeer>> = Arc::new(DashMap::new());
|
||||
|
||||
let peer_map_clone = peer_map.clone();
|
||||
tasks.spawn(async move {
|
||||
loop {
|
||||
peer_map_clone.retain(|_, peer| peer.access_time.elapsed().as_secs() < 600);
|
||||
tokio::time::sleep(Duration::from_secs(60)).await;
|
||||
}
|
||||
});
|
||||
|
||||
let mut buf = vec![0u8; MAX_PACKET];
|
||||
loop {
|
||||
let Ok((n, addr)) = socket.recv_from(&mut buf).await else {
|
||||
tracing::error!("Failed to receive from UDP socket");
|
||||
break;
|
||||
};
|
||||
|
||||
let data = &buf[..n];
|
||||
tracing::trace!("Received {} bytes from {}", n, addr);
|
||||
|
||||
if !peer_map.contains_key(&addr) {
|
||||
tracing::info!("New peer: {}", addr);
|
||||
let mut wg = WgPeer::new(socket.clone(), config.clone(), addr.clone());
|
||||
let tunnel = Box::new(TunnelWithCustomInfo::new(
|
||||
wg.start_and_get_tunnel(),
|
||||
TunnelInfo {
|
||||
tunnel_type: "wg".to_owned(),
|
||||
local_addr: build_url_from_socket_addr(
|
||||
&socket.local_addr().unwrap().to_string(),
|
||||
"wg",
|
||||
)
|
||||
.into(),
|
||||
remote_addr: build_url_from_socket_addr(&addr.to_string(), "wg").into(),
|
||||
},
|
||||
));
|
||||
if let Err(e) = conn_sender.send(tunnel) {
|
||||
tracing::error!("Failed to send tunnel to conn_sender: {}", e);
|
||||
}
|
||||
peer_map.insert(addr, wg);
|
||||
}
|
||||
|
||||
let mut peer = peer_map.get_mut(&addr).unwrap();
|
||||
peer.handle_packet_from_peer(data).await;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl TunnelListener for WgTunnelListener {
|
||||
async fn listen(&mut self) -> Result<(), super::TunnelError> {
|
||||
let addr = check_scheme_and_get_socket_addr::<SocketAddr>(&self.addr, "wg")?;
|
||||
let socket2_socket = socket2::Socket::new(
|
||||
socket2::Domain::for_address(addr),
|
||||
socket2::Type::DGRAM,
|
||||
Some(socket2::Protocol::UDP),
|
||||
)?;
|
||||
|
||||
let tunnel_url: TunnelUrl = self.addr.clone().into();
|
||||
if let Some(bind_dev) = tunnel_url.bind_dev() {
|
||||
setup_sokcet2_ext(&socket2_socket, &addr, Some(bind_dev))?;
|
||||
} else {
|
||||
setup_sokcet2(&socket2_socket, &addr)?;
|
||||
}
|
||||
|
||||
self.udp = Some(Arc::new(UdpSocket::from_std(socket2_socket.into())?));
|
||||
self.tasks.spawn(Self::handle_udp_incoming(
|
||||
self.get_udp_socket(),
|
||||
self.config.clone(),
|
||||
self.conn_send.take().unwrap(),
|
||||
));
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn accept(&mut self) -> Result<Box<dyn Tunnel>, super::TunnelError> {
|
||||
while let Some(tunnel) = self.conn_recv.recv().await {
|
||||
tracing::info!(?tunnel, "Accepted tunnel");
|
||||
return Ok(tunnel);
|
||||
}
|
||||
Err(TunnelError::CommonError(
|
||||
"Failed to accept tunnel".to_string(),
|
||||
))
|
||||
}
|
||||
|
||||
fn local_url(&self) -> url::Url {
|
||||
self.addr.clone()
|
||||
}
|
||||
}
|
||||
|
||||
pub struct WgClientTunnel {
|
||||
wg_peer: WgPeer,
|
||||
tunnel: Box<dyn Tunnel>,
|
||||
info: TunnelInfo,
|
||||
}
|
||||
|
||||
impl Tunnel for WgClientTunnel {
|
||||
fn stream(&self) -> Box<dyn DatagramStream> {
|
||||
self.tunnel.stream()
|
||||
}
|
||||
|
||||
fn sink(&self) -> Box<dyn DatagramSink> {
|
||||
self.tunnel.sink()
|
||||
}
|
||||
|
||||
fn info(&self) -> Option<TunnelInfo> {
|
||||
Some(self.info.clone())
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct WgTunnelConnector {
|
||||
addr: url::Url,
|
||||
config: WgConfig,
|
||||
udp: Option<Arc<UdpSocket>>,
|
||||
|
||||
bind_addrs: Vec<SocketAddr>,
|
||||
}
|
||||
|
||||
impl Debug for WgTunnelConnector {
|
||||
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
|
||||
f.debug_struct("WgTunnelConnector")
|
||||
.field("addr", &self.addr)
|
||||
.field("udp", &self.udp)
|
||||
.finish()
|
||||
}
|
||||
}
|
||||
|
||||
impl WgTunnelConnector {
|
||||
pub fn new(addr: url::Url, config: WgConfig) -> Self {
|
||||
WgTunnelConnector {
|
||||
addr,
|
||||
config,
|
||||
udp: None,
|
||||
bind_addrs: vec![],
|
||||
}
|
||||
}
|
||||
|
||||
fn create_handshake_init(tun: &mut Tunn) -> Vec<u8> {
|
||||
let mut dst = vec![0u8; 2048];
|
||||
let handshake_init = tun.format_handshake_initiation(&mut dst, false);
|
||||
assert!(matches!(handshake_init, TunnResult::WriteToNetwork(_)));
|
||||
let handshake_init = if let TunnResult::WriteToNetwork(sent) = handshake_init {
|
||||
sent
|
||||
} else {
|
||||
unreachable!();
|
||||
};
|
||||
|
||||
handshake_init.into()
|
||||
}
|
||||
|
||||
fn parse_handshake_resp(tun: &mut Tunn, handshake_resp: &[u8]) -> Vec<u8> {
|
||||
let mut dst = vec![0u8; 2048];
|
||||
let keepalive = tun.decapsulate(None, handshake_resp, &mut dst);
|
||||
assert!(
|
||||
matches!(keepalive, TunnResult::WriteToNetwork(_)),
|
||||
"Failed to parse handshake response, {:?}",
|
||||
keepalive
|
||||
);
|
||||
|
||||
let keepalive = if let TunnResult::WriteToNetwork(sent) = keepalive {
|
||||
sent
|
||||
} else {
|
||||
unreachable!();
|
||||
};
|
||||
|
||||
keepalive.into()
|
||||
}
|
||||
|
||||
#[tracing::instrument(skip(config))]
|
||||
async fn connect_with_socket(
|
||||
addr_url: url::Url,
|
||||
config: WgConfig,
|
||||
udp: UdpSocket,
|
||||
) -> Result<Box<dyn super::Tunnel>, super::TunnelError> {
|
||||
let addr = super::check_scheme_and_get_socket_addr::<SocketAddr>(&addr_url, "wg")?;
|
||||
tracing::warn!("wg connect: {:?}", addr);
|
||||
let local_addr = udp.local_addr().unwrap().to_string();
|
||||
|
||||
let mut my_tun = Tunn::new(
|
||||
config.my_secret_key.clone(),
|
||||
config.peer_public_key.clone(),
|
||||
None,
|
||||
None,
|
||||
rand::thread_rng().next_u32(),
|
||||
None,
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let init = Self::create_handshake_init(&mut my_tun);
|
||||
udp.send_to(&init, addr).await?;
|
||||
|
||||
let mut buf = vec![0u8; MAX_PACKET];
|
||||
let (n, _) = udp.recv_from(&mut buf).await.unwrap();
|
||||
let keepalive = Self::parse_handshake_resp(&mut my_tun, &buf[..n]);
|
||||
udp.send_to(&keepalive, addr).await?;
|
||||
|
||||
let mut wg_peer = WgPeer::new(Arc::new(udp), config.clone(), addr);
|
||||
let tunnel = wg_peer.start_and_get_tunnel();
|
||||
|
||||
let data = wg_peer.data.as_ref().unwrap().clone();
|
||||
wg_peer.tasks.spawn(async move {
|
||||
loop {
|
||||
let mut buf = vec![0u8; MAX_PACKET];
|
||||
let (n, recv_addr) = data.udp.recv_from(&mut buf).await.unwrap();
|
||||
if recv_addr != addr {
|
||||
continue;
|
||||
}
|
||||
data.handle_one_packet_from_peer(&buf[..n]).await;
|
||||
}
|
||||
});
|
||||
|
||||
let ret = Box::new(WgClientTunnel {
|
||||
wg_peer,
|
||||
tunnel,
|
||||
info: TunnelInfo {
|
||||
tunnel_type: "wg".to_owned(),
|
||||
local_addr: super::build_url_from_socket_addr(&local_addr, "wg").into(),
|
||||
remote_addr: addr_url.to_string(),
|
||||
},
|
||||
});
|
||||
|
||||
Ok(ret)
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl super::TunnelConnector for WgTunnelConnector {
|
||||
#[tracing::instrument]
|
||||
async fn connect(&mut self) -> Result<Box<dyn super::Tunnel>, super::TunnelError> {
|
||||
let bind_addrs = if self.bind_addrs.is_empty() {
|
||||
vec!["0.0.0.0:0".parse().unwrap()]
|
||||
} else {
|
||||
self.bind_addrs.clone()
|
||||
};
|
||||
let mut futures = FuturesUnordered::new();
|
||||
|
||||
for bind_addr in bind_addrs.into_iter() {
|
||||
let socket2_socket = socket2::Socket::new(
|
||||
socket2::Domain::for_address(bind_addr),
|
||||
socket2::Type::DGRAM,
|
||||
Some(socket2::Protocol::UDP),
|
||||
)?;
|
||||
setup_sokcet2(&socket2_socket, &bind_addr)?;
|
||||
let socket = UdpSocket::from_std(socket2_socket.into())?;
|
||||
tracing::info!(?bind_addr, ?self.addr, "prepare wg connect task");
|
||||
futures.push(Self::connect_with_socket(
|
||||
self.addr.clone(),
|
||||
self.config.clone(),
|
||||
socket,
|
||||
));
|
||||
}
|
||||
|
||||
let Some(ret) = futures.next().await else {
|
||||
return Err(super::TunnelError::CommonError(
|
||||
"join connect futures failed".to_owned(),
|
||||
));
|
||||
};
|
||||
|
||||
return ret;
|
||||
}
|
||||
|
||||
fn remote_url(&self) -> url::Url {
|
||||
self.addr.clone()
|
||||
}
|
||||
|
||||
fn set_bind_addrs(&mut self, addrs: Vec<SocketAddr>) {
|
||||
self.bind_addrs = addrs;
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
pub mod tests {
|
||||
use boringtun::*;
|
||||
|
||||
use crate::tunnels::common::tests::{_tunnel_bench, _tunnel_pingpong};
|
||||
use crate::tunnels::{wireguard::*, TunnelConnector};
|
||||
|
||||
pub fn create_wg_config() -> (WgConfig, WgConfig) {
|
||||
let my_secret_key = x25519::StaticSecret::random_from_rng(rand::thread_rng());
|
||||
let my_public_key = x25519::PublicKey::from(&my_secret_key);
|
||||
|
||||
let their_secret_key = x25519::StaticSecret::random_from_rng(rand::thread_rng());
|
||||
let their_public_key = x25519::PublicKey::from(&their_secret_key);
|
||||
|
||||
let server_cfg = WgConfig {
|
||||
my_secret_key: my_secret_key.clone(),
|
||||
my_public_key,
|
||||
peer_secret_key: their_secret_key.clone(),
|
||||
peer_public_key: their_public_key.clone(),
|
||||
wg_type: WgType::InternalUse,
|
||||
};
|
||||
|
||||
let client_cfg = WgConfig {
|
||||
my_secret_key: their_secret_key,
|
||||
my_public_key: their_public_key,
|
||||
peer_secret_key: my_secret_key,
|
||||
peer_public_key: my_public_key,
|
||||
wg_type: WgType::InternalUse,
|
||||
};
|
||||
|
||||
(server_cfg, client_cfg)
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn wg_pingpong() {
|
||||
let (server_cfg, client_cfg) = create_wg_config();
|
||||
let listener = WgTunnelListener::new("wg://0.0.0.0:5599".parse().unwrap(), server_cfg);
|
||||
let connector = WgTunnelConnector::new("wg://127.0.0.1:5599".parse().unwrap(), client_cfg);
|
||||
_tunnel_pingpong(listener, connector).await
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn wg_bench() {
|
||||
let (server_cfg, client_cfg) = create_wg_config();
|
||||
let listener = WgTunnelListener::new("wg://0.0.0.0:5598".parse().unwrap(), server_cfg);
|
||||
let connector = WgTunnelConnector::new("wg://127.0.0.1:5598".parse().unwrap(), client_cfg);
|
||||
_tunnel_bench(listener, connector).await
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn wg_bench_with_bind() {
|
||||
let (server_cfg, client_cfg) = create_wg_config();
|
||||
let listener = WgTunnelListener::new("wg://127.0.0.1:5597".parse().unwrap(), server_cfg);
|
||||
let mut connector =
|
||||
WgTunnelConnector::new("wg://127.0.0.1:5597".parse().unwrap(), client_cfg);
|
||||
connector.set_bind_addrs(vec!["127.0.0.1:0".parse().unwrap()]);
|
||||
_tunnel_pingpong(listener, connector).await
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
#[should_panic]
|
||||
async fn wg_bench_with_bind_fail() {
|
||||
let (server_cfg, client_cfg) = create_wg_config();
|
||||
let listener = WgTunnelListener::new("wg://127.0.0.1:5596".parse().unwrap(), server_cfg);
|
||||
let mut connector =
|
||||
WgTunnelConnector::new("wg://127.0.0.1:5596".parse().unwrap(), client_cfg);
|
||||
connector.set_bind_addrs(vec!["10.0.0.1:0".parse().unwrap()]);
|
||||
_tunnel_pingpong(listener, connector).await
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,24 @@
|
||||
// with vpn portal, user can use other vpn client to connect to easytier servers
|
||||
// without installing easytier.
|
||||
// these vpn client include:
|
||||
// 1. wireguard
|
||||
// 2. openvpn (TODO)
|
||||
// 3. shadowsocks (TODO)
|
||||
|
||||
use std::sync::Arc;
|
||||
|
||||
use crate::{common::global_ctx::ArcGlobalCtx, peers::peer_manager::PeerManager};
|
||||
|
||||
pub mod wireguard;
|
||||
|
||||
#[async_trait::async_trait]
|
||||
pub trait VpnPortal: Send + Sync {
|
||||
async fn start(
|
||||
&mut self,
|
||||
global_ctx: ArcGlobalCtx,
|
||||
peer_mgr: Arc<PeerManager>,
|
||||
) -> anyhow::Result<()>;
|
||||
async fn dump_client_config(&self, peer_mgr: Arc<PeerManager>) -> String;
|
||||
fn name(&self) -> String;
|
||||
async fn list_clients(&self) -> Vec<String>;
|
||||
}
|
||||
@@ -0,0 +1,357 @@
|
||||
use std::{
|
||||
net::{Ipv4Addr, SocketAddr},
|
||||
pin::Pin,
|
||||
sync::Arc,
|
||||
};
|
||||
|
||||
use anyhow::Context;
|
||||
use base64::{prelude::BASE64_STANDARD, Engine};
|
||||
use cidr::Ipv4Inet;
|
||||
use dashmap::DashMap;
|
||||
use futures::{SinkExt, StreamExt};
|
||||
use pnet::packet::ipv4::Ipv4Packet;
|
||||
use tokio::{sync::Mutex, task::JoinSet};
|
||||
use tokio_util::bytes::Bytes;
|
||||
|
||||
use crate::{
|
||||
common::{
|
||||
global_ctx::{ArcGlobalCtx, GlobalCtxEvent},
|
||||
join_joinset_background,
|
||||
},
|
||||
peers::{
|
||||
packet::{self, ArchivedPacket},
|
||||
peer_manager::PeerManager,
|
||||
PeerPacketFilter,
|
||||
},
|
||||
tunnels::{
|
||||
wireguard::{WgConfig, WgTunnelListener},
|
||||
DatagramSink, Tunnel, TunnelListener,
|
||||
},
|
||||
};
|
||||
|
||||
use super::VpnPortal;
|
||||
|
||||
type WgPeerIpTable = Arc<DashMap<Ipv4Addr, Arc<ClientEntry>>>;
|
||||
|
||||
struct ClientEntry {
|
||||
endpoint_addr: Option<url::Url>,
|
||||
sink: Mutex<Pin<Box<dyn DatagramSink + 'static>>>,
|
||||
}
|
||||
|
||||
struct WireGuardImpl {
|
||||
global_ctx: ArcGlobalCtx,
|
||||
peer_mgr: Arc<PeerManager>,
|
||||
wg_config: WgConfig,
|
||||
listenr_addr: SocketAddr,
|
||||
|
||||
wg_peer_ip_table: WgPeerIpTable,
|
||||
|
||||
tasks: Arc<std::sync::Mutex<JoinSet<()>>>,
|
||||
}
|
||||
|
||||
impl WireGuardImpl {
|
||||
fn new(global_ctx: ArcGlobalCtx, peer_mgr: Arc<PeerManager>) -> Self {
|
||||
let nid = global_ctx.get_network_identity();
|
||||
let key_seed = format!("{}{}", nid.network_name, nid.network_secret);
|
||||
let wg_config = WgConfig::new_for_portal(&key_seed, &key_seed);
|
||||
|
||||
let vpn_cfg = global_ctx.config.get_vpn_portal_config().unwrap();
|
||||
let listenr_addr = vpn_cfg.wireguard_listen;
|
||||
|
||||
Self {
|
||||
global_ctx,
|
||||
peer_mgr,
|
||||
wg_config,
|
||||
listenr_addr,
|
||||
wg_peer_ip_table: Arc::new(DashMap::new()),
|
||||
tasks: Arc::new(std::sync::Mutex::new(JoinSet::new())),
|
||||
}
|
||||
}
|
||||
|
||||
async fn handle_incoming_conn(
|
||||
t: Box<dyn Tunnel>,
|
||||
peer_mgr: Arc<PeerManager>,
|
||||
wg_peer_ip_table: WgPeerIpTable,
|
||||
) {
|
||||
let mut s = t.pin_stream();
|
||||
let mut ip_registered = false;
|
||||
|
||||
let info = t.info().unwrap_or_default();
|
||||
let remote_addr = info.remote_addr.clone();
|
||||
peer_mgr
|
||||
.get_global_ctx()
|
||||
.issue_event(GlobalCtxEvent::VpnPortalClientConnected(
|
||||
info.local_addr,
|
||||
info.remote_addr,
|
||||
));
|
||||
|
||||
while let Some(Ok(msg)) = s.next().await {
|
||||
let Some(i) = Ipv4Packet::new(&msg) else {
|
||||
tracing::error!(?msg, "Failed to parse ipv4 packet");
|
||||
continue;
|
||||
};
|
||||
if !ip_registered {
|
||||
let client_entry = Arc::new(ClientEntry {
|
||||
endpoint_addr: remote_addr.parse().ok(),
|
||||
sink: Mutex::new(t.pin_sink()),
|
||||
});
|
||||
wg_peer_ip_table.insert(i.get_source(), client_entry.clone());
|
||||
ip_registered = true;
|
||||
}
|
||||
tracing::trace!(?i, "Received from wg client");
|
||||
let _ = peer_mgr
|
||||
.send_msg_ipv4(msg.clone(), i.get_destination())
|
||||
.await;
|
||||
}
|
||||
|
||||
let info = t.info().unwrap_or_default();
|
||||
peer_mgr
|
||||
.get_global_ctx()
|
||||
.issue_event(GlobalCtxEvent::VpnPortalClientDisconnected(
|
||||
info.local_addr,
|
||||
info.remote_addr,
|
||||
));
|
||||
}
|
||||
|
||||
async fn start_pipeline_processor(&self) {
|
||||
struct PeerPacketFilterForVpnPortal {
|
||||
wg_peer_ip_table: WgPeerIpTable,
|
||||
}
|
||||
|
||||
#[async_trait::async_trait]
|
||||
impl PeerPacketFilter for PeerPacketFilterForVpnPortal {
|
||||
async fn try_process_packet_from_peer(
|
||||
&self,
|
||||
packet: &ArchivedPacket,
|
||||
_: &Bytes,
|
||||
) -> Option<()> {
|
||||
if packet.packet_type != packet::PacketType::Data {
|
||||
return None;
|
||||
};
|
||||
|
||||
let payload_bytes = packet.payload.as_bytes();
|
||||
|
||||
let ipv4 = Ipv4Packet::new(payload_bytes)?;
|
||||
if ipv4.get_version() != 4 {
|
||||
return None;
|
||||
}
|
||||
|
||||
let entry = self.wg_peer_ip_table.get(&ipv4.get_destination())?.clone();
|
||||
|
||||
tracing::trace!(?ipv4, "Packet filter for vpn portal");
|
||||
|
||||
let ret = entry
|
||||
.sink
|
||||
.lock()
|
||||
.await
|
||||
.send(Bytes::copy_from_slice(payload_bytes))
|
||||
.await;
|
||||
|
||||
ret.ok()
|
||||
}
|
||||
}
|
||||
|
||||
self.peer_mgr
|
||||
.add_packet_process_pipeline(Box::new(PeerPacketFilterForVpnPortal {
|
||||
wg_peer_ip_table: self.wg_peer_ip_table.clone(),
|
||||
}))
|
||||
.await;
|
||||
}
|
||||
|
||||
async fn start(&self) -> anyhow::Result<()> {
|
||||
let mut l = WgTunnelListener::new(
|
||||
format!("wg://{}", self.listenr_addr).parse().unwrap(),
|
||||
self.wg_config.clone(),
|
||||
);
|
||||
|
||||
l.listen()
|
||||
.await
|
||||
.with_context(|| "Failed to start wireguard listener for vpn portal")?;
|
||||
|
||||
join_joinset_background(self.tasks.clone(), "wireguard".to_string());
|
||||
|
||||
let tasks = Arc::downgrade(&self.tasks.clone());
|
||||
let peer_mgr = self.peer_mgr.clone();
|
||||
let wg_peer_ip_table = self.wg_peer_ip_table.clone();
|
||||
self.tasks.lock().unwrap().spawn(async move {
|
||||
while let Ok(t) = l.accept().await {
|
||||
let Some(tasks) = tasks.upgrade() else {
|
||||
break;
|
||||
};
|
||||
tasks.lock().unwrap().spawn(Self::handle_incoming_conn(
|
||||
t,
|
||||
peer_mgr.clone(),
|
||||
wg_peer_ip_table.clone(),
|
||||
));
|
||||
}
|
||||
});
|
||||
|
||||
self.start_pipeline_processor().await;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Default)]
|
||||
pub struct WireGuard {
|
||||
inner: Option<WireGuardImpl>,
|
||||
}
|
||||
|
||||
#[async_trait::async_trait]
|
||||
impl VpnPortal for WireGuard {
|
||||
async fn start(
|
||||
&mut self,
|
||||
global_ctx: ArcGlobalCtx,
|
||||
peer_mgr: Arc<PeerManager>,
|
||||
) -> anyhow::Result<()> {
|
||||
assert!(self.inner.is_none());
|
||||
|
||||
let vpn_cfg = global_ctx.config.get_vpn_portal_config();
|
||||
if vpn_cfg.is_none() {
|
||||
anyhow::bail!("vpn cfg is not set for wireguard vpn portal");
|
||||
}
|
||||
|
||||
let inner = WireGuardImpl::new(global_ctx, peer_mgr);
|
||||
inner.start().await?;
|
||||
self.inner = Some(inner);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn dump_client_config(&self, peer_mgr: Arc<PeerManager>) -> String {
|
||||
if self.inner.is_none() {
|
||||
return "ERROR: Wireguard VPN Portal Not Started".to_string();
|
||||
}
|
||||
let global_ctx = self.inner.as_ref().unwrap().global_ctx.clone();
|
||||
if global_ctx.config.get_vpn_portal_config().is_none() {
|
||||
return "ERROR: VPN Portal Config Not Set".to_string();
|
||||
}
|
||||
|
||||
let routes = peer_mgr.list_routes().await;
|
||||
let mut allow_ips = routes
|
||||
.iter()
|
||||
.map(|x| x.proxy_cidrs.iter().map(String::to_string))
|
||||
.flatten()
|
||||
.collect::<Vec<_>>();
|
||||
for ipv4 in routes
|
||||
.iter()
|
||||
.map(|x| x.ipv4_addr.clone())
|
||||
.chain(global_ctx.get_ipv4().iter().map(|x| x.to_string()))
|
||||
{
|
||||
let Ok(ipv4) = ipv4.parse() else {
|
||||
continue;
|
||||
};
|
||||
let inet = Ipv4Inet::new(ipv4, 24).unwrap();
|
||||
allow_ips.push(inet.network().to_string());
|
||||
break;
|
||||
}
|
||||
|
||||
let allow_ips = allow_ips
|
||||
.into_iter()
|
||||
.map(|x| x.to_string())
|
||||
.collect::<Vec<_>>()
|
||||
.join(",");
|
||||
|
||||
let vpn_cfg = global_ctx.config.get_vpn_portal_config().unwrap();
|
||||
let client_cidr = vpn_cfg.client_cidr;
|
||||
|
||||
let cfg = self.inner.as_ref().unwrap().wg_config.clone();
|
||||
let cfg_str = format!(
|
||||
r#"
|
||||
[Interface]
|
||||
PrivateKey = {peer_secret_key}
|
||||
Address = {client_cidr} # should assign an ip from this cidr manually
|
||||
|
||||
[Peer]
|
||||
PublicKey = {my_public_key}
|
||||
AllowedIPs = {allow_ips}
|
||||
Endpoint = {listenr_addr} # should be the public ip of the vpn server
|
||||
"#,
|
||||
peer_secret_key = BASE64_STANDARD.encode(cfg.peer_secret_key()),
|
||||
my_public_key = BASE64_STANDARD.encode(cfg.my_public_key()),
|
||||
listenr_addr = self.inner.as_ref().unwrap().listenr_addr,
|
||||
allow_ips = allow_ips,
|
||||
client_cidr = client_cidr,
|
||||
);
|
||||
|
||||
cfg_str
|
||||
}
|
||||
|
||||
fn name(&self) -> String {
|
||||
"wireguard".to_string()
|
||||
}
|
||||
|
||||
async fn list_clients(&self) -> Vec<String> {
|
||||
self.inner
|
||||
.as_ref()
|
||||
.unwrap()
|
||||
.wg_peer_ip_table
|
||||
.iter()
|
||||
.map(|x| {
|
||||
x.value()
|
||||
.endpoint_addr
|
||||
.as_ref()
|
||||
.map(|x| x.to_string())
|
||||
.unwrap_or_default()
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::sync::Arc;
|
||||
|
||||
use super::*;
|
||||
|
||||
use crate::{
|
||||
common::{
|
||||
config::{NetworkIdentity, VpnPortalConfig},
|
||||
global_ctx::tests::get_mock_global_ctx_with_network,
|
||||
},
|
||||
connector::udp_hole_punch::tests::replace_stun_info_collector,
|
||||
peers::{
|
||||
peer_manager::{PeerManager, RouteAlgoType},
|
||||
tests::wait_for_condition,
|
||||
},
|
||||
rpc::NatType,
|
||||
tunnels::{tcp_tunnel::TcpTunnelConnector, TunnelConnector},
|
||||
};
|
||||
|
||||
async fn portal_test() {
|
||||
let (s, _r) = tokio::sync::mpsc::channel(1000);
|
||||
let peer_mgr = Arc::new(PeerManager::new(
|
||||
RouteAlgoType::Ospf,
|
||||
get_mock_global_ctx_with_network(Some(NetworkIdentity {
|
||||
network_name: "sijie".to_string(),
|
||||
network_secret: "1919119".to_string(),
|
||||
})),
|
||||
s,
|
||||
));
|
||||
replace_stun_info_collector(peer_mgr.clone(), NatType::Unknown);
|
||||
peer_mgr
|
||||
.get_global_ctx()
|
||||
.config
|
||||
.set_vpn_portal_config(VpnPortalConfig {
|
||||
wireguard_listen: "0.0.0.0:11021".parse().unwrap(),
|
||||
client_cidr: "10.14.14.0/24".parse().unwrap(),
|
||||
});
|
||||
peer_mgr.run().await.unwrap();
|
||||
let mut pmgr_conn = TcpTunnelConnector::new("tcp://127.0.0.1:11010".parse().unwrap());
|
||||
let tunnel = pmgr_conn.connect().await;
|
||||
peer_mgr.add_client_tunnel(tunnel.unwrap()).await.unwrap();
|
||||
wait_for_condition(
|
||||
|| async {
|
||||
let routes = peer_mgr.list_routes().await;
|
||||
println!("Routes: {:?}", routes);
|
||||
routes.len() != 0
|
||||
},
|
||||
std::time::Duration::from_secs(10),
|
||||
)
|
||||
.await;
|
||||
|
||||
let mut wg = WireGuard::default();
|
||||
wg.start(peer_mgr.get_global_ctx(), peer_mgr.clone())
|
||||
.await
|
||||
.unwrap();
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user