From 2ec88da82329db26344679573783106bb210d60d Mon Sep 17 00:00:00 2001 From: "Sijie.Sun" Date: Tue, 29 Jul 2025 09:30:47 +0800 Subject: [PATCH] cli for port forward and tcp whitelist (#1165) --- easytier/src/common/acl_processor.rs | 143 +++++++++- easytier/src/common/config.rs | 37 ++- easytier/src/easytier-cli.rs | 391 ++++++++++++++++++++++++++- easytier/src/easytier-core.rs | 125 +-------- easytier/src/gateway/socks5.rs | 100 +++++-- easytier/src/instance/instance.rs | 103 ++++++- easytier/src/peers/rpc_service.rs | 60 +++- easytier/src/proto/cli.proto | 40 +++ 8 files changed, 828 insertions(+), 171 deletions(-) diff --git a/easytier/src/common/acl_processor.rs b/easytier/src/common/acl_processor.rs index de101d89..60863073 100644 --- a/easytier/src/common/acl_processor.rs +++ b/easytier/src/common/acl_processor.rs @@ -6,8 +6,9 @@ use std::{ time::{Duration, SystemTime, UNIX_EPOCH}, }; -use crate::common::token_bucket::TokenBucket; +use crate::common::{config::ConfigLoader, global_ctx::ArcGlobalCtx, token_bucket::TokenBucket}; use crate::proto::acl::*; +use anyhow::Context as _; use dashmap::DashMap; use tokio::task::JoinSet; @@ -993,6 +994,146 @@ impl AclStatKey { } } +pub struct AclRuleBuilder { + pub acl: Option, + pub tcp_whitelist: Vec, + pub udp_whitelist: Vec, + pub whitelist_priority: Option, +} + +impl AclRuleBuilder { + fn parse_port_list(port_list: &[String]) -> anyhow::Result> { + let mut ports = Vec::new(); + + for port_spec in port_list { + if port_spec.contains('-') { + // Handle port range like "8000-9000" + let parts: Vec<&str> = port_spec.split('-').collect(); + if parts.len() != 2 { + return Err(anyhow::anyhow!("Invalid port range format: {}", port_spec)); + } + + let start: u16 = parts[0] + .parse() + .with_context(|| format!("Invalid start port in range: {}", port_spec))?; + let end: u16 = parts[1] + .parse() + .with_context(|| format!("Invalid end port in range: {}", port_spec))?; + + if start > end { + return Err(anyhow::anyhow!( + "Start port must be <= end port in range: {}", + port_spec + )); + } + + // acl can handle port range + ports.push(port_spec.clone()); + } else { + // Handle single port + let port: u16 = port_spec + .parse() + .with_context(|| format!("Invalid port number: {}", port_spec))?; + ports.push(port.to_string()); + } + } + + Ok(ports) + } + + fn generate_acl_from_whitelists(&mut self) -> anyhow::Result<()> { + if self.tcp_whitelist.is_empty() && self.udp_whitelist.is_empty() { + return Ok(()); + } + + // Create inbound chain for whitelist rules + let mut inbound_chain = Chain { + name: "inbound_whitelist".to_string(), + chain_type: ChainType::Inbound as i32, + description: "Auto-generated inbound whitelist from CLI".to_string(), + enabled: true, + rules: vec![], + default_action: Action::Drop as i32, // Default deny + }; + + let mut rule_priority = self.whitelist_priority.unwrap_or(1000u32); + + // Add TCP whitelist rules + if !self.tcp_whitelist.is_empty() { + let tcp_ports = Self::parse_port_list(&self.tcp_whitelist)?; + let tcp_rule = Rule { + name: "tcp_whitelist".to_string(), + description: "Auto-generated TCP whitelist rule".to_string(), + priority: rule_priority, + enabled: true, + protocol: Protocol::Tcp as i32, + ports: tcp_ports, + source_ips: vec![], + destination_ips: vec![], + source_ports: vec![], + action: Action::Allow as i32, + rate_limit: 0, + burst_limit: 0, + stateful: true, + }; + inbound_chain.rules.push(tcp_rule); + rule_priority -= 1; + } + + // Add UDP whitelist rules + if !self.udp_whitelist.is_empty() { + let udp_ports = Self::parse_port_list(&self.udp_whitelist)?; + let udp_rule = Rule { + name: "udp_whitelist".to_string(), + description: "Auto-generated UDP whitelist rule".to_string(), + priority: rule_priority, + enabled: true, + protocol: Protocol::Udp as i32, + ports: udp_ports, + source_ips: vec![], + destination_ips: vec![], + source_ports: vec![], + action: Action::Allow as i32, + rate_limit: 0, + burst_limit: 0, + stateful: false, + }; + inbound_chain.rules.push(udp_rule); + } + + if self.acl.is_none() { + self.acl = Some(Acl::default()); + } + + let acl = self.acl.as_mut().unwrap(); + + if let Some(ref mut acl_v1) = acl.acl_v1 { + acl_v1.chains.push(inbound_chain); + } else { + acl.acl_v1 = Some(AclV1 { + chains: vec![inbound_chain], + }); + } + + Ok(()) + } + + fn do_build(mut self) -> anyhow::Result> { + self.generate_acl_from_whitelists()?; + Ok(self.acl.clone()) + } + + pub fn build(global_ctx: &ArcGlobalCtx) -> anyhow::Result> { + let builder = AclRuleBuilder { + acl: global_ctx.config.get_acl(), + tcp_whitelist: global_ctx.config.get_tcp_whitelist(), + udp_whitelist: global_ctx.config.get_udp_whitelist(), + whitelist_priority: None, + }; + builder.do_build() + } +} + #[derive(Debug, Clone, Copy)] pub enum AclStatType { Total, diff --git a/easytier/src/common/config.rs b/easytier/src/common/config.rs index a2d84957..e9eeedb8 100644 --- a/easytier/src/common/config.rs +++ b/easytier/src/common/config.rs @@ -122,6 +122,12 @@ pub trait ConfigLoader: Send + Sync { fn get_acl(&self) -> Option; fn set_acl(&self, acl: Option); + fn get_tcp_whitelist(&self) -> Vec; + fn set_tcp_whitelist(&self, whitelist: Vec); + + fn get_udp_whitelist(&self) -> Vec; + fn set_udp_whitelist(&self, whitelist: Vec); + fn dump(&self) -> String; } @@ -230,7 +236,7 @@ pub struct VpnPortalConfig { pub wireguard_listen: SocketAddr, } -#[derive(Debug, Clone, Deserialize, Serialize, PartialEq)] +#[derive(Debug, Clone, Deserialize, Serialize, PartialEq, Eq, Hash)] pub struct PortForwardConfig { pub bind_addr: SocketAddr, pub dst_addr: SocketAddr, @@ -299,6 +305,9 @@ struct Config { flags_struct: Option, acl: Option, + + tcp_whitelist: Option>, + udp_whitelist: Option>, } #[derive(Debug, Clone)] @@ -665,6 +674,32 @@ impl ConfigLoader for TomlConfigLoader { self.config.lock().unwrap().acl = acl; } + fn get_tcp_whitelist(&self) -> Vec { + self.config + .lock() + .unwrap() + .tcp_whitelist + .clone() + .unwrap_or_default() + } + + fn set_tcp_whitelist(&self, whitelist: Vec) { + self.config.lock().unwrap().tcp_whitelist = Some(whitelist); + } + + fn get_udp_whitelist(&self) -> Vec { + self.config + .lock() + .unwrap() + .udp_whitelist + .clone() + .unwrap_or_default() + } + + fn set_udp_whitelist(&self, whitelist: Vec) { + self.config.lock().unwrap().udp_whitelist = Some(whitelist); + } + fn dump(&self) -> String { let default_flags_json = serde_json::to_string(&gen_default_flags()).unwrap(); let default_flags_hashmap = diff --git a/easytier/src/easytier-cli.rs b/easytier/src/easytier-cli.rs index 9d2c0240..608d6169 100644 --- a/easytier/src/easytier-cli.rs +++ b/easytier/src/easytier-cli.rs @@ -22,23 +22,25 @@ use tokio::time::timeout; use easytier::{ common::{ + config::PortForwardConfig, constants::EASYTIER_VERSION, stun::{StunInfoCollector, StunInfoCollectorTrait}, }, proto::{ cli::{ - list_peer_route_pair, AclManageRpc, AclManageRpcClientFactory, ConnectorManageRpc, - ConnectorManageRpcClientFactory, DumpRouteRequest, GetAclStatsRequest, - GetVpnPortalInfoRequest, ListConnectorRequest, ListForeignNetworkRequest, - ListGlobalForeignNetworkRequest, ListMappedListenerRequest, ListPeerRequest, - ListPeerResponse, ListRouteRequest, ListRouteResponse, ManageMappedListenerRequest, - MappedListenerManageAction, MappedListenerManageRpc, - MappedListenerManageRpcClientFactory, NodeInfo, PeerManageRpc, - PeerManageRpcClientFactory, ShowNodeInfoRequest, TcpProxyEntryState, + list_peer_route_pair, AclManageRpc, AclManageRpcClientFactory, AddPortForwardRequest, + ConnectorManageRpc, ConnectorManageRpcClientFactory, DumpRouteRequest, + GetAclStatsRequest, GetVpnPortalInfoRequest, GetWhitelistRequest, ListConnectorRequest, + ListForeignNetworkRequest, ListGlobalForeignNetworkRequest, ListMappedListenerRequest, + ListPeerRequest, ListPeerResponse, ListPortForwardRequest, ListRouteRequest, + ListRouteResponse, ManageMappedListenerRequest, MappedListenerManageAction, + MappedListenerManageRpc, MappedListenerManageRpcClientFactory, NodeInfo, PeerManageRpc, + PeerManageRpcClientFactory, PortForwardManageRpc, PortForwardManageRpcClientFactory, + RemovePortForwardRequest, SetWhitelistRequest, ShowNodeInfoRequest, TcpProxyEntryState, TcpProxyEntryTransportType, TcpProxyRpc, TcpProxyRpcClientFactory, VpnPortalRpc, VpnPortalRpcClientFactory, }, - common::NatType, + common::{NatType, SocketType}, peer_rpc::{GetGlobalPeerMapRequest, PeerCenterRpc, PeerCenterRpcClientFactory}, rpc_impl::standalone::StandAloneClient, rpc_types::controller::BaseController, @@ -96,6 +98,10 @@ enum SubCommand { Proxy, #[command(about = "show ACL rules statistics")] Acl(AclArgs), + #[command(about = "manage port forwarding")] + PortForward(PortForwardArgs), + #[command(about = "manage TCP/UDP whitelist")] + Whitelist(WhitelistArgs), #[command(about = t!("core_clap.generate_completions").to_string())] GenAutocomplete { shell: Shell }, } @@ -193,6 +199,62 @@ enum AclSubCommand { Stats, } +#[derive(Args, Debug)] +struct PortForwardArgs { + #[command(subcommand)] + sub_command: Option, +} + +#[derive(Subcommand, Debug)] +enum PortForwardSubCommand { + /// Add port forward rule + Add { + #[arg(help = "Protocol (tcp/udp)")] + protocol: String, + #[arg(help = "Local bind address (e.g., 0.0.0.0:8080)")] + bind_addr: String, + #[arg(help = "Destination address (e.g., 10.1.1.1:80)")] + dst_addr: String, + }, + /// Remove port forward rule + Remove { + #[arg(help = "Protocol (tcp/udp)")] + protocol: String, + #[arg(help = "Local bind address (e.g., 0.0.0.0:8080)")] + bind_addr: String, + #[arg(help = "Optional Destination address (e.g., 10.1.1.1:80)")] + dst_addr: Option, + }, + /// List port forward rules + List, +} + +#[derive(Args, Debug)] +struct WhitelistArgs { + #[command(subcommand)] + sub_command: Option, +} + +#[derive(Subcommand, Debug)] +enum WhitelistSubCommand { + /// Set TCP port whitelist + SetTcp { + #[arg(help = "TCP ports (e.g., 80,443,8000-9000)")] + ports: String, + }, + /// Set UDP port whitelist + SetUdp { + #[arg(help = "UDP ports (e.g., 53,5000-6000)")] + ports: String, + }, + /// Clear TCP whitelist + ClearTcp, + /// Clear UDP whitelist + ClearUdp, + /// Show current whitelist configuration + Show, +} + #[derive(Args, Debug)] struct ServiceArgs { #[arg(short, long, default_value = env!("CARGO_PKG_NAME"), help = "service name")] @@ -340,6 +402,18 @@ impl CommandHandler<'_> { .with_context(|| "failed to get vpn portal client")?) } + async fn get_port_forward_manager_client( + &self, + ) -> Result>, Error> { + Ok(self + .client + .lock() + .unwrap() + .scoped_client::>("".to_string()) + .await + .with_context(|| "failed to get port forward manager client")?) + } + async fn list_peers(&self) -> Result { let client = self.get_peer_manager_client().await?; let request = ListPeerRequest::default(); @@ -788,6 +862,265 @@ impl CommandHandler<'_> { } Ok(url) } + + async fn handle_port_forward_add( + &self, + protocol: &str, + bind_addr: &str, + dst_addr: &str, + ) -> Result<(), Error> { + let bind_addr: std::net::SocketAddr = bind_addr + .parse() + .with_context(|| format!("Invalid bind address: {}", bind_addr))?; + let dst_addr: std::net::SocketAddr = dst_addr + .parse() + .with_context(|| format!("Invalid destination address: {}", dst_addr))?; + + if protocol != "tcp" && protocol != "udp" { + return Err(anyhow::anyhow!("Protocol must be 'tcp' or 'udp'")); + } + + let client = self.get_port_forward_manager_client().await?; + let request = AddPortForwardRequest { + cfg: Some( + PortForwardConfig { + proto: protocol.to_string(), + bind_addr: bind_addr.into(), + dst_addr: dst_addr.into(), + } + .into(), + ), + }; + + client + .add_port_forward(BaseController::default(), request) + .await?; + println!( + "Port forward rule added: {} {} -> {}", + protocol, bind_addr, dst_addr + ); + Ok(()) + } + + async fn handle_port_forward_remove( + &self, + protocol: &str, + bind_addr: &str, + dst_addr: Option<&str>, + ) -> Result<(), Error> { + let bind_addr: std::net::SocketAddr = bind_addr + .parse() + .with_context(|| format!("Invalid bind address: {}", bind_addr))?; + + if protocol != "tcp" && protocol != "udp" { + return Err(anyhow::anyhow!("Protocol must be 'tcp' or 'udp'")); + } + + let client = self.get_port_forward_manager_client().await?; + let request = RemovePortForwardRequest { + cfg: Some( + PortForwardConfig { + proto: protocol.to_string(), + bind_addr: bind_addr.into(), + dst_addr: dst_addr + .map(|s| s.parse::().unwrap()) + .map(Into::into) + .unwrap_or("0.0.0.0:0".parse::().unwrap().into()), + } + .into(), + ), + }; + + client + .remove_port_forward(BaseController::default(), request) + .await?; + println!("Port forward rule removed: {} {}", protocol, bind_addr); + Ok(()) + } + + async fn handle_port_forward_list(&self) -> Result<(), Error> { + let client = self.get_port_forward_manager_client().await?; + let request = ListPortForwardRequest::default(); + let response = client + .list_port_forward(BaseController::default(), request) + .await?; + + if self.verbose || *self.output_format == OutputFormat::Json { + println!("{}", serde_json::to_string_pretty(&response)?); + return Ok(()); + } + + #[derive(tabled::Tabled, serde::Serialize)] + struct PortForwardTableItem { + protocol: String, + bind_addr: String, + dst_addr: String, + } + + let items: Vec = response + .cfgs + .into_iter() + .map(|rule| PortForwardTableItem { + protocol: format!( + "{:?}", + SocketType::try_from(rule.socket_type).unwrap_or(SocketType::Tcp) + ), + bind_addr: rule + .bind_addr + .map(|addr| addr.to_string()) + .unwrap_or_default(), + dst_addr: rule + .dst_addr + .map(|addr| addr.to_string()) + .unwrap_or_default(), + }) + .collect(); + + print_output(&items, self.output_format)?; + Ok(()) + } + + async fn handle_whitelist_set_tcp(&self, ports: &str) -> Result<(), Error> { + let tcp_ports = Self::parse_port_list(ports)?; + let client = self.get_acl_manager_client().await?; + + // Get current UDP ports to preserve them + let current = client + .get_whitelist(BaseController::default(), GetWhitelistRequest::default()) + .await?; + let request = SetWhitelistRequest { + tcp_ports, + udp_ports: current.udp_ports, + }; + + client + .set_whitelist(BaseController::default(), request) + .await?; + println!("TCP whitelist updated: {}", ports); + Ok(()) + } + + async fn handle_whitelist_set_udp(&self, ports: &str) -> Result<(), Error> { + let udp_ports = Self::parse_port_list(ports)?; + let client = self.get_acl_manager_client().await?; + + // Get current TCP ports to preserve them + let current = client + .get_whitelist(BaseController::default(), GetWhitelistRequest::default()) + .await?; + let request = SetWhitelistRequest { + tcp_ports: current.tcp_ports, + udp_ports, + }; + + client + .set_whitelist(BaseController::default(), request) + .await?; + println!("UDP whitelist updated: {}", ports); + Ok(()) + } + + async fn handle_whitelist_clear_tcp(&self) -> Result<(), Error> { + let client = self.get_acl_manager_client().await?; + + // Get current UDP ports to preserve them + let current = client + .get_whitelist(BaseController::default(), GetWhitelistRequest::default()) + .await?; + let request = SetWhitelistRequest { + tcp_ports: vec![], + udp_ports: current.udp_ports, + }; + + client + .set_whitelist(BaseController::default(), request) + .await?; + println!("TCP whitelist cleared"); + Ok(()) + } + + async fn handle_whitelist_clear_udp(&self) -> Result<(), Error> { + let client = self.get_acl_manager_client().await?; + + // Get current TCP ports to preserve them + let current = client + .get_whitelist(BaseController::default(), GetWhitelistRequest::default()) + .await?; + let request = SetWhitelistRequest { + tcp_ports: current.tcp_ports, + udp_ports: vec![], + }; + + client + .set_whitelist(BaseController::default(), request) + .await?; + println!("UDP whitelist cleared"); + Ok(()) + } + + async fn handle_whitelist_show(&self) -> Result<(), Error> { + let client = self.get_acl_manager_client().await?; + let request = GetWhitelistRequest::default(); + let response = client + .get_whitelist(BaseController::default(), request) + .await?; + + if self.verbose || *self.output_format == OutputFormat::Json { + println!("{}", serde_json::to_string_pretty(&response)?); + return Ok(()); + } + + println!( + "TCP Whitelist: {}", + if response.tcp_ports.is_empty() { + "None".to_string() + } else { + response.tcp_ports.join(", ") + } + ); + + println!( + "UDP Whitelist: {}", + if response.udp_ports.is_empty() { + "None".to_string() + } else { + response.udp_ports.join(", ") + } + ); + + Ok(()) + } + + fn parse_port_list(ports_str: &str) -> Result, Error> { + let mut ports = Vec::new(); + for port_spec in ports_str.split(',') { + let port_spec = port_spec.trim(); + if port_spec.contains('-') { + // Handle port range + let parts: Vec<&str> = port_spec.split('-').collect(); + if parts.len() != 2 { + return Err(anyhow::anyhow!("Invalid port range: {}", port_spec)); + } + let start: u16 = parts[0] + .parse() + .with_context(|| format!("Invalid start port: {}", parts[0]))?; + let end: u16 = parts[1] + .parse() + .with_context(|| format!("Invalid end port: {}", parts[1]))?; + if start > end { + return Err(anyhow::anyhow!("Invalid port range: start > end")); + } + ports.push(format!("{}-{}", start, end)); + } else { + // Handle single port + let port: u16 = port_spec + .parse() + .with_context(|| format!("Invalid port number: {}", port_spec))?; + ports.push(port.to_string()); + } + } + Ok(ports) + } } #[derive(Debug)] @@ -1494,6 +1827,46 @@ async fn main() -> Result<(), Error> { handler.handle_acl_stats().await?; } }, + SubCommand::PortForward(port_forward_args) => match &port_forward_args.sub_command { + Some(PortForwardSubCommand::Add { + protocol, + bind_addr, + dst_addr, + }) => { + handler + .handle_port_forward_add(protocol, bind_addr, dst_addr) + .await?; + } + Some(PortForwardSubCommand::Remove { + protocol, + bind_addr, + dst_addr, + }) => { + handler + .handle_port_forward_remove(protocol, bind_addr, dst_addr.as_deref()) + .await?; + } + Some(PortForwardSubCommand::List) | None => { + handler.handle_port_forward_list().await?; + } + }, + SubCommand::Whitelist(whitelist_args) => match &whitelist_args.sub_command { + Some(WhitelistSubCommand::SetTcp { ports }) => { + handler.handle_whitelist_set_tcp(ports).await?; + } + Some(WhitelistSubCommand::SetUdp { ports }) => { + handler.handle_whitelist_set_udp(ports).await?; + } + Some(WhitelistSubCommand::ClearTcp) => { + handler.handle_whitelist_clear_tcp().await?; + } + Some(WhitelistSubCommand::ClearUdp) => { + handler.handle_whitelist_clear_udp().await?; + } + Some(WhitelistSubCommand::Show) | None => { + handler.handle_whitelist_show().await?; + } + }, SubCommand::GenAutocomplete { shell } => { let mut cmd = Cli::command(); easytier::print_completions(shell, &mut cmd, "easytier-cli"); diff --git a/easytier/src/easytier-core.rs b/easytier/src/easytier-core.rs index 509a1639..27dea312 100644 --- a/easytier/src/easytier-core.rs +++ b/easytier/src/easytier-core.rs @@ -29,10 +29,7 @@ use easytier::{ connector::create_connector_by_url, instance_manager::NetworkInstanceManager, launcher::{add_proxy_network_to_config, ConfigSource}, - proto::{ - acl::{Acl, AclV1, Action, Chain, ChainType, Protocol, Rule}, - common::{CompressionAlgoPb, NatType}, - }, + proto::common::{CompressionAlgoPb, NatType}, tunnel::{IpVersion, PROTO_PORT_OFFSET}, utils::{init_logger, setup_panic_handler}, web_client, @@ -622,115 +619,6 @@ impl NetworkOptions { false } - fn parse_port_list(port_list: &[String]) -> anyhow::Result> { - let mut ports = Vec::new(); - - for port_spec in port_list { - if port_spec.contains('-') { - // Handle port range like "8000-9000" - let parts: Vec<&str> = port_spec.split('-').collect(); - if parts.len() != 2 { - return Err(anyhow::anyhow!("Invalid port range format: {}", port_spec)); - } - - let start: u16 = parts[0] - .parse() - .with_context(|| format!("Invalid start port in range: {}", port_spec))?; - let end: u16 = parts[1] - .parse() - .with_context(|| format!("Invalid end port in range: {}", port_spec))?; - - if start > end { - return Err(anyhow::anyhow!( - "Start port must be <= end port in range: {}", - port_spec - )); - } - - // acl can handle port range - ports.push(port_spec.clone()); - } else { - // Handle single port - let port: u16 = port_spec - .parse() - .with_context(|| format!("Invalid port number: {}", port_spec))?; - ports.push(port.to_string()); - } - } - - Ok(ports) - } - - fn generate_acl_from_whitelists(&self) -> anyhow::Result> { - if self.tcp_whitelist.is_empty() && self.udp_whitelist.is_empty() { - return Ok(None); - } - - let mut acl = Acl { - acl_v1: Some(AclV1 { chains: vec![] }), - }; - - let acl_v1 = acl.acl_v1.as_mut().unwrap(); - - // Create inbound chain for whitelist rules - let mut inbound_chain = Chain { - name: "inbound_whitelist".to_string(), - chain_type: ChainType::Inbound as i32, - description: "Auto-generated inbound whitelist from CLI".to_string(), - enabled: true, - rules: vec![], - default_action: Action::Drop as i32, // Default deny - }; - - let mut rule_priority = 1000u32; - - // Add TCP whitelist rules - if !self.tcp_whitelist.is_empty() { - let tcp_ports = Self::parse_port_list(&self.tcp_whitelist)?; - let tcp_rule = Rule { - name: "tcp_whitelist".to_string(), - description: "Auto-generated TCP whitelist rule".to_string(), - priority: rule_priority, - enabled: true, - protocol: Protocol::Tcp as i32, - ports: tcp_ports, - source_ips: vec![], - destination_ips: vec![], - source_ports: vec![], - action: Action::Allow as i32, - rate_limit: 0, - burst_limit: 0, - stateful: true, - }; - inbound_chain.rules.push(tcp_rule); - rule_priority -= 1; - } - - // Add UDP whitelist rules - if !self.udp_whitelist.is_empty() { - let udp_ports = Self::parse_port_list(&self.udp_whitelist)?; - let udp_rule = Rule { - name: "udp_whitelist".to_string(), - description: "Auto-generated UDP whitelist rule".to_string(), - priority: rule_priority, - enabled: true, - protocol: Protocol::Udp as i32, - ports: udp_ports, - source_ips: vec![], - destination_ips: vec![], - source_ports: vec![], - action: Action::Allow as i32, - rate_limit: 0, - burst_limit: 0, - stateful: false, - }; - inbound_chain.rules.push(udp_rule); - } - - acl_v1.chains.push(inbound_chain); - Ok(Some(acl)) - } - fn merge_into(&self, cfg: &mut TomlConfigLoader) -> anyhow::Result<()> { if self.hostname.is_some() { cfg.set_hostname(self.hostname.clone()); @@ -988,10 +876,13 @@ impl NetworkOptions { cfg.set_exit_nodes(self.exit_nodes.clone()); } - // Handle port whitelists by generating ACL configuration - if let Some(acl) = self.generate_acl_from_whitelists()? { - cfg.set_acl(Some(acl)); - } + let mut old_tcp_whitelist = cfg.get_tcp_whitelist(); + old_tcp_whitelist.extend(self.tcp_whitelist.clone()); + cfg.set_tcp_whitelist(old_tcp_whitelist); + + let mut old_udp_whitelist = cfg.get_udp_whitelist(); + old_udp_whitelist.extend(self.udp_whitelist.clone()); + cfg.set_udp_whitelist(old_udp_whitelist); Ok(()) } diff --git a/easytier/src/gateway/socks5.rs b/easytier/src/gateway/socks5.rs index 7d33350c..4ca0f659 100644 --- a/easytier/src/gateway/socks5.rs +++ b/easytier/src/gateway/socks5.rs @@ -6,6 +6,7 @@ use std::{ use crossbeam::atomic::AtomicCell; use kcp_sys::{endpoint::KcpEndpoint, stream::KcpStream}; +use tokio_util::sync::{CancellationToken, DropGuard}; use crate::{ common::{ @@ -432,6 +433,8 @@ pub struct Socks5Server { udp_forward_task: Arc>>, kcp_endpoint: Mutex>>, + + cancel_tokens: DashMap, } #[async_trait::async_trait] @@ -531,6 +534,8 @@ impl Socks5Server { udp_forward_task: Arc::new(DashMap::new()), kcp_endpoint: Mutex::new(None), + + cancel_tokens: DashMap::new(), }) } @@ -614,10 +619,9 @@ impl Socks5Server { need_start = true; }; - for port_forward in self.global_ctx.config.get_port_forwards() { - self.add_port_forward(port_forward).await?; - need_start = true; - } + let cfgs = self.global_ctx.config.get_port_forwards(); + self.reload_port_forwards(&cfgs).await?; + need_start = need_start || cfgs.len() > 0; if need_start { self.peer_manager @@ -630,6 +634,26 @@ impl Socks5Server { Ok(()) } + pub async fn reload_port_forwards(&self, cfgs: &Vec) -> Result<(), Error> { + // remove entries not in new cfg + self.cancel_tokens.retain(|k, _| { + cfgs.iter().any(|cfg| { + if cfg.dst_addr.ip().is_unspecified() { + k.bind_addr == cfg.bind_addr && k.proto == cfg.proto + } else { + k == cfg + } + }) + }); + // add new ones + for cfg in cfgs { + if !self.cancel_tokens.contains_key(cfg) { + self.add_port_forward(cfg.clone()).await?; + } + } + Ok(()) + } + async fn handle_port_forward_connection( mut incoming_socket: tokio::net::TcpStream, connector: Box + Send>, @@ -660,12 +684,10 @@ impl Socks5Server { pub async fn add_port_forward(&self, cfg: PortForwardConfig) -> Result<(), Error> { match cfg.proto.to_lowercase().as_str() { "tcp" => { - self.add_tcp_port_forward(cfg.bind_addr, cfg.dst_addr) - .await?; + self.add_tcp_port_forward(&cfg).await?; } "udp" => { - self.add_udp_port_forward(cfg.bind_addr, cfg.dst_addr) - .await?; + self.add_udp_port_forward(&cfg).await?; } _ => { return Err(anyhow::anyhow!( @@ -680,11 +702,12 @@ impl Socks5Server { Ok(()) } - pub async fn add_tcp_port_forward( - &self, - bind_addr: SocketAddr, - dst_addr: SocketAddr, - ) -> Result<(), Error> { + pub fn remove_port_forward(&self, cfg: PortForwardConfig) { + let _ = self.cancel_tokens.remove(&cfg); + } + + pub async fn add_tcp_port_forward(&self, cfg: &PortForwardConfig) -> Result<(), Error> { + let (bind_addr, dst_addr) = (cfg.bind_addr, cfg.dst_addr); let listener = bind_tcp_socket(bind_addr, self.global_ctx.net_ns.clone())?; let net = self.net.clone(); @@ -693,14 +716,26 @@ impl Socks5Server { let forward_tasks = tasks.clone(); let kcp_endpoint = self.kcp_endpoint.lock().await.clone(); let peer_mgr = Arc::downgrade(&self.peer_manager.clone()); + let cancel_token = CancellationToken::new(); + self.cancel_tokens + .insert(cfg.clone(), cancel_token.clone().drop_guard()); self.tasks.lock().unwrap().spawn(async move { loop { - let (incoming_socket, addr) = match listener.accept().await { - Ok(result) => result, - Err(err) => { - tracing::error!("port forward accept error = {:?}", err); - continue; + let (incoming_socket, addr) = select! { + biased; + _ = cancel_token.cancelled() => { + tracing::info!("port forward for {:?} cancelled", bind_addr); + break; + } + res = listener.accept() => { + match res { + Ok(result) => result, + Err(err) => { + tracing::error!("port forward accept error = {:?}", err); + continue; + } + } } }; @@ -747,11 +782,8 @@ impl Socks5Server { } #[tracing::instrument(name = "add_udp_port_forward", skip(self))] - pub async fn add_udp_port_forward( - &self, - bind_addr: SocketAddr, - dst_addr: SocketAddr, - ) -> Result<(), Error> { + pub async fn add_udp_port_forward(&self, cfg: &PortForwardConfig) -> Result<(), Error> { + let (bind_addr, dst_addr) = (cfg.bind_addr, cfg.dst_addr); let socket = Arc::new(bind_udp_socket(bind_addr, self.global_ctx.net_ns.clone())?); let entries = self.entries.clone(); @@ -759,16 +791,28 @@ impl Socks5Server { let net = self.net.clone(); let udp_client_map = self.udp_client_map.clone(); let udp_forward_task = self.udp_forward_task.clone(); + let cancel_token = CancellationToken::new(); + self.cancel_tokens + .insert(cfg.clone(), cancel_token.clone().drop_guard()); self.tasks.lock().unwrap().spawn(async move { loop { // we set the max buffer size of smoltcp to 8192, so we need to use a buffer size that is less than 8192 here. let mut buf = vec![0u8; 8192]; - let (len, addr) = match socket.recv_from(&mut buf).await { - Ok(result) => result, - Err(err) => { - tracing::error!("udp port forward recv error = {:?}", err); - continue; + let (len, addr) = select! { + biased; + _ = cancel_token.cancelled() => { + tracing::info!("udp port forward for {:?} cancelled", bind_addr); + break; + } + res = socket.recv_from(&mut buf) => { + match res { + Ok(result) => result, + Err(err) => { + tracing::error!("udp port forward recv error = {:?}", err); + continue; + } + } } }; diff --git a/easytier/src/instance/instance.rs b/easytier/src/instance/instance.rs index 70dccb1f..43747b9b 100644 --- a/easytier/src/instance/instance.rs +++ b/easytier/src/instance/instance.rs @@ -10,6 +10,7 @@ use cidr::{IpCidr, Ipv4Inet}; use tokio::{sync::Mutex, task::JoinSet}; use tokio_util::sync::CancellationToken; +use crate::common::acl_processor::AclRuleBuilder; use crate::common::config::ConfigLoader; use crate::common::error::Error; use crate::common::global_ctx::{ArcGlobalCtx, GlobalCtx, GlobalCtxEvent}; @@ -29,13 +30,15 @@ use crate::peers::peer_manager::{PeerManager, RouteAlgoType}; use crate::peers::rpc_service::PeerManagerRpcService; use crate::peers::{create_packet_recv_chan, recv_packet_from_chan, PacketRecvChanReceiver}; use crate::proto::cli::VpnPortalRpc; -use crate::proto::cli::{GetVpnPortalInfoRequest, GetVpnPortalInfoResponse, VpnPortalInfo}; use crate::proto::cli::{ - ListMappedListenerRequest, ListMappedListenerResponse, ManageMappedListenerRequest, - ManageMappedListenerResponse, MappedListener, MappedListenerManageAction, - MappedListenerManageRpc, + AddPortForwardRequest, AddPortForwardResponse, ListMappedListenerRequest, + ListMappedListenerResponse, ListPortForwardRequest, ListPortForwardResponse, + ManageMappedListenerRequest, ManageMappedListenerResponse, MappedListener, + MappedListenerManageAction, MappedListenerManageRpc, PortForwardManageRpc, + RemovePortForwardRequest, RemovePortForwardResponse, }; -use crate::proto::common::TunnelInfo; +use crate::proto::cli::{GetVpnPortalInfoRequest, GetVpnPortalInfoResponse, VpnPortalInfo}; +use crate::proto::common::{PortForwardConfigPb, TunnelInfo}; use crate::proto::peer_rpc::PeerCenterRpcServer; use crate::proto::rpc_impl::standalone::{RpcServerHook, StandAloneServer}; use crate::proto::rpc_types; @@ -609,9 +612,9 @@ impl Instance { } } - if let Some(acl) = self.global_ctx.config.get_acl() { - self.global_ctx.get_acl_filter().reload_rules(Some(&acl)); - } + self.global_ctx + .get_acl_filter() + .reload_rules(AclRuleBuilder::build(&self.global_ctx)?.as_ref()); // run after tun device created, so listener can bind to tun device, which may be required by win 10 self.ip_proxy = Some(IpProxy::new( @@ -790,6 +793,85 @@ impl Instance { MappedListenerManagerRpcService(self.global_ctx.clone()) } + fn get_port_forward_manager_rpc_service( + &self, + ) -> impl PortForwardManageRpc + Clone { + #[derive(Clone)] + pub struct PortForwardManagerRpcService { + global_ctx: ArcGlobalCtx, + socks5_server: Weak, + } + + #[async_trait::async_trait] + impl PortForwardManageRpc for PortForwardManagerRpcService { + type Controller = BaseController; + + async fn add_port_forward( + &self, + _: BaseController, + request: AddPortForwardRequest, + ) -> Result { + let Some(socks5_server) = self.socks5_server.upgrade() else { + return Err(anyhow::anyhow!("socks5 server not available").into()); + }; + if let Some(cfg) = request.cfg { + tracing::info!("Port forward rule added: {:?}", cfg); + let mut current_forwards = self.global_ctx.config.get_port_forwards(); + current_forwards.push(cfg.into()); + self.global_ctx + .config + .set_port_forwards(current_forwards.clone()); + socks5_server + .reload_port_forwards(¤t_forwards) + .await + .with_context(|| "Failed to reload port forwards")?; + } + Ok(AddPortForwardResponse {}) + } + + async fn remove_port_forward( + &self, + _: BaseController, + request: RemovePortForwardRequest, + ) -> Result { + let Some(socks5_server) = self.socks5_server.upgrade() else { + return Err(anyhow::anyhow!("socks5 server not available").into()); + }; + let Some(cfg) = request.cfg else { + return Err(anyhow::anyhow!("port forward config is empty").into()); + }; + let cfg = cfg.into(); + let mut current_forwards = self.global_ctx.config.get_port_forwards(); + current_forwards.retain(|e| *e != cfg); + self.global_ctx + .config + .set_port_forwards(current_forwards.clone()); + socks5_server + .reload_port_forwards(¤t_forwards) + .await + .with_context(|| "Failed to reload port forwards")?; + + tracing::info!("Port forward rule removed: {:?}", cfg); + Ok(RemovePortForwardResponse {}) + } + + async fn list_port_forward( + &self, + _: BaseController, + _request: ListPortForwardRequest, + ) -> Result { + let forwards = self.global_ctx.config.get_port_forwards(); + let cfgs: Vec = forwards.into_iter().map(Into::into).collect(); + Ok(ListPortForwardResponse { cfgs }) + } + } + + PortForwardManagerRpcService { + global_ctx: self.global_ctx.clone(), + socks5_server: Arc::downgrade(&self.socks5_server), + } + } + async fn run_rpc_server(&mut self) -> Result<(), Error> { let Some(_) = self.global_ctx.config.get_rpc_portal() else { tracing::info!("rpc server not enabled, because rpc_portal is not set."); @@ -803,6 +885,7 @@ impl Instance { let peer_center = self.peer_center.clone(); let vpn_portal_rpc = self.get_vpn_portal_rpc_service(); let mapped_listener_manager_rpc = self.get_mapped_listener_manager_rpc_service(); + let port_forward_manager_rpc = self.get_port_forward_manager_rpc_service(); let s = self.rpc_server.as_mut().unwrap(); let peer_mgr_rpc_service = PeerManagerRpcService::new(peer_mgr.clone()); @@ -823,6 +906,10 @@ impl Instance { MappedListenerManageRpcServer::new(mapped_listener_manager_rpc), "", ); + s.registry().register( + PortForwardManageRpcServer::new(port_forward_manager_rpc), + "", + ); if let Some(ip_proxy) = self.ip_proxy.as_ref() { s.registry().register( diff --git a/easytier/src/peers/rpc_service.rs b/easytier/src/peers/rpc_service.rs index e9913cca..3c83f1cb 100644 --- a/easytier/src/peers/rpc_service.rs +++ b/easytier/src/peers/rpc_service.rs @@ -1,13 +1,18 @@ use std::sync::Arc; -use crate::proto::{ - cli::{ - AclManageRpc, DumpRouteRequest, DumpRouteResponse, GetAclStatsRequest, GetAclStatsResponse, - ListForeignNetworkRequest, ListForeignNetworkResponse, ListGlobalForeignNetworkRequest, - ListGlobalForeignNetworkResponse, ListPeerRequest, ListPeerResponse, ListRouteRequest, - ListRouteResponse, PeerInfo, PeerManageRpc, ShowNodeInfoRequest, ShowNodeInfoResponse, +use crate::{ + common::acl_processor::AclRuleBuilder, + proto::{ + cli::{ + AclManageRpc, DumpRouteRequest, DumpRouteResponse, GetAclStatsRequest, + GetAclStatsResponse, GetWhitelistRequest, GetWhitelistResponse, + ListForeignNetworkRequest, ListForeignNetworkResponse, ListGlobalForeignNetworkRequest, + ListGlobalForeignNetworkResponse, ListPeerRequest, ListPeerResponse, ListRouteRequest, + ListRouteResponse, PeerInfo, PeerManageRpc, SetWhitelistRequest, SetWhitelistResponse, + ShowNodeInfoRequest, ShowNodeInfoResponse, + }, + rpc_types::{self, controller::BaseController}, }, - rpc_types::{self, controller::BaseController}, }; use super::peer_manager::PeerManager; @@ -153,4 +158,45 @@ impl AclManageRpc for PeerManagerRpcService { acl_stats: Some(acl_stats), }) } + + async fn set_whitelist( + &self, + _: BaseController, + request: SetWhitelistRequest, + ) -> Result { + tracing::info!( + "Setting whitelist - TCP: {:?}, UDP: {:?}", + request.tcp_ports, + request.udp_ports + ); + + let global_ctx = self.peer_manager.get_global_ctx(); + + global_ctx.config.set_tcp_whitelist(request.tcp_ports); + global_ctx.config.set_udp_whitelist(request.udp_ports); + global_ctx + .get_acl_filter() + .reload_rules(AclRuleBuilder::build(&global_ctx)?.as_ref()); + + Ok(SetWhitelistResponse {}) + } + + async fn get_whitelist( + &self, + _: BaseController, + _request: GetWhitelistRequest, + ) -> Result { + let global_ctx = self.peer_manager.get_global_ctx(); + let tcp_ports = global_ctx.config.get_tcp_whitelist(); + let udp_ports = global_ctx.config.get_udp_whitelist(); + tracing::info!( + "Getting whitelist - TCP: {:?}, UDP: {:?}", + tcp_ports, + udp_ports + ); + Ok(GetWhitelistResponse { + tcp_ports, + udp_ports, + }) + } } diff --git a/easytier/src/proto/cli.proto b/easytier/src/proto/cli.proto index 847e48af..fbf21006 100644 --- a/easytier/src/proto/cli.proto +++ b/easytier/src/proto/cli.proto @@ -261,4 +261,44 @@ message GetAclStatsResponse { service AclManageRpc { rpc GetAclStats(GetAclStatsRequest) returns (GetAclStatsResponse); + rpc SetWhitelist(SetWhitelistRequest) returns (SetWhitelistResponse); + rpc GetWhitelist(GetWhitelistRequest) returns (GetWhitelistResponse); +} + +message SetWhitelistRequest { + repeated string tcp_ports = 1; + repeated string udp_ports = 2; +} + +message SetWhitelistResponse {} + +message GetWhitelistRequest {} + +message GetWhitelistResponse { + repeated string tcp_ports = 1; + repeated string udp_ports = 2; +} + +message AddPortForwardRequest { + common.PortForwardConfigPb cfg = 1; +} + +message AddPortForwardResponse {} + +message RemovePortForwardRequest { + common.PortForwardConfigPb cfg = 1; +} + +message RemovePortForwardResponse {} + +message ListPortForwardRequest {} + +message ListPortForwardResponse { + repeated common.PortForwardConfigPb cfgs = 1; +} + +service PortForwardManageRpc { + rpc AddPortForward(AddPortForwardRequest) returns (AddPortForwardResponse); + rpc RemovePortForward(RemovePortForwardRequest) returns (RemovePortForwardResponse); + rpc ListPortForward(ListPortForwardRequest) returns (ListPortForwardResponse); }