Initial Version

This commit is contained in:
sijie.sun
2023-09-23 01:53:45 +00:00
commit 9779923b87
63 changed files with 10840 additions and 0 deletions
+161
View File
@@ -0,0 +1,161 @@
// use filesystem as a config store
use std::{
ffi::OsStr,
io::Write,
path::{Path, PathBuf},
};
static DEFAULT_BASE_DIR: &str = "/var/lib/easytier";
static DIR_ROOT_CONFIG_FILE_NAME: &str = "__root__";
pub struct ConfigFs {
_db_name: String,
db_path: PathBuf,
}
impl ConfigFs {
pub fn new(db_name: &str) -> Self {
Self::new_with_dir(db_name, DEFAULT_BASE_DIR)
}
pub fn new_with_dir(db_name: &str, dir: &str) -> Self {
let p = Path::new(OsStr::new(dir)).join(OsStr::new(db_name));
std::fs::create_dir_all(&p).unwrap();
ConfigFs {
_db_name: db_name.to_string(),
db_path: p,
}
}
pub fn get(&self, key: &str) -> Result<String, std::io::Error> {
let path = self.db_path.join(OsStr::new(key));
// if path is dir, read the DIR_ROOT_CONFIG_FILE_NAME in it
if path.is_dir() {
let path = path.join(OsStr::new(DIR_ROOT_CONFIG_FILE_NAME));
std::fs::read_to_string(path)
} else if path.is_file() {
return std::fs::read_to_string(path);
} else {
return Err(std::io::Error::new(
std::io::ErrorKind::NotFound,
"key not found",
));
}
}
pub fn list_keys(&self, key: &str) -> Result<Vec<String>, std::io::Error> {
let path = self.db_path.join(OsStr::new(key));
let mut keys = Vec::new();
for entry in std::fs::read_dir(path)? {
let entry = entry?;
let path = entry.path();
let key = path.file_name().unwrap().to_str().unwrap().to_string();
if key != DIR_ROOT_CONFIG_FILE_NAME {
keys.push(key);
}
}
Ok(keys)
}
#[allow(dead_code)]
pub fn remove(&self, key: &str) -> Result<(), std::io::Error> {
let path = self.db_path.join(OsStr::new(key));
// if path is dir, remove the DIR_ROOT_CONFIG_FILE_NAME in it
if path.is_dir() {
std::fs::remove_dir_all(path)
} else if path.is_file() {
return std::fs::remove_file(path);
} else {
return Err(std::io::Error::new(
std::io::ErrorKind::NotFound,
"key not found",
));
}
}
pub fn add_dir(&self, key: &str) -> Result<std::fs::File, std::io::Error> {
let path = self.db_path.join(OsStr::new(key));
// if path is dir, write the DIR_ROOT_CONFIG_FILE_NAME in it
if path.is_file() {
Err(std::io::Error::new(
std::io::ErrorKind::AlreadyExists,
"key already exists",
))
} else {
std::fs::create_dir_all(&path)?;
return std::fs::File::create(path.join(OsStr::new(DIR_ROOT_CONFIG_FILE_NAME)));
}
}
pub fn add_file(&self, key: &str) -> Result<std::fs::File, std::io::Error> {
let path = self.db_path.join(OsStr::new(key));
let base_dir = path.parent().unwrap();
if !path.is_file() {
std::fs::create_dir_all(base_dir)?;
}
std::fs::File::create(path)
}
pub fn get_or_add<F>(
&self,
key: &str,
val_fn: F,
add_dir: bool,
) -> Result<String, std::io::Error>
where
F: FnOnce() -> String,
{
let get_ret = self.get(key);
match get_ret {
Ok(v) => Ok(v),
Err(e) => {
if e.kind() == std::io::ErrorKind::NotFound {
let val = val_fn();
if add_dir {
let mut f = self.add_dir(key)?;
f.write_all(val.as_bytes())?;
} else {
let mut f = self.add_file(key)?;
f.write_all(val.as_bytes())?;
}
Ok(val)
} else {
Err(e)
}
}
}
}
#[allow(dead_code)]
pub fn get_or_add_dir<F>(&self, key: &str, val_fn: F) -> Result<String, std::io::Error>
where
F: FnOnce() -> String,
{
self.get_or_add(key, val_fn, true)
}
pub fn get_or_add_file<F>(&self, key: &str, val_fn: F) -> Result<String, std::io::Error>
where
F: FnOnce() -> String,
{
self.get_or_add(key, val_fn, false)
}
pub fn get_or_default<F>(&self, key: &str, default: F) -> Result<String, std::io::Error>
where
F: FnOnce() -> String,
{
let get_ret = self.get(key);
match get_ret {
Ok(v) => Ok(v),
Err(e) => {
if e.kind() == std::io::ErrorKind::NotFound {
Ok(default())
} else {
Err(e)
}
}
}
}
}
+28
View File
@@ -0,0 +1,28 @@
pub const DIRECT_CONNECTOR_SERVICE_ID: u32 = 1;
pub const DIRECT_CONNECTOR_BLACKLIST_TIMEOUT_SEC: u64 = 60;
pub const DIRECT_CONNECTOR_IP_LIST_TIMEOUT_SEC: u64 = 60;
macro_rules! define_global_var {
($name:ident, $type:ty, $init:expr) => {
pub static $name: once_cell::sync::Lazy<tokio::sync::Mutex<$type>> =
once_cell::sync::Lazy::new(|| tokio::sync::Mutex::new($init));
};
}
#[macro_export]
macro_rules! use_global_var {
($name:ident) => {
crate::common::constants::$name.lock().await.to_owned()
};
}
#[macro_export]
macro_rules! set_global_var {
($name:ident, $val:expr) => {
*crate::common::constants::$name.lock().await = $val
};
}
define_global_var!(MANUAL_CONNECTOR_RECONNECT_INTERVAL_MS, u64, 1000);
pub const UDP_HOLE_PUNCH_CONNECTOR_SERVICE_ID: u32 = 2;
+43
View File
@@ -0,0 +1,43 @@
use std::{io, result};
use thiserror::Error;
use crate::tunnels;
#[derive(Error, Debug)]
pub enum Error {
#[error("io error")]
IOError(#[from] io::Error),
#[error("rust tun error {0}")]
TunError(#[from] tun::Error),
#[error("tunnel error {0}")]
TunnelError(#[from] tunnels::TunnelError),
#[error("Peer has no conn, PeerId: {0}")]
PeerNoConnectionError(uuid::Uuid),
#[error("RouteError: {0}")]
RouteError(String),
#[error("Not found")]
NotFound,
#[error("Invalid Url: {0}")]
InvalidUrl(String),
#[error("Shell Command error: {0}")]
ShellCommandError(String),
// #[error("Rpc listen error: {0}")]
// RpcListenError(String),
#[error("Rpc connect error: {0}")]
RpcConnectError(String),
#[error("Rpc error: {0}")]
RpcClientError(#[from] tarpc::client::RpcError),
#[error("Timeout error: {0}")]
Timeout(#[from] tokio::time::error::Elapsed),
#[error("url in blacklist")]
UrlInBlacklist,
#[error("unknown data store error")]
Unknown,
#[error("anyhow error: {0}")]
AnyhowError(#[from] anyhow::Error),
}
pub type Result<T> = result::Result<T, Error>;
// impl From for std::
+259
View File
@@ -0,0 +1,259 @@
use std::{io::Write, sync::Arc};
use crossbeam::atomic::AtomicCell;
use easytier_rpc::PeerConnInfo;
use super::{
config_fs::ConfigFs,
netns::NetNS,
network::IPCollector,
stun::{StunInfoCollector, StunInfoCollectorTrait},
};
#[derive(Debug, Clone, PartialEq)]
pub enum GlobalCtxEvent {
PeerAdded,
PeerRemoved,
PeerConnAdded(PeerConnInfo),
PeerConnRemoved(PeerConnInfo),
}
type EventBus = tokio::sync::broadcast::Sender<GlobalCtxEvent>;
type EventBusSubscriber = tokio::sync::broadcast::Receiver<GlobalCtxEvent>;
pub struct GlobalCtx {
pub inst_name: String,
pub id: uuid::Uuid,
pub config_fs: ConfigFs,
pub net_ns: NetNS,
event_bus: EventBus,
cached_ipv4: AtomicCell<Option<std::net::Ipv4Addr>>,
cached_proxy_cidrs: AtomicCell<Option<Vec<cidr::IpCidr>>>,
ip_collector: Arc<IPCollector>,
hotname: AtomicCell<Option<String>>,
stun_info_collection: Box<dyn StunInfoCollectorTrait>,
}
impl std::fmt::Debug for GlobalCtx {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("GlobalCtx")
.field("inst_name", &self.inst_name)
.field("id", &self.id)
.field("net_ns", &self.net_ns.name())
.field("event_bus", &"EventBus")
.field("ipv4", &self.cached_ipv4)
.finish()
}
}
pub type ArcGlobalCtx = std::sync::Arc<GlobalCtx>;
impl GlobalCtx {
pub fn new(inst_name: &str, config_fs: ConfigFs, net_ns: NetNS) -> Self {
let id = config_fs
.get_or_add_file("inst_id", || uuid::Uuid::new_v4().to_string())
.unwrap();
let id = uuid::Uuid::parse_str(&id).unwrap();
let (event_bus, _) = tokio::sync::broadcast::channel(100);
// NOTICE: we may need to choose stun stun server based on geo location
// stun server cross nation may return a external ip address with high latency and loss rate
let default_stun_servers = vec![
"stun.miwifi.com:3478".to_string(),
"stun.qq.com:3478".to_string(),
"stun.chat.bilibili.com:3478".to_string(),
"fwa.lifesizecloud.com:3478".to_string(),
"stun.isp.net.au:3478".to_string(),
"stun.nextcloud.com:3478".to_string(),
"stun.freeswitch.org:3478".to_string(),
"stun.voip.blackberry.com:3478".to_string(),
"stunserver.stunprotocol.org:3478".to_string(),
"stun.sipnet.com:3478".to_string(),
"stun.radiojar.com:3478".to_string(),
"stun.sonetel.com:3478".to_string(),
"stun.voipgate.com:3478".to_string(),
"stun.counterpath.com:3478".to_string(),
"180.235.108.91:3478".to_string(),
"193.22.2.248:3478".to_string(),
];
GlobalCtx {
inst_name: inst_name.to_string(),
id,
config_fs,
net_ns: net_ns.clone(),
event_bus,
cached_ipv4: AtomicCell::new(None),
cached_proxy_cidrs: AtomicCell::new(None),
ip_collector: Arc::new(IPCollector::new(net_ns)),
hotname: AtomicCell::new(None),
stun_info_collection: Box::new(StunInfoCollector::new(default_stun_servers)),
}
}
pub fn subscribe(&self) -> EventBusSubscriber {
self.event_bus.subscribe()
}
pub fn issue_event(&self, event: GlobalCtxEvent) {
if self.event_bus.receiver_count() != 0 {
self.event_bus.send(event).unwrap();
} else {
log::warn!("No subscriber for event: {:?}", event);
}
}
pub fn get_ipv4(&self) -> Option<std::net::Ipv4Addr> {
if let Some(ret) = self.cached_ipv4.load() {
return Some(ret);
}
let Ok(addr) = self.config_fs.get("ipv4") else {
return None;
};
let Ok(addr) = addr.parse() else {
tracing::error!("invalid ipv4 addr: {}", addr);
return None;
};
self.cached_ipv4.store(Some(addr));
return Some(addr);
}
pub fn set_ipv4(&mut self, addr: std::net::Ipv4Addr) {
self.config_fs
.add_file("ipv4")
.unwrap()
.write_all(addr.to_string().as_bytes())
.unwrap();
self.cached_ipv4.store(None);
}
pub fn add_proxy_cidr(&self, cidr: cidr::IpCidr) -> Result<(), std::io::Error> {
let escaped_cidr = cidr.to_string().replace("/", "_");
self.config_fs
.add_file(&format!("proxy_cidrs/{}", escaped_cidr))?;
self.cached_proxy_cidrs.store(None);
Ok(())
}
pub fn remove_proxy_cidr(&self, cidr: cidr::IpCidr) -> Result<(), std::io::Error> {
let escaped_cidr = cidr.to_string().replace("/", "_");
self.config_fs
.remove(&format!("proxy_cidrs/{}", escaped_cidr))?;
self.cached_proxy_cidrs.store(None);
Ok(())
}
pub fn get_proxy_cidrs(&self) -> Vec<cidr::IpCidr> {
if let Some(proxy_cidrs) = self.cached_proxy_cidrs.take() {
self.cached_proxy_cidrs.store(Some(proxy_cidrs.clone()));
return proxy_cidrs;
}
let Ok(keys) = self.config_fs.list_keys("proxy_cidrs") else {
return vec![];
};
let mut ret = Vec::new();
for key in keys.iter() {
let key = key.replace("_", "/");
let Ok(cidr) = key.parse() else {
tracing::error!("invalid proxy cidr: {}", key);
continue;
};
ret.push(cidr);
}
self.cached_proxy_cidrs.store(Some(ret.clone()));
ret
}
pub fn get_ip_collector(&self) -> Arc<IPCollector> {
self.ip_collector.clone()
}
pub fn get_hostname(&self) -> Option<String> {
if let Some(hostname) = self.hotname.take() {
self.hotname.store(Some(hostname.clone()));
return Some(hostname);
}
let hostname = gethostname::gethostname().to_string_lossy().to_string();
self.hotname.store(Some(hostname.clone()));
return Some(hostname);
}
pub fn get_stun_info_collector(&self) -> impl StunInfoCollectorTrait + '_ {
self.stun_info_collection.as_ref()
}
#[cfg(test)]
pub fn replace_stun_info_collector(&self, collector: Box<dyn StunInfoCollectorTrait>) {
// force replace the stun_info_collection without mut and drop the old one
let ptr = &self.stun_info_collection as *const Box<dyn StunInfoCollectorTrait>;
let ptr = ptr as *mut Box<dyn StunInfoCollectorTrait>;
unsafe {
std::ptr::drop_in_place(ptr);
std::ptr::write(ptr, collector);
}
}
pub fn get_id(&self) -> uuid::Uuid {
self.id
}
}
#[cfg(test)]
pub mod tests {
use super::*;
#[tokio::test]
async fn test_global_ctx() {
let config_fs = ConfigFs::new("/tmp/easytier");
let net_ns = NetNS::new(None);
let global_ctx = GlobalCtx::new("test", config_fs, net_ns);
let mut subscriber = global_ctx.subscribe();
global_ctx.issue_event(GlobalCtxEvent::PeerAdded);
global_ctx.issue_event(GlobalCtxEvent::PeerRemoved);
global_ctx.issue_event(GlobalCtxEvent::PeerConnAdded(PeerConnInfo::default()));
global_ctx.issue_event(GlobalCtxEvent::PeerConnRemoved(PeerConnInfo::default()));
assert_eq!(subscriber.recv().await.unwrap(), GlobalCtxEvent::PeerAdded);
assert_eq!(
subscriber.recv().await.unwrap(),
GlobalCtxEvent::PeerRemoved
);
assert_eq!(
subscriber.recv().await.unwrap(),
GlobalCtxEvent::PeerConnAdded(PeerConnInfo::default())
);
assert_eq!(
subscriber.recv().await.unwrap(),
GlobalCtxEvent::PeerConnRemoved(PeerConnInfo::default())
);
}
pub fn get_mock_global_ctx() -> ArcGlobalCtx {
let node_id = uuid::Uuid::new_v4();
let config_fs = ConfigFs::new_with_dir(node_id.to_string().as_str(), "/tmp/easytier");
let net_ns = NetNS::new(None);
std::sync::Arc::new(GlobalCtx::new(
format!("test_{}", node_id).as_str(),
config_fs,
net_ns,
))
}
}
+312
View File
@@ -0,0 +1,312 @@
use std::net::Ipv4Addr;
use async_trait::async_trait;
use tokio::process::Command;
use super::error::Error;
#[async_trait]
pub trait IfConfiguerTrait {
async fn add_ipv4_route(
&self,
name: &str,
address: Ipv4Addr,
cidr_prefix: u8,
) -> Result<(), Error>;
async fn remove_ipv4_route(
&self,
name: &str,
address: Ipv4Addr,
cidr_prefix: u8,
) -> Result<(), Error>;
async fn add_ipv4_ip(
&self,
name: &str,
address: Ipv4Addr,
cidr_prefix: u8,
) -> Result<(), Error>;
async fn set_link_status(&self, name: &str, up: bool) -> Result<(), Error>;
async fn remove_ip(&self, name: &str, ip: Option<Ipv4Addr>) -> Result<(), Error>;
async fn wait_interface_show(&self, _name: &str) -> Result<(), Error> {
return Ok(());
}
}
fn cidr_to_subnet_mask(prefix_length: u8) -> Ipv4Addr {
if prefix_length > 32 {
panic!("Invalid CIDR prefix length");
}
let subnet_mask: u32 = (!0u32)
.checked_shl(32 - u32::from(prefix_length))
.unwrap_or(0);
Ipv4Addr::new(
((subnet_mask >> 24) & 0xFF) as u8,
((subnet_mask >> 16) & 0xFF) as u8,
((subnet_mask >> 8) & 0xFF) as u8,
(subnet_mask & 0xFF) as u8,
)
}
async fn run_shell_cmd(cmd: &str) -> Result<(), Error> {
let cmd_out = if cfg!(target_os = "windows") {
Command::new("cmd").arg("/C").arg(cmd).output().await?
} else {
Command::new("sh").arg("-c").arg(cmd).output().await?
};
let stdout = String::from_utf8_lossy(cmd_out.stdout.as_slice());
let stderr = String::from_utf8_lossy(cmd_out.stderr.as_slice());
let ec = cmd_out.status.code();
let succ = cmd_out.status.success();
tracing::info!(?cmd, ?ec, ?succ, ?stdout, ?stderr, "run shell cmd");
if !cmd_out.status.success() {
return Err(Error::ShellCommandError(
stdout.to_string() + &stderr.to_string(),
));
}
Ok(())
}
pub struct MacIfConfiger {}
#[async_trait]
impl IfConfiguerTrait for MacIfConfiger {
async fn add_ipv4_route(
&self,
name: &str,
address: Ipv4Addr,
cidr_prefix: u8,
) -> Result<(), Error> {
run_shell_cmd(
format!(
"route -n add {} -netmask {} -interface {} -hopcount 7",
address,
cidr_to_subnet_mask(cidr_prefix),
name
)
.as_str(),
)
.await
}
async fn remove_ipv4_route(
&self,
name: &str,
address: Ipv4Addr,
cidr_prefix: u8,
) -> Result<(), Error> {
run_shell_cmd(
format!(
"route -n delete {} -netmask {} -interface {}",
address,
cidr_to_subnet_mask(cidr_prefix),
name
)
.as_str(),
)
.await
}
async fn add_ipv4_ip(
&self,
name: &str,
address: Ipv4Addr,
cidr_prefix: u8,
) -> Result<(), Error> {
run_shell_cmd(
format!(
"ifconfig {} {:?}/{:?} 10.8.8.8 up",
name, address, cidr_prefix,
)
.as_str(),
)
.await
}
async fn set_link_status(&self, name: &str, up: bool) -> Result<(), Error> {
run_shell_cmd(format!("ifconfig {} {}", name, if up { "up" } else { "down" }).as_str())
.await
}
async fn remove_ip(&self, name: &str, ip: Option<Ipv4Addr>) -> Result<(), Error> {
if ip.is_none() {
run_shell_cmd(format!("ifconfig {} inet delete", name).as_str()).await
} else {
run_shell_cmd(
format!("ifconfig {} inet {} delete", name, ip.unwrap().to_string()).as_str(),
)
.await
}
}
}
pub struct LinuxIfConfiger {}
#[async_trait]
impl IfConfiguerTrait for LinuxIfConfiger {
async fn add_ipv4_route(
&self,
name: &str,
address: Ipv4Addr,
cidr_prefix: u8,
) -> Result<(), Error> {
run_shell_cmd(
format!(
"ip route add {}/{} dev {} metric 65535",
address, cidr_prefix, name
)
.as_str(),
)
.await
}
async fn remove_ipv4_route(
&self,
name: &str,
address: Ipv4Addr,
cidr_prefix: u8,
) -> Result<(), Error> {
run_shell_cmd(format!("ip route del {}/{} dev {}", address, cidr_prefix, name).as_str())
.await
}
async fn add_ipv4_ip(
&self,
name: &str,
address: Ipv4Addr,
cidr_prefix: u8,
) -> Result<(), Error> {
run_shell_cmd(format!("ip addr add {:?}/{:?} dev {}", address, cidr_prefix, name).as_str())
.await
}
async fn set_link_status(&self, name: &str, up: bool) -> Result<(), Error> {
run_shell_cmd(format!("ip link set {} {}", name, if up { "up" } else { "down" }).as_str())
.await
}
async fn remove_ip(&self, name: &str, ip: Option<Ipv4Addr>) -> Result<(), Error> {
if ip.is_none() {
run_shell_cmd(format!("ip addr flush dev {}", name).as_str()).await
} else {
run_shell_cmd(
format!("ip addr del {:?} dev {}", ip.unwrap().to_string(), name).as_str(),
)
.await
}
}
}
pub struct WindowsIfConfiger {}
#[async_trait]
impl IfConfiguerTrait for WindowsIfConfiger {
async fn add_ipv4_route(
&self,
name: &str,
address: Ipv4Addr,
cidr_prefix: u8,
) -> Result<(), Error> {
run_shell_cmd(
format!(
"route add {} mask {} {}",
address,
cidr_to_subnet_mask(cidr_prefix),
name
)
.as_str(),
)
.await
}
async fn remove_ipv4_route(
&self,
name: &str,
address: Ipv4Addr,
cidr_prefix: u8,
) -> Result<(), Error> {
run_shell_cmd(
format!(
"route delete {} mask {} {}",
address,
cidr_to_subnet_mask(cidr_prefix),
name
)
.as_str(),
)
.await
}
async fn add_ipv4_ip(
&self,
name: &str,
address: Ipv4Addr,
cidr_prefix: u8,
) -> Result<(), Error> {
run_shell_cmd(
format!(
"netsh interface ipv4 add address {} address={} mask={}",
name,
address,
cidr_to_subnet_mask(cidr_prefix)
)
.as_str(),
)
.await
}
async fn set_link_status(&self, name: &str, up: bool) -> Result<(), Error> {
run_shell_cmd(
format!(
"netsh interface set interface {} {}",
name,
if up { "enable" } else { "disable" }
)
.as_str(),
)
.await
}
async fn remove_ip(&self, name: &str, ip: Option<Ipv4Addr>) -> Result<(), Error> {
if ip.is_none() {
run_shell_cmd(format!("netsh interface ipv4 delete address {}", name).as_str()).await
} else {
run_shell_cmd(
format!(
"netsh interface ipv4 delete address {} address={}",
name,
ip.unwrap().to_string()
)
.as_str(),
)
.await
}
}
async fn wait_interface_show(&self, name: &str) -> Result<(), Error> {
Ok(
tokio::time::timeout(std::time::Duration::from_secs(10), async move {
loop {
let Ok(_) = run_shell_cmd(
format!("netsh interface ipv4 show interfaces {}", name).as_str(),
)
.await
else {
tokio::time::sleep(std::time::Duration::from_millis(100)).await;
continue;
};
break;
}
})
.await?,
)
}
}
#[cfg(target_os = "macos")]
pub type IfConfiger = MacIfConfiger;
#[cfg(target_os = "linux")]
pub type IfConfiger = LinuxIfConfiger;
#[cfg(target_os = "windows")]
pub type IfConfiger = WindowsIfConfiger;
+9
View File
@@ -0,0 +1,9 @@
pub mod config_fs;
pub mod constants;
pub mod error;
pub mod global_ctx;
pub mod ifcfg;
pub mod netns;
pub mod network;
pub mod rkyv_util;
pub mod stun;
+118
View File
@@ -0,0 +1,118 @@
use futures::Future;
use once_cell::sync::Lazy;
use tokio::sync::Mutex;
#[cfg(target_os = "linux")]
use nix::sched::{setns, CloneFlags};
#[cfg(target_os = "linux")]
use std::os::fd::AsFd;
pub struct NetNSGuard {
#[cfg(target_os = "linux")]
old_ns: Option<std::fs::File>,
}
type NetNSLock = Mutex<()>;
static LOCK: Lazy<NetNSLock> = Lazy::new(|| Mutex::new(()));
pub static ROOT_NETNS_NAME: &str = "_root_ns";
#[cfg(target_os = "linux")]
impl NetNSGuard {
pub fn new(ns: Option<String>) -> Box<Self> {
let old_ns = if ns.is_some() {
let old_ns = if cfg!(target_os = "linux") {
Some(std::fs::File::open("/proc/self/ns/net").unwrap())
} else {
None
};
Self::switch_ns(ns);
old_ns
} else {
None
};
Box::new(NetNSGuard { old_ns })
}
fn switch_ns(name: Option<String>) {
if name.is_none() {
return;
}
let ns_path: String;
let name = name.unwrap();
if name == ROOT_NETNS_NAME {
ns_path = "/proc/1/ns/net".to_string();
} else {
ns_path = format!("/var/run/netns/{}", name);
}
let ns = std::fs::File::open(ns_path).unwrap();
log::info!(
"[INIT NS] switching to new ns_name: {:?}, ns_file: {:?}",
name,
ns
);
setns(ns.as_fd(), CloneFlags::CLONE_NEWNET).unwrap();
}
}
#[cfg(target_os = "linux")]
impl Drop for NetNSGuard {
fn drop(&mut self) {
if self.old_ns.is_none() {
return;
}
log::info!("[INIT NS] switching back to old ns, ns: {:?}", self.old_ns);
setns(
self.old_ns.as_ref().unwrap().as_fd(),
CloneFlags::CLONE_NEWNET,
)
.unwrap();
}
}
#[cfg(not(target_os = "linux"))]
impl NetNSGuard {
pub fn new(_ns: Option<String>) -> Box<Self> {
Box::new(NetNSGuard {})
}
}
#[derive(Clone)]
pub struct NetNS {
name: Option<String>,
}
impl NetNS {
pub fn new(name: Option<String>) -> Self {
NetNS { name }
}
pub async fn run_async<F, Fut, Ret>(&self, f: F) -> Ret
where
F: FnOnce() -> Fut,
Fut: Future<Output = Ret>,
{
// TODO: do we really need this lock
// let _lock = LOCK.lock().await;
let _guard = NetNSGuard::new(self.name.clone());
f().await
}
pub fn run<F, Ret>(&self, f: F) -> Ret
where
F: FnOnce() -> Ret,
{
let _guard = NetNSGuard::new(self.name.clone());
f()
}
pub fn guard(&self) -> Box<NetNSGuard> {
NetNSGuard::new(self.name.clone())
}
pub fn name(&self) -> Option<String> {
self.name.clone()
}
}
+218
View File
@@ -0,0 +1,218 @@
use std::{ops::Deref, sync::Arc};
use easytier_rpc::peer::GetIpListResponse;
use pnet::datalink::NetworkInterface;
use tokio::{
sync::{Mutex, RwLock},
task::JoinSet,
};
use super::{constants::DIRECT_CONNECTOR_IP_LIST_TIMEOUT_SEC, netns::NetNS};
struct InterfaceFilter {
iface: NetworkInterface,
}
#[cfg(target_os = "linux")]
impl InterfaceFilter {
async fn is_iface_bridge(&self) -> bool {
let path = format!("/sys/class/net/{}/bridge", self.iface.name);
tokio::fs::metadata(&path).await.is_ok()
}
async fn is_iface_phsical(&self) -> bool {
let path = format!("/sys/class/net/{}/device", self.iface.name);
tokio::fs::metadata(&path).await.is_ok()
}
async fn filter_iface(&self) -> bool {
tracing::trace!(
"filter linux iface: {:?}, is_point_to_point: {}, is_loopback: {}, is_up: {}, is_lower_up: {}, is_bridge: {}, is_physical: {}",
self.iface,
self.iface.is_point_to_point(),
self.iface.is_loopback(),
self.iface.is_up(),
self.iface.is_lower_up(),
self.is_iface_bridge().await,
self.is_iface_phsical().await,
);
!self.iface.is_point_to_point()
&& !self.iface.is_loopback()
&& self.iface.is_up()
&& self.iface.is_lower_up()
&& (self.is_iface_bridge().await || self.is_iface_phsical().await)
}
}
#[cfg(target_os = "macos")]
impl InterfaceFilter {
async fn is_interface_physical(interface_name: &str) -> bool {
let output = tokio::process::Command::new("networksetup")
.args(&["-listallhardwareports"])
.output()
.await
.expect("Failed to execute command");
let stdout = std::str::from_utf8(&output.stdout).expect("Invalid UTF-8");
let lines: Vec<&str> = stdout.lines().collect();
for i in 0..lines.len() {
let line = lines[i];
if line.contains("Device:") && line.contains(interface_name) {
let next_line = lines[i + 1];
if next_line.contains("Virtual Interface") {
return false;
} else {
return true;
}
}
}
false
}
async fn filter_iface(&self) -> bool {
!self.iface.is_point_to_point()
&& !self.iface.is_loopback()
&& self.iface.is_up()
&& Self::is_interface_physical(&self.iface.name).await
}
}
#[cfg(target_os = "windows")]
impl InterfaceFilter {
async fn filter_iface(&self) -> bool {
!self.iface.is_point_to_point() && !self.iface.is_loopback() && self.iface.is_up()
}
}
pub async fn local_ipv4() -> std::io::Result<std::net::Ipv4Addr> {
let socket = tokio::net::UdpSocket::bind("0.0.0.0:0").await?;
socket.connect("8.8.8.8:80").await?;
let addr = socket.local_addr()?;
match addr.ip() {
std::net::IpAddr::V4(ip) => Ok(ip),
std::net::IpAddr::V6(_) => Err(std::io::Error::new(
std::io::ErrorKind::AddrNotAvailable,
"no ipv4 address",
)),
}
}
pub async fn local_ipv6() -> std::io::Result<std::net::Ipv6Addr> {
let socket = tokio::net::UdpSocket::bind("[::]:0").await?;
socket
.connect("[2001:4860:4860:0000:0000:0000:0000:8888]:80")
.await?;
let addr = socket.local_addr()?;
match addr.ip() {
std::net::IpAddr::V6(ip) => Ok(ip),
std::net::IpAddr::V4(_) => Err(std::io::Error::new(
std::io::ErrorKind::AddrNotAvailable,
"no ipv4 address",
)),
}
}
pub struct IPCollector {
cached_ip_list: Arc<RwLock<GetIpListResponse>>,
collect_ip_task: Mutex<JoinSet<()>>,
net_ns: NetNS,
}
impl IPCollector {
pub fn new(net_ns: NetNS) -> Self {
Self {
cached_ip_list: Arc::new(RwLock::new(GetIpListResponse::new())),
collect_ip_task: Mutex::new(JoinSet::new()),
net_ns,
}
}
pub async fn collect_ip_addrs(&self) -> GetIpListResponse {
let mut task = self.collect_ip_task.lock().await;
if task.is_empty() {
let cached_ip_list = self.cached_ip_list.clone();
*cached_ip_list.write().await =
Self::do_collect_ip_addrs(false, self.net_ns.clone()).await;
let net_ns = self.net_ns.clone();
task.spawn(async move {
loop {
let ip_addrs = Self::do_collect_ip_addrs(true, net_ns.clone()).await;
*cached_ip_list.write().await = ip_addrs;
tokio::time::sleep(std::time::Duration::from_secs(
DIRECT_CONNECTOR_IP_LIST_TIMEOUT_SEC,
))
.await;
}
});
}
return self.cached_ip_list.read().await.deref().clone();
}
#[tracing::instrument(skip(net_ns))]
async fn do_collect_ip_addrs(with_public: bool, net_ns: NetNS) -> GetIpListResponse {
let mut ret = easytier_rpc::peer::GetIpListResponse {
public_ipv4: "".to_string(),
interface_ipv4s: vec![],
public_ipv6: "".to_string(),
interface_ipv6s: vec![],
};
if with_public {
if let Some(v4_addr) =
public_ip::addr_with(public_ip::http::ALL, public_ip::Version::V4).await
{
ret.public_ipv4 = v4_addr.to_string();
}
if let Some(v6_addr) = public_ip::addr_v6().await {
ret.public_ipv6 = v6_addr.to_string();
}
}
let _g = net_ns.guard();
let ifaces = pnet::datalink::interfaces();
for iface in ifaces {
let f = InterfaceFilter {
iface: iface.clone(),
};
if !f.filter_iface().await {
continue;
}
for ip in iface.ips {
let ip: std::net::IpAddr = ip.ip();
if ip.is_loopback() || ip.is_multicast() {
continue;
}
if ip.is_ipv4() {
ret.interface_ipv4s.push(ip.to_string());
} else if ip.is_ipv6() {
ret.interface_ipv6s.push(ip.to_string());
}
}
}
if let Ok(v4_addr) = local_ipv4().await {
tracing::trace!("got local ipv4: {}", v4_addr);
if !ret.interface_ipv4s.contains(&v4_addr.to_string()) {
ret.interface_ipv4s.push(v4_addr.to_string());
}
}
if let Ok(v6_addr) = local_ipv6().await {
tracing::trace!("got local ipv6: {}", v6_addr);
if !ret.interface_ipv6s.contains(&v6_addr.to_string()) {
ret.interface_ipv6s.push(v6_addr.to_string());
}
}
ret
}
}
+54
View File
@@ -0,0 +1,54 @@
use rkyv::{
validation::{validators::DefaultValidator, CheckTypeError},
vec::ArchivedVec,
Archive, CheckBytes, Serialize,
};
use tokio_util::bytes::{Bytes, BytesMut};
pub fn decode_from_bytes_checked<'a, T: Archive>(
bytes: &'a [u8],
) -> Result<&'a T::Archived, CheckTypeError<T::Archived, DefaultValidator<'a>>>
where
T::Archived: CheckBytes<DefaultValidator<'a>>,
{
rkyv::check_archived_root::<T>(bytes)
}
pub fn decode_from_bytes<'a, T: Archive>(
bytes: &'a [u8],
) -> Result<&'a T::Archived, CheckTypeError<T::Archived, DefaultValidator<'a>>>
where
T::Archived: CheckBytes<DefaultValidator<'a>>,
{
// rkyv::check_archived_root::<T>(bytes)
unsafe { Ok(rkyv::archived_root::<T>(bytes)) }
}
// allow deseraial T to Bytes
pub fn encode_to_bytes<T, const N: usize>(val: &T) -> Bytes
where
T: Serialize<rkyv::ser::serializers::AllocSerializer<N>>,
{
let ret = rkyv::to_bytes::<_, N>(val).unwrap();
// let mut r = BytesMut::new();
// r.extend_from_slice(&ret);
// r.freeze()
ret.into_boxed_slice().into()
}
pub fn extract_bytes_from_archived_vec(raw_data: &Bytes, archived_data: &ArchivedVec<u8>) -> Bytes {
let ptr_range = archived_data.as_ptr_range();
let offset = ptr_range.start as usize - raw_data.as_ptr() as usize;
let len = ptr_range.end as usize - ptr_range.start as usize;
return raw_data.slice(offset..offset + len);
}
pub fn extract_bytes_mut_from_archived_vec(
raw_data: &mut BytesMut,
archived_data: &ArchivedVec<u8>,
) -> BytesMut {
let ptr_range = archived_data.as_ptr_range();
let offset = ptr_range.start as usize - raw_data.as_ptr() as usize;
let len = ptr_range.end as usize - ptr_range.start as usize;
raw_data.split_off(offset).split_to(len)
}
+433
View File
@@ -0,0 +1,433 @@
use std::net::{Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6};
use std::sync::Arc;
use std::time::Duration;
use crossbeam::atomic::AtomicCell;
use easytier_rpc::{NatType, StunInfo};
use stun_format::Attr;
use tokio::net::{lookup_host, UdpSocket};
use tokio::sync::RwLock;
use tokio::task::JoinSet;
use crate::common::error::Error;
struct Stun {
stun_server: String,
req_repeat: u8,
resp_timeout: Duration,
}
#[derive(Debug, Clone, Copy)]
struct BindRequestResponse {
source_addr: SocketAddr,
mapped_socket_addr: Option<SocketAddr>,
changed_socket_addr: Option<SocketAddr>,
ip_changed: bool,
port_changed: bool,
}
impl BindRequestResponse {
pub fn get_mapped_addr_no_check(&self) -> &SocketAddr {
self.mapped_socket_addr.as_ref().unwrap()
}
}
impl Stun {
pub fn new(stun_server: String) -> Self {
Self {
stun_server,
req_repeat: 3,
resp_timeout: Duration::from_millis(3000),
}
}
async fn wait_stun_response<'a, const N: usize>(
&self,
buf: &'a mut [u8; N],
udp: &UdpSocket,
tids: &Vec<u128>,
) -> Result<(stun_format::Msg<'a>, SocketAddr), Error> {
let mut now = tokio::time::Instant::now();
let deadline = now + self.resp_timeout;
while now < deadline {
let mut udp_buf = [0u8; 1500];
let (len, remote_addr) =
tokio::time::timeout(deadline - now, udp.recv_from(udp_buf.as_mut_slice()))
.await??;
now = tokio::time::Instant::now();
if len < 20 {
continue;
}
// TODO:: we cannot borrow `buf` directly in udp recv_from, so we copy it here
unsafe { std::ptr::copy(udp_buf.as_ptr(), buf.as_ptr() as *mut u8, len) };
let msg = stun_format::Msg::<'a>::from(&buf[..]);
tracing::trace!(b = ?&udp_buf[..len], ?msg, ?tids, "recv stun response");
if msg.typ().is_none() || msg.tid().is_none() {
continue;
}
if matches!(
msg.typ().as_ref().unwrap(),
stun_format::MsgType::BindingResponse
) && tids.contains(msg.tid().as_ref().unwrap())
{
return Ok((msg, remote_addr));
}
}
Err(Error::Unknown)
}
fn stun_addr(addr: stun_format::SocketAddr) -> SocketAddr {
match addr {
stun_format::SocketAddr::V4(ip, port) => {
SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::from(ip), port))
}
stun_format::SocketAddr::V6(ip, port) => {
SocketAddr::V6(SocketAddrV6::new(Ipv6Addr::from(ip), port, 0, 0))
}
}
}
fn extrace_mapped_addr(msg: &stun_format::Msg) -> Option<SocketAddr> {
let mut mapped_addr = None;
for x in msg.attrs_iter() {
match x {
Attr::MappedAddress(addr) => {
if mapped_addr.is_none() {
let _ = mapped_addr.insert(Self::stun_addr(addr));
}
}
Attr::XorMappedAddress(addr) => {
if mapped_addr.is_none() {
let _ = mapped_addr.insert(Self::stun_addr(addr));
}
}
_ => {}
}
}
mapped_addr
}
fn extract_changed_addr(msg: &stun_format::Msg) -> Option<SocketAddr> {
let mut changed_addr = None;
for x in msg.attrs_iter() {
match x {
Attr::ChangedAddress(addr) => {
if changed_addr.is_none() {
let _ = changed_addr.insert(Self::stun_addr(addr));
}
}
_ => {}
}
}
changed_addr
}
pub async fn bind_request(
&self,
source_port: u16,
change_ip: bool,
change_port: bool,
) -> Result<BindRequestResponse, Error> {
let stun_host = lookup_host(&self.stun_server)
.await?
.next()
.ok_or(Error::NotFound)?;
// let udp_socket = socket2::Socket::new(
// match stun_host {
// SocketAddr::V4(..) => socket2::Domain::IPV4,
// SocketAddr::V6(..) => socket2::Domain::IPV6,
// },
// socket2::Type::DGRAM,
// Some(socket2::Protocol::UDP),
// )?;
// udp_socket.set_reuse_port(true)?;
// udp_socket.set_reuse_address(true)?;
let udp = UdpSocket::bind(format!("0.0.0.0:{}", source_port)).await?;
// repeat req in case of packet loss
let mut tids = vec![];
for _ in 0..self.req_repeat {
let mut buf = [0u8; 28];
// memset buf
unsafe { std::ptr::write_bytes(buf.as_mut_ptr(), 0, buf.len()) };
let mut msg = stun_format::MsgBuilder::from(buf.as_mut_slice());
msg.typ(stun_format::MsgType::BindingRequest).unwrap();
let tid = rand::random::<u32>();
msg.tid(tid as u128).unwrap();
if change_ip || change_port {
msg.add_attr(Attr::ChangeRequest {
change_ip,
change_port,
})
.unwrap();
}
tids.push(tid as u128);
tracing::trace!(b = ?msg.as_bytes(), tid, "send stun request");
udp.send_to(msg.as_bytes(), &stun_host).await?;
}
tracing::trace!("waiting stun response");
let mut buf = [0; 1620];
let (msg, recv_addr) = self.wait_stun_response(&mut buf, &udp, &tids).await?;
let changed_socket_addr = Self::extract_changed_addr(&msg);
let ip_changed = stun_host.ip() != recv_addr.ip();
let port_changed = stun_host.port() != recv_addr.port();
let resp = BindRequestResponse {
source_addr: udp.local_addr()?,
mapped_socket_addr: Self::extrace_mapped_addr(&msg),
changed_socket_addr,
ip_changed,
port_changed,
};
tracing::info!(
?stun_host,
?recv_addr,
?changed_socket_addr,
"finish stun bind request"
);
Ok(resp)
}
}
struct UdpNatTypeDetector {
stun_servers: Vec<String>,
}
impl UdpNatTypeDetector {
pub fn new(stun_servers: Vec<String>) -> Self {
Self { stun_servers }
}
async fn get_udp_nat_type(&self, mut source_port: u16) -> NatType {
// Like classic STUN (rfc3489). Detect NAT behavior for UDP.
// Modified from rfc3489. Requires at least two STUN servers.
let mut ret_test1_1 = None;
let mut ret_test1_2 = None;
let mut ret_test2 = None;
let mut ret_test3 = None;
if source_port == 0 {
let udp = UdpSocket::bind("0.0.0.0:0").await.unwrap();
source_port = udp.local_addr().unwrap().port();
}
let mut succ = false;
for server_ip in &self.stun_servers {
let stun = Stun::new(server_ip.clone());
let ret = stun.bind_request(source_port, false, false).await;
if ret.is_err() {
// Try another STUN server
continue;
}
if ret_test1_1.is_none() {
ret_test1_1 = ret.ok();
continue;
}
ret_test1_2 = ret.ok();
let ret = stun.bind_request(source_port, true, true).await;
if let Ok(resp) = ret {
if !resp.ip_changed || !resp.port_changed {
// Try another STUN server
continue;
}
}
ret_test2 = ret.ok();
ret_test3 = stun.bind_request(source_port, false, true).await.ok();
succ = true;
break;
}
if !succ {
return NatType::Unknown;
}
tracing::info!(
?ret_test1_1,
?ret_test1_2,
?ret_test2,
?ret_test3,
"finish stun test, try to detect nat type"
);
let ret_test1_1 = ret_test1_1.unwrap();
let ret_test1_2 = ret_test1_2.unwrap();
if ret_test1_1.mapped_socket_addr != ret_test1_2.mapped_socket_addr {
return NatType::Symmetric;
}
if ret_test1_1.mapped_socket_addr.is_some()
&& ret_test1_1.source_addr == ret_test1_1.mapped_socket_addr.unwrap()
{
if !ret_test2.is_none() {
return NatType::OpenInternet;
} else {
return NatType::SymUdpFirewall;
}
} else {
if let Some(ret_test2) = ret_test2 {
if source_port == ret_test2.get_mapped_addr_no_check().port()
&& source_port == ret_test1_1.get_mapped_addr_no_check().port()
{
return NatType::NoPat;
} else {
return NatType::FullCone;
}
} else {
if !ret_test3.is_none() {
return NatType::Restricted;
} else {
return NatType::PortRestricted;
}
}
}
}
}
#[async_trait::async_trait]
#[auto_impl::auto_impl(&, Arc, Box)]
pub trait StunInfoCollectorTrait: Send + Sync {
fn get_stun_info(&self) -> StunInfo;
async fn get_udp_port_mapping(&self, local_port: u16) -> Result<SocketAddr, Error>;
}
pub struct StunInfoCollector {
stun_servers: Arc<RwLock<Vec<String>>>,
udp_nat_type: Arc<AtomicCell<(NatType, std::time::Instant)>>,
redetect_notify: Arc<tokio::sync::Notify>,
tasks: JoinSet<()>,
}
#[async_trait::async_trait]
impl StunInfoCollectorTrait for StunInfoCollector {
fn get_stun_info(&self) -> StunInfo {
let (typ, time) = self.udp_nat_type.load();
StunInfo {
udp_nat_type: typ as i32,
tcp_nat_type: 0,
last_update_time: time.elapsed().as_secs() as i64,
}
}
async fn get_udp_port_mapping(&self, local_port: u16) -> Result<SocketAddr, Error> {
let stun_servers = self.stun_servers.read().await.clone();
for server in stun_servers.iter() {
let stun = Stun::new(server.clone());
let Ok(ret) = stun.bind_request(local_port, false, false).await else {
tracing::warn!(?server, "stun bind request failed");
continue;
};
if let Some(mapped_addr) = ret.mapped_socket_addr {
return Ok(mapped_addr);
}
}
Err(Error::NotFound)
}
}
impl StunInfoCollector {
pub fn new(stun_servers: Vec<String>) -> Self {
let mut ret = Self {
stun_servers: Arc::new(RwLock::new(stun_servers)),
udp_nat_type: Arc::new(AtomicCell::new((
NatType::Unknown,
std::time::Instant::now(),
))),
redetect_notify: Arc::new(tokio::sync::Notify::new()),
tasks: JoinSet::new(),
};
ret.start_stun_routine();
ret
}
fn start_stun_routine(&mut self) {
let stun_servers = self.stun_servers.clone();
let udp_nat_type = self.udp_nat_type.clone();
let redetect_notify = self.redetect_notify.clone();
self.tasks.spawn(async move {
loop {
let detector = UdpNatTypeDetector::new(stun_servers.read().await.clone());
let ret = detector.get_udp_nat_type(0).await;
udp_nat_type.store((ret, std::time::Instant::now()));
let sleep_sec = match ret {
NatType::Unknown => 15,
_ => 60,
};
tracing::info!(?ret, ?sleep_sec, "finish udp nat type detect");
tokio::select! {
_ = redetect_notify.notified() => {}
_ = tokio::time::sleep(Duration::from_secs(sleep_sec)) => {}
}
}
});
}
pub fn update_stun_info(&self) {
self.redetect_notify.notify_one();
}
pub async fn set_stun_servers(&self, stun_servers: Vec<String>) {
*self.stun_servers.write().await = stun_servers;
self.update_stun_info();
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_stun_bind_request() {
// miwifi / qq seems not correctly responde to change_ip and change_port, they always try to change the src ip and port.
// let stun = Stun::new("stun.counterpath.com:3478".to_string());
let stun = Stun::new("180.235.108.91:3478".to_string());
// let stun = Stun::new("193.22.2.248:3478".to_string());
// let stun = Stun::new("stun.chat.bilibili.com:3478".to_string());
// let stun = Stun::new("stun.miwifi.com:3478".to_string());
let rs = stun.bind_request(12345, true, true).await.unwrap();
assert!(rs.ip_changed);
assert!(rs.port_changed);
let rs = stun.bind_request(12345, true, false).await.unwrap();
assert!(rs.ip_changed);
assert!(!rs.port_changed);
let rs = stun.bind_request(12345, false, true).await.unwrap();
assert!(!rs.ip_changed);
assert!(rs.port_changed);
let rs = stun.bind_request(12345, false, false).await.unwrap();
assert!(!rs.ip_changed);
assert!(!rs.port_changed);
}
#[tokio::test]
async fn test_udp_nat_type_detect() {
let detector = UdpNatTypeDetector::new(vec![
"stun.counterpath.com:3478".to_string(),
"180.235.108.91:3478".to_string(),
]);
let ret = detector.get_udp_nat_type(0).await;
assert_eq!(ret, NatType::FullCone);
}
}