From c142db301a6e81526b55c90ec46706236f4aec83 Mon Sep 17 00:00:00 2001 From: "Sijie.Sun" Date: Tue, 1 Apr 2025 09:59:53 +0800 Subject: [PATCH] port forward (#736) * support tcp port forward * support udp port forward * command line option for port forward --- easytier-gui/locales/cn.yml | 1 + easytier-gui/locales/en.yml | 1 + easytier-web/frontend-lib/src/locales/cn.yaml | 2 +- easytier-web/frontend-lib/src/locales/en.yaml | 1 + .../frontend-lib/src/types/network.ts | 2 + easytier/Cargo.toml | 25 +- easytier/locales/app.yml | 3 + easytier/src/common/config.rs | 103 ++- easytier/src/common/global_ctx.rs | 4 +- easytier/src/easytier-core.rs | 62 +- easytier/src/gateway/ip_reassembler.rs | 9 + easytier/src/gateway/socks5.rs | 622 +++++++++++++++--- easytier/src/gateway/tokio_smoltcp/mod.rs | 9 +- easytier/src/gateway/tokio_smoltcp/socket.rs | 84 +++ .../gateway/tokio_smoltcp/socket_allocator.rs | 32 +- easytier/src/proto/common.proto | 11 + easytier/src/proto/common.rs | 6 + easytier/src/tests/three_node.rs | 120 +++- easytier/src/tunnel/common.rs | 1 + 19 files changed, 955 insertions(+), 143 deletions(-) diff --git a/easytier-gui/locales/cn.yml b/easytier-gui/locales/cn.yml index da44a6c6..66f5de3e 100644 --- a/easytier-gui/locales/cn.yml +++ b/easytier-gui/locales/cn.yml @@ -113,3 +113,4 @@ event: VpnPortalClientDisconnected: VPN门户客户端已断开连接 DhcpIpv4Changed: DHCP IPv4地址更改 DhcpIpv4Conflicted: DHCP IPv4地址冲突 + PortForwardAdded: 端口转发添加 diff --git a/easytier-gui/locales/en.yml b/easytier-gui/locales/en.yml index 23bfdf7d..94d8178b 100644 --- a/easytier-gui/locales/en.yml +++ b/easytier-gui/locales/en.yml @@ -112,3 +112,4 @@ event: VpnPortalClientDisconnected: VpnPortalClientDisconnected DhcpIpv4Changed: DhcpIpv4Changed DhcpIpv4Conflicted: DhcpIpv4Conflicted + PortForwardAdded: PortForwardAdded diff --git a/easytier-web/frontend-lib/src/locales/cn.yaml b/easytier-web/frontend-lib/src/locales/cn.yaml index 4255cace..754c1c54 100644 --- a/easytier-web/frontend-lib/src/locales/cn.yaml +++ b/easytier-web/frontend-lib/src/locales/cn.yaml @@ -182,4 +182,4 @@ event: VpnPortalClientDisconnected: VPN门户客户端已断开连接 DhcpIpv4Changed: DHCP IPv4地址更改 DhcpIpv4Conflicted: DHCP IPv4地址冲突 - + PortForwardAdded: 端口转发添加 diff --git a/easytier-web/frontend-lib/src/locales/en.yaml b/easytier-web/frontend-lib/src/locales/en.yaml index 26e11d93..237df2c9 100644 --- a/easytier-web/frontend-lib/src/locales/en.yaml +++ b/easytier-web/frontend-lib/src/locales/en.yaml @@ -182,3 +182,4 @@ event: VpnPortalClientDisconnected: VpnPortalClientDisconnected DhcpIpv4Changed: DhcpIpv4Changed DhcpIpv4Conflicted: DhcpIpv4Conflicted + PortForwardAdded: PortForwardAdded diff --git a/easytier-web/frontend-lib/src/types/network.ts b/easytier-web/frontend-lib/src/types/network.ts index 3de4890d..c6a0bf9e 100644 --- a/easytier-web/frontend-lib/src/types/network.ts +++ b/easytier-web/frontend-lib/src/types/network.ts @@ -264,4 +264,6 @@ export enum EventType { DhcpIpv4Changed = 'DhcpIpv4Changed', // ipv4 | null, ipv4 | null DhcpIpv4Conflicted = 'DhcpIpv4Conflicted', // ipv4 | null + + PortForwardAdded = 'PortForwardAdded', // PortForwardConfigPb } diff --git a/easytier/Cargo.toml b/easytier/Cargo.toml index ab41b931..c340ca42 100644 --- a/easytier/Cargo.toml +++ b/easytier/Cargo.toml @@ -162,8 +162,14 @@ smoltcp = { version = "0.12.0", optional = true, default-features = false, featu "medium-ip", "proto-ipv4", "proto-ipv6", + "proto-ipv4-fragmentation", + "fragmentation-buffer-size-8192", + "assembler-max-segment-count-16", + "reassembly-buffer-size-8192", + "reassembly-buffer-count-16", "socket-tcp", - "socket-tcp-cubic", + "socket-udp", + # "socket-tcp-cubic", "async", ] } parking_lot = { version = "0.12.0", optional = true } @@ -176,9 +182,12 @@ sys-locale = "0.3" ringbuf = "0.4.5" async-ringbuf = "0.3.1" -service-manager = {git = "https://github.com/chipsenkbeil/service-manager-rs.git", branch = "main"} +service-manager = { git = "https://github.com/chipsenkbeil/service-manager-rs.git", branch = "main" } -async-compression = { version = "0.4.17", default-features = false, features = ["zstd", "tokio"] } +async-compression = { version = "0.4.17", default-features = false, features = [ + "zstd", + "tokio", +] } kcp-sys = { git = "https://github.com/EasyTier/kcp-sys" } @@ -187,7 +196,9 @@ prost-reflect = { version = "0.14.5", default-features = false, features = [ ] } # for http connector -http_req = { git = "https://github.com/EasyTier/http_req.git", default-features = false, features = ["rust-tls"] } +http_req = { git = "https://github.com/EasyTier/http_req.git", default-features = false, features = [ + "rust-tls", +] } # for dns connector hickory-resolver = "0.24.4" @@ -212,7 +223,7 @@ windows = { version = "0.52.0", features = [ "Win32_System_Ole", "Win32_Networking_WinSock", "Win32_System_IO", -]} +] } encoding = "0.2" winreg = "0.52" windows-service = "0.7.0" @@ -222,7 +233,9 @@ tonic-build = "0.12" globwalk = "0.8.1" regex = "1" prost-build = "0.13.2" -rpc_build = { package = "easytier-rpc-build", version = "0.1.0", features = ["internal-namespace"] } +rpc_build = { package = "easytier-rpc-build", version = "0.1.0", features = [ + "internal-namespace", +] } prost-reflect-build = { version = "0.14.0" } [target.'cfg(windows)'.build-dependencies] diff --git a/easytier/locales/app.yml b/easytier/locales/app.yml index d6812042..dd32282a 100644 --- a/easytier/locales/app.yml +++ b/easytier/locales/app.yml @@ -149,6 +149,9 @@ core_clap: disable_kcp_input: en: "do not allow other nodes to use kcp to proxy tcp streams to this node. when a node with kcp proxy enabled accesses this node, the original tcp connection is preserved." zh-CN: "不允许其他节点使用 KCP 代理 TCP 流到此节点。开启 KCP 代理的节点访问此节点时,依然使用原始 TCP 连接。" + port_forward: + en: "forward local port to remote port in virtual network. e.g.: udp://0.0.0.0:12345/10.126.126.1:23456, means forward local udp port 12345 to 10.126.126.1:23456 in the virtual network. can specify multiple." + zh-CN: "将本地端口转发到虚拟网络中的远程端口。例如:udp://0.0.0.0:12345/10.126.126.1:23456,表示将本地UDP端口12345转发到虚拟网络中的10.126.126.1:23456。可以指定多个。" core_app: panic_backtrace_save: diff --git a/easytier/src/common/config.rs b/easytier/src/common/config.rs index d6a727a8..c782fecf 100644 --- a/easytier/src/common/config.rs +++ b/easytier/src/common/config.rs @@ -7,7 +7,10 @@ use std::{ use anyhow::Context; use serde::{Deserialize, Serialize}; -use crate::{proto::common::CompressionAlgoPb, tunnel::generate_digest_from_str}; +use crate::{ + proto::common::{CompressionAlgoPb, PortForwardConfigPb, SocketType}, + tunnel::generate_digest_from_str, +}; pub type Flags = crate::proto::common::FlagsInConfig; @@ -97,6 +100,9 @@ pub trait ConfigLoader: Send + Sync { fn get_socks5_portal(&self) -> Option; fn set_socks5_portal(&self, addr: Option); + fn get_port_forwards(&self) -> Vec; + fn set_port_forwards(&self, forwards: Vec); + fn dump(&self) -> String; } @@ -180,6 +186,41 @@ pub struct VpnPortalConfig { pub wireguard_listen: SocketAddr, } +#[derive(Debug, Clone, Deserialize, Serialize, PartialEq)] +pub struct PortForwardConfig { + pub bind_addr: SocketAddr, + pub dst_addr: SocketAddr, + pub proto: String, +} + +impl From for PortForwardConfig { + fn from(config: PortForwardConfigPb) -> Self { + PortForwardConfig { + bind_addr: config.bind_addr.unwrap_or_default().into(), + dst_addr: config.dst_addr.unwrap_or_default().into(), + proto: match SocketType::try_from(config.socket_type) { + Ok(SocketType::Tcp) => "tcp".to_string(), + Ok(SocketType::Udp) => "udp".to_string(), + _ => "tcp".to_string(), + }, + } + } +} + +impl Into for PortForwardConfig { + fn into(self) -> PortForwardConfigPb { + PortForwardConfigPb { + bind_addr: Some(self.bind_addr.into()), + dst_addr: Some(self.dst_addr.into()), + socket_type: match self.proto.to_lowercase().as_str() { + "tcp" => SocketType::Tcp as i32, + "udp" => SocketType::Udp as i32, + _ => SocketType::Tcp as i32, + }, + } + } +} + #[derive(Debug, Clone, Deserialize, Serialize, PartialEq)] struct Config { netns: Option, @@ -207,6 +248,8 @@ struct Config { socks5_proxy: Option, + port_forward: Option>, + flags: Option>, #[serde(skip)] @@ -534,6 +577,35 @@ impl ConfigLoader for TomlConfigLoader { self.config.lock().unwrap().exit_nodes = Some(nodes); } + fn get_routes(&self) -> Option> { + self.config.lock().unwrap().routes.clone() + } + + fn set_routes(&self, routes: Option>) { + self.config.lock().unwrap().routes = routes; + } + + fn get_socks5_portal(&self) -> Option { + self.config.lock().unwrap().socks5_proxy.clone() + } + + fn set_socks5_portal(&self, addr: Option) { + self.config.lock().unwrap().socks5_proxy = addr; + } + + fn get_port_forwards(&self) -> Vec { + self.config + .lock() + .unwrap() + .port_forward + .clone() + .unwrap_or_default() + } + + fn set_port_forwards(&self, forwards: Vec) { + self.config.lock().unwrap().port_forward = Some(forwards); + } + fn dump(&self) -> String { let default_flags_json = serde_json::to_string(&gen_default_flags()).unwrap(); let default_flags_hashmap = @@ -558,22 +630,6 @@ impl ConfigLoader for TomlConfigLoader { config.flags = Some(flag_map); toml::to_string_pretty(&config).unwrap() } - - fn get_routes(&self) -> Option> { - self.config.lock().unwrap().routes.clone() - } - - fn set_routes(&self, routes: Option>) { - self.config.lock().unwrap().routes = routes; - } - - fn get_socks5_portal(&self) -> Option { - self.config.lock().unwrap().socks5_proxy.clone() - } - - fn set_socks5_portal(&self, addr: Option) { - self.config.lock().unwrap().socks5_proxy = addr; - } } #[cfg(test)] @@ -614,6 +670,11 @@ dir = "/tmp/easytier" [console_logger] level = "warn" + +[[port_forward]] +bind_addr = "0.0.0.0:11011" +dst_addr = "192.168.94.33:11011" +proto = "tcp" "#; let ret = TomlConfigLoader::new_from_str(config_str); if let Err(e) = &ret { @@ -634,6 +695,14 @@ level = "warn" .collect::>() ); + assert_eq!( + vec![PortForwardConfig { + bind_addr: "0.0.0.0:11011".parse().unwrap(), + dst_addr: "192.168.94.33:11011".parse().unwrap(), + proto: "tcp".to_string(), + }], + ret.get_port_forwards() + ); println!("{}", ret.dump()); } } diff --git a/easytier/src/common/global_ctx.rs b/easytier/src/common/global_ctx.rs index caa723be..3774f9cf 100644 --- a/easytier/src/common/global_ctx.rs +++ b/easytier/src/common/global_ctx.rs @@ -5,7 +5,7 @@ use std::{ }; use crate::proto::cli::PeerConnInfo; -use crate::proto::common::PeerFeatureFlag; +use crate::proto::common::{PeerFeatureFlag, PortForwardConfigPb}; use crossbeam::atomic::AtomicCell; use super::{ @@ -42,6 +42,8 @@ pub enum GlobalCtxEvent { DhcpIpv4Changed(Option, Option), // (old, new) DhcpIpv4Conflicted(Option), + + PortForwardAdded(PortForwardConfigPb), } pub type EventBus = tokio::sync::broadcast::Sender; diff --git a/easytier/src/easytier-core.rs b/easytier/src/easytier-core.rs index a2882d3f..7529735d 100644 --- a/easytier/src/easytier-core.rs +++ b/easytier/src/easytier-core.rs @@ -17,7 +17,7 @@ use easytier::{ common::{ config::{ ConfigLoader, ConsoleLoggerConfig, FileLoggerConfig, NetworkIdentity, PeerConfig, - TomlConfigLoader, VpnPortalConfig, + PortForwardConfig, TomlConfigLoader, VpnPortalConfig, }, constants::EASYTIER_VERSION, global_ctx::{EventBusSubscriber, GlobalCtx, GlobalCtxEvent}, @@ -334,6 +334,13 @@ struct Cli { default_value = "false" )] disable_kcp_input: bool, + + #[arg( + long, + help = t!("core_clap.port_forward").to_string(), + num_args = 1.. + )] + port_forward: Vec, } rust_i18n::i18n!("locales", fallback = "en"); @@ -549,6 +556,40 @@ impl TryFrom<&Cli> for TomlConfigLoader { )); } + #[cfg(feature = "socks5")] + for port_forward in cli.port_forward.iter() { + let example_str = ", example: udp://0.0.0.0:12345/10.126.126.1:12345"; + + let bind_addr = format!( + "{}:{}", + port_forward.host_str().expect("local bind host is missing"), + port_forward.port().expect("local bind port is missing") + ) + .parse() + .expect(format!("failed to parse local bind addr {}", example_str).as_str()); + + let dst_addr = format!( + "{}", + port_forward + .path_segments() + .expect(format!("remote destination addr is missing {}", example_str).as_str()) + .next() + .expect(format!("remote destination addr is missing {}", example_str).as_str()) + ) + .parse() + .expect(format!("failed to parse remote destination addr {}", example_str).as_str()); + + let port_forward_item = PortForwardConfig { + bind_addr, + dst_addr, + proto: port_forward.scheme().to_string(), + }; + + let mut old = cfg.get_port_forwards(); + old.push(port_forward_item); + cfg.set_port_forwards(old); + } + let mut f = cfg.get_flags(); if cli.default_protocol.is_some() { f.default_protocol = cli.default_protocol.as_ref().unwrap().clone(); @@ -710,6 +751,15 @@ pub fn handle_event(mut events: EventBusSubscriber) -> tokio::task::JoinHandle<( GlobalCtxEvent::DhcpIpv4Conflicted(ip) => { print_event(format!("dhcp ip conflict. ip: {:?}", ip)); } + + GlobalCtxEvent::PortForwardAdded(cfg) => { + print_event(format!( + "port forward added. local: {}, remote: {}, proto: {}", + cfg.bind_addr.unwrap().to_string(), + cfg.dst_addr.unwrap().to_string(), + cfg.socket_type().as_str_name() + )); + } } } }) @@ -870,17 +920,13 @@ async fn run_main(cli: Cli) -> anyhow::Result<()> { flags.bind_device = false; global_ctx.set_flags(flags); let hostname = match cli.hostname { - None => { - gethostname::gethostname().to_string_lossy().to_string() - } - Some(hostname) => { - hostname.to_string() - } + None => gethostname::gethostname().to_string_lossy().to_string(), + Some(hostname) => hostname.to_string(), }; let _wc = web_client::WebClient::new( create_connector_by_url(c_url.as_str(), &global_ctx, IpVersion::Both).await?, token.to_string(), - hostname + hostname, ); tokio::signal::ctrl_c().await.unwrap(); DNSTunnelConnector::new("".parse().unwrap(), global_ctx); diff --git a/easytier/src/gateway/ip_reassembler.rs b/easytier/src/gateway/ip_reassembler.rs index 7f20c9e7..e05dc4e4 100644 --- a/easytier/src/gateway/ip_reassembler.rs +++ b/easytier/src/gateway/ip_reassembler.rs @@ -45,11 +45,13 @@ impl IpPacket { // make sure the fragment doesn't overlap with existing fragments for f in &self.fragments { if f.offset <= fragment.offset && fragment.offset < f.offset + f.data.len() as u16 { + tracing::trace!("fragment overlap 1, f.offset = {}, fragment.offset = {}, f.data.len() = {}, fragment.data.len() = {}", f.offset, fragment.offset, f.data.len(), fragment.data.len()); return; } if fragment.offset <= f.offset && f.offset < fragment.offset + fragment.data.len() as u16 { + tracing::trace!("fragment overlap 2, f.offset = {}, fragment.offset = {}, f.data.len() = {}, fragment.data.len() = {}", f.offset, fragment.offset, f.data.len(), fragment.data.len()); return; } } @@ -151,6 +153,13 @@ impl IpReassembler { id, }; + tracing::trace!( + ?key, + "add fragment, offset = {}, total_length = {}", + fragment.offset, + total_length + ); + let mut entry = self.packets.entry(key.clone()).or_insert_with(|| { let packet = IpPacket::new(source, destination); let timestamp = Instant::now(); diff --git a/easytier/src/gateway/socks5.rs b/easytier/src/gateway/socks5.rs index c3635f44..c979a53c 100644 --- a/easytier/src/gateway/socks5.rs +++ b/easytier/src/gateway/socks5.rs @@ -1,10 +1,16 @@ use std::{ net::{IpAddr, Ipv4Addr, SocketAddr}, sync::Arc, - time::Duration, + time::{Duration, Instant}, }; +use crossbeam::atomic::AtomicCell; + use crate::{ + common::{ + config::PortForwardConfig, global_ctx::GlobalCtxEvent, join_joinset_background, + scoped_task::ScopedTask, + }, gateway::{ fast_socks5::{ server::{ @@ -12,19 +18,21 @@ use crate::{ }, util::stream::tcp_connect_with_timeout, }, - tokio_smoltcp::TcpStream, + ip_reassembler::IpReassembler, + tokio_smoltcp::{channel_device, Net, NetConfig}, }, - tunnel::packet_def::PacketType, + tunnel::packet_def::{PacketType, ZCPacket}, }; use anyhow::Context; -use dashmap::DashSet; -use pnet::packet::{ip::IpNextHeaderProtocols, ipv4::Ipv4Packet, tcp::TcpPacket, Packet}; -use tokio::{ - io::{AsyncRead, AsyncWrite}, - select, +use dashmap::DashMap; +use pnet::packet::{ + ip::IpNextHeaderProtocols, ipv4::Ipv4Packet, tcp::TcpPacket, udp::UdpPacket, Packet, }; use tokio::{ + io::{AsyncRead, AsyncWrite}, net::TcpListener, + net::UdpSocket, + select, sync::{mpsc, Mutex}, task::JoinSet, time::timeout, @@ -32,14 +40,33 @@ use tokio::{ use crate::{ common::{error::Error, global_ctx::GlobalCtx}, - gateway::tokio_smoltcp::{channel_device, Net, NetConfig}, peers::{peer_manager::PeerManager, PeerPacketFilter}, - tunnel::packet_def::ZCPacket, }; +enum SocksUdpSocket { + UdpSocket(Arc), + SmolUdpSocket(super::tokio_smoltcp::UdpSocket), +} + +impl SocksUdpSocket { + pub async fn send_to(&self, buf: &[u8], addr: SocketAddr) -> Result { + match self { + SocksUdpSocket::UdpSocket(socket) => socket.send_to(buf, addr).await, + SocksUdpSocket::SmolUdpSocket(socket) => socket.send_to(buf, addr).await, + } + } + + pub async fn recv_from(&self, buf: &mut [u8]) -> Result<(usize, SocketAddr), std::io::Error> { + match self { + SocksUdpSocket::UdpSocket(socket) => socket.recv_from(buf).await, + SocksUdpSocket::SmolUdpSocket(socket) => socket.recv_from(buf).await, + } + } +} + enum SocksTcpStream { TcpStream(tokio::net::TcpStream), - SmolTcpStream(TcpStream), + SmolTcpStream(super::tokio_smoltcp::TcpStream), } impl AsyncRead for SocksTcpStream { @@ -102,13 +129,80 @@ impl AsyncWrite for SocksTcpStream { } } +enum Socks5EntryData { + Tcp(TcpListener), // hold a binded socket to hold the tcp port + Udp((Arc, UdpClientKey)), // hold the socket to send data to dst +} + +const UDP_ENTRY: u8 = 1; +const TCP_ENTRY: u8 = 2; + #[derive(Debug, Eq, PartialEq, Hash, Clone)] struct Socks5Entry { src: SocketAddr, dst: SocketAddr, + entry_type: u8, } -type Socks5EntrySet = Arc>; +type Socks5EntrySet = Arc>; + +struct SmolTcpConnector { + net: Arc, + entries: Socks5EntrySet, + current_entry: std::sync::Mutex>, +} + +#[async_trait::async_trait] +impl AsyncTcpConnector for SmolTcpConnector { + type S = SocksTcpStream; + + async fn tcp_connect( + &self, + addr: SocketAddr, + timeout_s: u64, + ) -> crate::gateway::fast_socks5::Result { + let tmp_listener = TcpListener::bind("0.0.0.0:0").await?; + let local_addr = self.net.get_address(); + let port = tmp_listener.local_addr()?.port(); + + let entry = Socks5Entry { + src: SocketAddr::new(local_addr, port), + dst: addr, + entry_type: TCP_ENTRY, + }; + *self.current_entry.lock().unwrap() = Some(entry.clone()); + self.entries + .insert(entry, Socks5EntryData::Tcp(tmp_listener)); + + if addr.ip() == local_addr { + let modified_addr = + SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), addr.port()); + + Ok(SocksTcpStream::TcpStream( + tcp_connect_with_timeout(modified_addr, timeout_s).await?, + )) + } else { + let remote_socket = timeout( + Duration::from_secs(timeout_s), + self.net.tcp_connect(addr, port), + ) + .await + .with_context(|| "connect to remote timeout")?; + + Ok(SocksTcpStream::SmolTcpStream(remote_socket.map_err( + |e| super::fast_socks5::SocksError::Other(e.into()), + )?)) + } + } +} + +impl Drop for SmolTcpConnector { + fn drop(&mut self) { + if let Some(entry) = self.current_entry.lock().unwrap().take() { + self.entries.remove(&entry); + } + } +} struct Socks5ServerNet { ipv4_addr: cidr::Ipv4Inet, @@ -130,7 +224,7 @@ impl Socks5ServerNet { ) -> Self { let mut forward_tasks = JoinSet::new(); let mut cap = smoltcp::phy::DeviceCapabilities::default(); - cap.max_transmission_unit = 1280; + cap.max_transmission_unit = 1284; // 1284 - 20 can be divided by 8 (fragment offset unit) cap.medium = smoltcp::phy::Medium::Ip; let (dev, stack_sink, mut stack_stream) = channel_device::ChannelDevice::new(cap); @@ -151,7 +245,8 @@ impl Socks5ServerNet { while let Some(data) = stack_stream.recv().await { tracing::trace!( ?data, - "receive from smoltcp stack and send to peer mgr packet" + "receive from smoltcp stack and send to peer mgr packet, len = {}", + data.len() ); let Some(ipv4) = Ipv4Packet::new(&data) else { tracing::error!(?data, "smoltcp stack stream get non ipv4 packet"); @@ -197,69 +292,14 @@ impl Socks5ServerNet { config.set_skip_auth(false); config.set_allow_no_auth(true); - struct SmolTcpConnector( - Arc, - Socks5EntrySet, - std::sync::Mutex>, - ); - - #[async_trait::async_trait] - impl AsyncTcpConnector for SmolTcpConnector { - type S = SocksTcpStream; - - async fn tcp_connect( - &self, - addr: SocketAddr, - timeout_s: u64, - ) -> crate::gateway::fast_socks5::Result { - let local_addr = self.0.get_address(); - let port = self.0.get_port(); - - let entry = Socks5Entry { - src: SocketAddr::new(local_addr, port), - dst: addr, - }; - *self.2.lock().unwrap() = Some(entry.clone()); - self.1.insert(entry); - - if addr.ip() == local_addr { - let modified_addr = - SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), addr.port()); - - Ok(SocksTcpStream::TcpStream( - tcp_connect_with_timeout(modified_addr, timeout_s).await?, - )) - } else { - let remote_socket = timeout( - Duration::from_secs(timeout_s), - self.0.tcp_connect(addr, port), - ) - .await - .with_context(|| "connect to remote timeout")?; - - Ok(SocksTcpStream::SmolTcpStream(remote_socket.map_err( - |e| super::fast_socks5::SocksError::Other(e.into()), - )?)) - } - } - } - - impl Drop for SmolTcpConnector { - fn drop(&mut self) { - if let Some(entry) = self.2.lock().unwrap().take() { - self.1.remove(&entry); - } - } - } - let socket = Socks5Socket::new( stream, Arc::new(config), - SmolTcpConnector( - self.smoltcp_net.clone(), - self.entries.clone(), - std::sync::Mutex::new(None), - ), + SmolTcpConnector { + net: self.smoltcp_net.clone(), + entries: self.entries.clone(), + current_entry: std::sync::Mutex::new(None), + }, ); self.forward_tasks.lock().unwrap().spawn(async move { @@ -275,17 +315,36 @@ impl Socks5ServerNet { } } +struct UdpClientInfo { + client_addr: SocketAddr, + port_holder_socket: Arc, + local_addr: SocketAddr, + last_active: AtomicCell, + entries: Socks5EntrySet, + entry_key: Socks5Entry, +} + +#[derive(Debug, Eq, PartialEq, Hash, Clone)] +struct UdpClientKey { + client_addr: SocketAddr, + dst_addr: SocketAddr, +} + pub struct Socks5Server { global_ctx: Arc, peer_manager: Arc, auth: Option, - tasks: Arc>>, + tasks: Arc>>, packet_sender: mpsc::Sender, packet_recv: Arc>>, net: Arc>>, entries: Socks5EntrySet, + + tcp_forward_task: Arc>>, + udp_client_map: Arc>>, + udp_forward_task: Arc>>, } #[async_trait::async_trait] @@ -299,22 +358,65 @@ impl PeerPacketFilter for Socks5Server { let payload_bytes = packet.payload(); let ipv4 = Ipv4Packet::new(payload_bytes).unwrap(); - if ipv4.get_version() != 4 || ipv4.get_next_level_protocol() != IpNextHeaderProtocols::Tcp { + if ipv4.get_version() != 4 { return Some(packet); } - let tcp_packet = TcpPacket::new(ipv4.payload()).unwrap(); - let entry = Socks5Entry { - dst: SocketAddr::new(ipv4.get_source().into(), tcp_packet.get_source()), - src: SocketAddr::new(ipv4.get_destination().into(), tcp_packet.get_destination()), + let entry_key = match ipv4.get_next_level_protocol() { + IpNextHeaderProtocols::Tcp => { + let tcp_packet = TcpPacket::new(ipv4.payload()).unwrap(); + Socks5Entry { + dst: SocketAddr::new(ipv4.get_source().into(), tcp_packet.get_source()), + src: SocketAddr::new( + ipv4.get_destination().into(), + tcp_packet.get_destination(), + ), + entry_type: TCP_ENTRY, + } + } + + IpNextHeaderProtocols::Udp => { + if IpReassembler::is_packet_fragmented(&ipv4) && !self.entries.is_empty() { + let ipv4_src: IpAddr = ipv4.get_source().into(); + // only send to smoltcp if the ipv4 src is in the entries + let is_in_entries = self.entries.iter().any(|x| x.key().dst.ip() == ipv4_src); + tracing::trace!( + ?is_in_entries, + "ipv4 src = {:?}, check need send both smoltcp and kernel tun", + ipv4_src + ); + if is_in_entries { + // if the packet is fragmented, no matther what the payload is, need send it to both smoltcp and kernel tun. because + // we cannot determine the udp port of the packet. + let _ = self.packet_sender.try_send(packet.clone()).ok(); + } + return Some(packet); + } + + let udp_packet = UdpPacket::new(ipv4.payload()).unwrap(); + Socks5Entry { + dst: SocketAddr::new(ipv4.get_source().into(), udp_packet.get_source()), + src: SocketAddr::new( + ipv4.get_destination().into(), + udp_packet.get_destination(), + ), + entry_type: UDP_ENTRY, + } + } + _ => { + return Some(packet); + } }; - if !self.entries.contains(&entry) { + if !self.entries.contains_key(&entry_key) { return Some(packet); } + tracing::trace!(?entry_key, ?ipv4, "socks5 found entry for packet from peer"); + let _ = self.packet_sender.try_send(packet).ok(); - return None; + + None } } @@ -330,12 +432,16 @@ impl Socks5Server { peer_manager, auth, - tasks: Arc::new(Mutex::new(JoinSet::new())), + tasks: Arc::new(std::sync::Mutex::new(JoinSet::new())), packet_recv: Arc::new(Mutex::new(packet_recv)), packet_sender, net: Arc::new(Mutex::new(None)), - entries: Arc::new(DashSet::new()), + entries: Arc::new(DashMap::new()), + + tcp_forward_task: Arc::new(std::sync::Mutex::new(JoinSet::new())), + udp_client_map: Arc::new(DashMap::new()), + udp_forward_task: Arc::new(DashMap::new()), }) } @@ -345,7 +451,9 @@ impl Socks5Server { let peer_manager = self.peer_manager.clone(); let packet_recv = self.packet_recv.clone(); let entries = self.entries.clone(); - self.tasks.lock().await.spawn(async move { + let tcp_forward_task = self.tcp_forward_task.clone(); + let udp_client_map = self.udp_client_map.clone(); + self.tasks.lock().unwrap().spawn(async move { let mut prev_ipv4 = None; loop { let mut event_recv = global_ctx.subscribe(); @@ -353,7 +461,10 @@ impl Socks5Server { let cur_ipv4 = global_ctx.get_ipv4(); if prev_ipv4 != cur_ipv4 { prev_ipv4 = cur_ipv4; + entries.clear(); + tcp_forward_task.lock().unwrap().abort_all(); + udp_client_map.clear(); if cur_ipv4.is_none() { let _ = net.lock().await.take(); @@ -377,42 +488,339 @@ impl Socks5Server { } pub async fn run(self: &Arc) -> Result<(), Error> { - let Some(proxy_url) = self.global_ctx.config.get_socks5_portal() else { - return Ok(()); + let mut need_start = false; + if let Some(proxy_url) = self.global_ctx.config.get_socks5_portal() { + let bind_addr = format!( + "{}:{}", + proxy_url.host_str().unwrap(), + proxy_url.port().unwrap() + ); + + let listener = { + let _g = self.global_ctx.net_ns.guard(); + TcpListener::bind(bind_addr.parse::().unwrap()).await? + }; + + let net = self.net.clone(); + self.tasks.lock().unwrap().spawn(async move { + loop { + match listener.accept().await { + Ok((socket, _addr)) => { + tracing::info!("accept a new connection, {:?}", socket); + if let Some(net) = net.lock().await.as_ref() { + net.handle_tcp_stream(socket); + } + } + Err(err) => tracing::error!("accept error = {:?}", err), + } + } + }); + + join_joinset_background(self.tasks.clone(), "socks5 server".to_string()); + + need_start = true; }; - let bind_addr = format!( - "{}:{}", - proxy_url.host_str().unwrap(), - proxy_url.port().unwrap() - ); + for port_forward in self.global_ctx.config.get_port_forwards() { + self.add_port_forward(port_forward).await?; + need_start = true; + } + if need_start { + self.peer_manager + .add_packet_process_pipeline(Box::new(self.clone())) + .await; + + self.run_net_update_task().await; + } + + Ok(()) + } + + async fn handle_port_forward_connection( + mut incoming_socket: tokio::net::TcpStream, + connector: SmolTcpConnector, + dst_addr: SocketAddr, + ) { + let outgoing_socket = match connector.tcp_connect(dst_addr, 10).await { + Ok(socket) => socket, + Err(e) => { + tracing::error!("port forward: failed to connect to destination: {:?}", e); + return; + } + }; + + let mut outgoing_socket = outgoing_socket; + match tokio::io::copy_bidirectional(&mut incoming_socket, &mut outgoing_socket).await { + Ok((from_client, from_server)) => { + tracing::info!( + "port forward connection finished: client->server: {} bytes, server->client: {} bytes", + from_client, from_server + ); + } + Err(e) => { + tracing::error!("port forward connection error: {:?}", e); + } + } + } + + pub async fn add_port_forward(&self, cfg: PortForwardConfig) -> Result<(), Error> { + match cfg.proto.to_lowercase().as_str() { + "tcp" => { + self.add_tcp_port_forward(cfg.bind_addr, cfg.dst_addr) + .await?; + } + "udp" => { + self.add_udp_port_forward(cfg.bind_addr, cfg.dst_addr) + .await?; + } + _ => { + return Err(anyhow::anyhow!( + "unsupported protocol: {}, only support udp / tcp", + cfg.proto + ) + .into()); + } + } + self.global_ctx + .issue_event(GlobalCtxEvent::PortForwardAdded(cfg.clone().into())); + Ok(()) + } + + pub async fn add_tcp_port_forward( + &self, + bind_addr: SocketAddr, + dst_addr: SocketAddr, + ) -> Result<(), Error> { let listener = { let _g = self.global_ctx.net_ns.guard(); - TcpListener::bind(bind_addr.parse::().unwrap()).await? + TcpListener::bind(bind_addr).await? }; - self.peer_manager - .add_packet_process_pipeline(Box::new(self.clone())) - .await; - - self.run_net_update_task().await; - let net = self.net.clone(); - self.tasks.lock().await.spawn(async move { + let entries = self.entries.clone(); + let tasks = Arc::new(std::sync::Mutex::new(JoinSet::new())); + let forward_tasks = tasks.clone(); + + self.tasks.lock().unwrap().spawn(async move { loop { - match listener.accept().await { - Ok((socket, _addr)) => { - tracing::info!("accept a new connection, {:?}", socket); - if let Some(net) = net.lock().await.as_ref() { - net.handle_tcp_stream(socket); - } + let (incoming_socket, _addr) = match listener.accept().await { + Ok(result) => result, + Err(err) => { + tracing::error!("port forward accept error = {:?}", err); + continue; } - Err(err) => tracing::error!("accept error = {:?}", err), + }; + + tracing::info!( + "port forward: accept new connection from {:?} to {:?}", + bind_addr, + dst_addr + ); + + let net_guard = net.lock().await; + let Some(net) = net_guard.as_ref() else { + tracing::error!("net is not ready"); + continue; + }; + + let connector = SmolTcpConnector { + net: net.smoltcp_net.clone(), + entries: entries.clone(), + current_entry: std::sync::Mutex::new(None), + }; + + forward_tasks + .lock() + .unwrap() + .spawn(Self::handle_port_forward_connection( + incoming_socket, + connector, + dst_addr, + )); + } + }); + + Ok(()) + } + + #[tracing::instrument(name = "add_udp_port_forward", skip(self))] + pub async fn add_udp_port_forward( + &self, + bind_addr: SocketAddr, + dst_addr: SocketAddr, + ) -> Result<(), Error> { + let socket = { + let _g = self.global_ctx.net_ns.guard(); + Arc::new(UdpSocket::bind(bind_addr).await?) + }; + + let entries = self.entries.clone(); + let net_ns = self.global_ctx.net_ns.clone(); + let net = self.net.clone(); + let udp_client_map = self.udp_client_map.clone(); + let udp_forward_task = self.udp_forward_task.clone(); + + self.tasks.lock().unwrap().spawn(async move { + loop { + // we set the max buffer size of smoltcp to 8192, so we need to use a buffer size that is less than 8192 here. + let mut buf = vec![0u8; 8192]; + let (len, addr) = match socket.recv_from(&mut buf).await { + Ok(result) => result, + Err(err) => { + tracing::error!("udp port forward recv error = {:?}", err); + continue; + } + }; + + tracing::trace!( + "udp port forward recv packet from {:?}, len = {}", + addr, + len + ); + + let udp_client_key = UdpClientKey { + client_addr: addr, + dst_addr, + }; + + let binded_socket = udp_client_map.get(&udp_client_key); + let client_info = match binded_socket { + Some(s) => s.clone(), + None => { + let _g = net_ns.guard(); + // reserve a port so os will not use it to connect to the virtual network + let binded_socket = tokio::net::UdpSocket::bind("0.0.0.0:0").await; + if binded_socket.is_err() { + tracing::error!("udp port forward bind error = {:?}", binded_socket); + continue; + } + let binded_socket = binded_socket.unwrap(); + let mut local_addr = binded_socket.local_addr().unwrap(); + let Some(cur_ipv4) = net.lock().await.as_ref().map(|net| net.ipv4_addr) else { + continue; + }; + local_addr.set_ip(cur_ipv4.address().into()); + + let entry_key = Socks5Entry { + src: local_addr, + dst: dst_addr, + entry_type: UDP_ENTRY, + }; + + tracing::debug!("udp port forward binded socket = {:?}, entry_key = {:?}", local_addr, entry_key); + + let client_info = Arc::new(UdpClientInfo { + client_addr: addr, + port_holder_socket: Arc::new(binded_socket), + local_addr, + last_active: AtomicCell::new(Instant::now()), + entries: entries.clone(), + entry_key, + }); + udp_client_map.insert(udp_client_key.clone(), client_info.clone()); + client_info + } + }; + + client_info.last_active.store(Instant::now()); + + let entry_data = match entries.get(&client_info.entry_key) { + Some(data) => data, + None => { + let guard = net.lock().await; + let Some(net) = guard.as_ref() else { + continue; + }; + let local_addr = net.ipv4_addr; + let sokcs_udp = if dst_addr.ip() == local_addr.address() { + SocksUdpSocket::UdpSocket(client_info.port_holder_socket.clone()) + } else { + tracing::debug!("udp port forward bind new smol udp socket, {:?}", local_addr); + SocksUdpSocket::SmolUdpSocket( + net.smoltcp_net + .udp_bind(SocketAddr::new( + IpAddr::V4(local_addr.address()), + client_info.local_addr.port(), + )) + .await + .unwrap(), + ) + }; + let socks_udp = Arc::new(sokcs_udp); + entries.insert( + client_info.entry_key.clone(), + Socks5EntryData::Udp((socks_udp.clone(), udp_client_key.clone())), + ); + + let socks = socket.clone(); + let client_addr = addr; + udp_forward_task.insert( + udp_client_key.clone(), + ScopedTask::from(tokio::spawn(async move { + loop { + let mut buf = vec![0u8; 8192]; + match socks_udp.recv_from(&mut buf).await { + Ok((len, dst_addr)) => { + tracing::trace!( + "udp port forward recv response packet from {:?}, len = {}, client_addr = {:?}", + dst_addr, + len, + client_addr + ); + if let Err(e) = socks.send_to(&buf[..len], client_addr).await { + tracing::error!("udp forward send error = {:?}", e); + } + } + Err(e) => { + tracing::error!("udp forward recv error = {:?}", e); + } + } + } + })), + ); + + entries.get(&client_info.entry_key).unwrap() + } + }; + + let s = match entry_data.value() { + Socks5EntryData::Udp((s, _)) => s.clone(), + _ => { + panic!("udp entry data is not udp entry data"); + } + }; + drop(entry_data); + + if let Err(e) = s.send_to(&buf[..len], dst_addr).await { + tracing::error!(?dst_addr, ?len, "udp port forward send error = {:?}", e); + } else { + tracing::trace!(?dst_addr, ?len, "udp port forward send packet success"); } } }); + // clean up task + let udp_client_map = self.udp_client_map.clone(); + let udp_forward_task = self.udp_forward_task.clone(); + let entries = self.entries.clone(); + self.tasks.lock().unwrap().spawn(async move { + loop { + tokio::time::sleep(Duration::from_secs(30)).await; + let now = Instant::now(); + udp_client_map.retain(|_, client_info| { + now.duration_since(client_info.last_active.load()).as_secs() < 600 + }); + udp_forward_task.retain(|k, _| udp_client_map.contains_key(&k)); + entries.retain(|_, data| match data { + Socks5EntryData::Udp((_, udp_client_key)) => { + udp_client_map.contains_key(&udp_client_key) + } + _ => true, + }); + } + }); + Ok(()) } } diff --git a/easytier/src/gateway/tokio_smoltcp/mod.rs b/easytier/src/gateway/tokio_smoltcp/mod.rs index 58805d0b..9f2f0350 100644 --- a/easytier/src/gateway/tokio_smoltcp/mod.rs +++ b/easytier/src/gateway/tokio_smoltcp/mod.rs @@ -20,7 +20,7 @@ use smoltcp::{ time::{Duration, Instant}, wire::{HardwareAddress, IpAddress, IpCidr}, }; -pub use socket::{TcpListener, TcpStream}; +pub use socket::{TcpListener, TcpStream, UdpSocket}; pub use socket_allocator::BufferSize; use tokio::sync::Notify; @@ -158,6 +158,13 @@ impl Net { ) .await } + + /// This function will create a new UDP socket and attempt to bind it to the `addr` provided. + pub async fn udp_bind(&self, addr: SocketAddr) -> io::Result { + let addr = self.set_address(addr); + UdpSocket::new(self.reactor.clone(), addr.into()).await + } + fn set_address(&self, mut addr: SocketAddr) -> SocketAddr { if addr.ip().is_unspecified() { addr.set_ip(match self.ip_addr.address() { diff --git a/easytier/src/gateway/tokio_smoltcp/socket.rs b/easytier/src/gateway/tokio_smoltcp/socket.rs index f46f80f1..1b1e6fcf 100644 --- a/easytier/src/gateway/tokio_smoltcp/socket.rs +++ b/easytier/src/gateway/tokio_smoltcp/socket.rs @@ -2,6 +2,7 @@ use super::{reactor::Reactor, socket_allocator::SocketHandle}; use futures::future::{self, poll_fn}; use futures::{ready, Stream}; pub use smoltcp::socket::tcp; +use smoltcp::socket::udp; use smoltcp::wire::{IpAddress, IpEndpoint}; use std::mem::replace; use std::net::{IpAddr, Ipv4Addr, Ipv6Addr}; @@ -247,3 +248,86 @@ impl AsyncWrite for TcpStream { Poll::Pending } } + +/// A UDP socket. +pub struct UdpSocket { + handle: SocketHandle, + reactor: Arc, + local_addr: SocketAddr, +} + +impl UdpSocket { + pub(super) async fn new( + reactor: Arc, + local_endpoint: IpEndpoint, + ) -> io::Result { + let handle = reactor.socket_allocator().new_udp_socket(); + { + let mut socket = reactor.get_socket::(*handle); + socket.bind(local_endpoint).map_err(map_err)?; + } + + let local_addr = ep2sa(&local_endpoint); + + Ok(UdpSocket { + handle, + reactor, + local_addr, + }) + } + /// Note that on multiple calls to a poll_* method in the send direction, only the Waker from the Context passed to the most recent call will be scheduled to receive a wakeup. + pub fn poll_send_to( + &self, + cx: &mut Context<'_>, + buf: &[u8], + target: SocketAddr, + ) -> Poll> { + let mut socket = self.reactor.get_socket::(*self.handle); + let target_ip: IpEndpoint = target.into(); + + match socket.send_slice(buf, target_ip) { + // the buffer is full + Err(udp::SendError::BufferFull) => {} + r => { + r.map_err(map_err)?; + self.reactor.notify(); + return Poll::Ready(Ok(buf.len())); + } + } + + socket.register_send_waker(cx.waker()); + Poll::Pending + } + /// See note on `poll_send_to` + pub async fn send_to(&self, buf: &[u8], target: SocketAddr) -> io::Result { + poll_fn(|cx| self.poll_send_to(cx, buf, target)).await + } + /// Note that on multiple calls to a poll_* method in the recv direction, only the Waker from the Context passed to the most recent call will be scheduled to receive a wakeup. + pub fn poll_recv_from( + &self, + cx: &mut Context<'_>, + buf: &mut [u8], + ) -> Poll> { + let mut socket = self.reactor.get_socket::(*self.handle); + + match socket.recv_slice(buf) { + // the buffer is empty + Err(udp::RecvError::Exhausted) => {} + r => { + let (size, metadata) = r.map_err(map_err)?; + self.reactor.notify(); + return Poll::Ready(Ok((size, ep2sa(&metadata.endpoint)))); + } + } + + socket.register_recv_waker(cx.waker()); + Poll::Pending + } + /// See note on `poll_recv_from` + pub async fn recv_from(&self, buf: &mut [u8]) -> io::Result<(usize, SocketAddr)> { + poll_fn(|cx| self.poll_recv_from(cx, buf)).await + } + pub fn local_addr(&self) -> io::Result { + Ok(self.local_addr) + } +} diff --git a/easytier/src/gateway/tokio_smoltcp/socket_allocator.rs b/easytier/src/gateway/tokio_smoltcp/socket_allocator.rs index a95602f7..934d40f8 100644 --- a/easytier/src/gateway/tokio_smoltcp/socket_allocator.rs +++ b/easytier/src/gateway/tokio_smoltcp/socket_allocator.rs @@ -1,7 +1,7 @@ use parking_lot::Mutex; use smoltcp::{ iface::{SocketHandle as InnerSocketHandle, SocketSet}, - socket::tcp, + socket::{tcp, udp}, time::Duration, }; use std::{ @@ -14,6 +14,11 @@ use std::{ pub struct BufferSize { pub tcp_rx_size: usize, pub tcp_tx_size: usize, + + pub udp_rx_size: usize, + pub udp_tx_size: usize, + pub udp_rx_meta_size: usize, + pub udp_tx_meta_size: usize, } impl Default for BufferSize { @@ -21,6 +26,11 @@ impl Default for BufferSize { BufferSize { tcp_rx_size: 8192, tcp_tx_size: 8192, + + udp_rx_size: 8192, + udp_tx_size: 8192, + udp_rx_meta_size: 32, + udp_tx_meta_size: 32, } } } @@ -59,6 +69,26 @@ impl SocketAlloctor { tcp } + + pub fn new_udp_socket(&self) -> SocketHandle { + let mut set = self.sockets.lock(); + let handle = set.add(self.alloc_udp_socket()); + SocketHandle::new(handle, self.sockets.clone()) + } + + fn alloc_udp_socket(&self) -> udp::Socket<'static> { + let rx_buffer = udp::PacketBuffer::new( + vec![udp::PacketMetadata::EMPTY; self.buffer_size.udp_rx_meta_size], + vec![0; self.buffer_size.udp_rx_size], + ); + let tx_buffer = udp::PacketBuffer::new( + vec![udp::PacketMetadata::EMPTY; self.buffer_size.udp_tx_meta_size], + vec![0; self.buffer_size.udp_tx_size], + ); + let udp = udp::Socket::new(rx_buffer, tx_buffer); + + udp + } } pub struct SocketHandle(InnerSocketHandle, SharedSocketSet); diff --git a/easytier/src/proto/common.proto b/easytier/src/proto/common.proto index 2494bd0e..6d211832 100644 --- a/easytier/src/proto/common.proto +++ b/easytier/src/proto/common.proto @@ -155,3 +155,14 @@ message PeerFeatureFlag { bool kcp_input = 3; bool no_relay_kcp = 4; } + +enum SocketType { + TCP = 0; + UDP = 1; +} + +message PortForwardConfigPb { + SocketAddr bind_addr = 1; + SocketAddr dst_addr = 2; + SocketType socket_type = 3; +} diff --git a/easytier/src/proto/common.rs b/easytier/src/proto/common.rs index a40e4711..47aff39b 100644 --- a/easytier/src/proto/common.rs +++ b/easytier/src/proto/common.rs @@ -190,6 +190,12 @@ impl From for std::net::SocketAddr { } } +impl ToString for SocketAddr { + fn to_string(&self) -> String { + std::net::SocketAddr::from(self.clone()).to_string() + } +} + impl TryFrom for CompressorAlgo { type Error = anyhow::Error; diff --git a/easytier/src/tests/three_node.rs b/easytier/src/tests/three_node.rs index 5fcbc1f4..b13f8376 100644 --- a/easytier/src/tests/three_node.rs +++ b/easytier/src/tests/three_node.rs @@ -4,13 +4,14 @@ use std::{ time::Duration, }; +use rand::Rng; use tokio::{net::UdpSocket, task::JoinSet}; use super::*; use crate::{ common::{ - config::{ConfigLoader, NetworkIdentity, TomlConfigLoader}, + config::{ConfigLoader, NetworkIdentity, PortForwardConfig, TomlConfigLoader}, netns::{NetNS, ROOT_NETNS_NAME}, }, instance::instance::Instance, @@ -890,3 +891,120 @@ pub async fn manual_reconnector(#[values(true, false)] is_foreign: bool) { ) .await; } + +#[rstest::rstest] +#[tokio::test] +#[serial_test::serial] +pub async fn port_forward_test( + #[values(true, false)] no_tun: bool, + #[values(64, 1900)] buf_size: u64, +) { + prepare_linux_namespaces(); + + let _insts = init_three_node_ex( + "udp", + |cfg| { + if cfg.get_inst_name() == "inst1" { + cfg.set_port_forwards(vec![ + // test port forward to other virtual node + PortForwardConfig { + bind_addr: "0.0.0.0:23456".parse().unwrap(), + dst_addr: "10.144.144.3:23456".parse().unwrap(), + proto: "tcp".to_string(), + }, + // test port forward to subnet proxy + PortForwardConfig { + bind_addr: "0.0.0.0:23457".parse().unwrap(), + dst_addr: "10.1.2.4:23457".parse().unwrap(), + proto: "tcp".to_string(), + }, + // test udp port forward to other virtual node + PortForwardConfig { + bind_addr: "0.0.0.0:23458".parse().unwrap(), + dst_addr: "10.144.144.3:23458".parse().unwrap(), + proto: "udp".to_string(), + }, + // test udp port forward to subnet proxy + PortForwardConfig { + bind_addr: "0.0.0.0:23459".parse().unwrap(), + dst_addr: "10.1.2.4:23459".parse().unwrap(), + proto: "udp".to_string(), + }, + ]); + } else if cfg.get_inst_name() == "inst3" { + cfg.add_proxy_cidr("10.1.2.0/24".parse().unwrap()); + } + let mut flags = cfg.get_flags(); + flags.no_tun = no_tun; + cfg.set_flags(flags); + cfg + }, + false, + ) + .await; + + use crate::tunnel::{ + common::tests::_tunnel_pingpong_netns, tcp::TcpTunnelListener, udp::UdpTunnelConnector, + udp::UdpTunnelListener, + }; + + let tcp_listener = TcpTunnelListener::new("tcp://0.0.0.0:23456".parse().unwrap()); + let tcp_connector = TcpTunnelConnector::new("tcp://127.0.0.1:23456".parse().unwrap()); + + let mut buf = vec![0; buf_size as usize]; + rand::thread_rng().fill(&mut buf[..]); + + _tunnel_pingpong_netns( + tcp_listener, + tcp_connector, + NetNS::new(Some("net_c".into())), + NetNS::new(Some("net_a".into())), + buf, + ) + .await; + + let tcp_listener = TcpTunnelListener::new("tcp://0.0.0.0:23457".parse().unwrap()); + let tcp_connector = TcpTunnelConnector::new("tcp://127.0.0.1:23457".parse().unwrap()); + + let mut buf = vec![0; buf_size as usize]; + rand::thread_rng().fill(&mut buf[..]); + + _tunnel_pingpong_netns( + tcp_listener, + tcp_connector, + NetNS::new(Some("net_d".into())), + NetNS::new(Some("net_a".into())), + buf, + ) + .await; + + let udp_listener = UdpTunnelListener::new("udp://0.0.0.0:23458".parse().unwrap()); + let udp_connector = UdpTunnelConnector::new("udp://127.0.0.1:23458".parse().unwrap()); + + let mut buf = vec![0; buf_size as usize]; + rand::thread_rng().fill(&mut buf[..]); + + _tunnel_pingpong_netns( + udp_listener, + udp_connector, + NetNS::new(Some("net_c".into())), + NetNS::new(Some("net_a".into())), + buf, + ) + .await; + + let udp_listener = UdpTunnelListener::new("udp://0.0.0.0:23459".parse().unwrap()); + let udp_connector = UdpTunnelConnector::new("udp://127.0.0.1:23459".parse().unwrap()); + + let mut buf = vec![0; buf_size as usize]; + rand::thread_rng().fill(&mut buf[..]); + + _tunnel_pingpong_netns( + udp_listener, + udp_connector, + NetNS::new(Some("net_d".into())), + NetNS::new(Some("net_a".into())), + buf, + ) + .await; +} diff --git a/easytier/src/tunnel/common.rs b/easytier/src/tunnel/common.rs index b834d15e..7d197a32 100644 --- a/easytier/src/tunnel/common.rs +++ b/easytier/src/tunnel/common.rs @@ -454,6 +454,7 @@ pub mod tests { let Ok(msg) = item else { continue; }; + tracing::debug!(?msg, "recv a msg, try echo back"); if let Err(_) = send.send(msg).await { break; }