From 3467890270097b521d3667d52f940f7b492c7341 Mon Sep 17 00:00:00 2001 From: "Sijie.Sun" Date: Wed, 24 Apr 2024 23:12:46 +0800 Subject: [PATCH] zero copy tunnel (#55) make tunnel zero copy, for better performance. remove most of the locks in io path. introduce quic tunnel prepare for encryption --- Cargo.lock | 320 ++++++- easytier/Cargo.toml | 12 + easytier/proto/cli.proto | 18 + easytier/src/common/error.rs | 11 +- easytier/src/common/stun.rs | 38 +- easytier/src/connector/manual.rs | 31 +- easytier/src/connector/mod.rs | 12 +- easytier/src/connector/udp_hole_punch.rs | 18 +- easytier/src/easytier-cli.rs | 1 + easytier/src/easytier-core.rs | 1 + easytier/src/gateway/icmp_proxy.rs | 138 +-- easytier/src/gateway/tcp_proxy.rs | 190 ++-- easytier/src/gateway/udp_proxy.rs | 56 +- easytier/src/instance/instance.rs | 72 +- easytier/src/instance/listeners.rs | 23 +- easytier/src/lib.rs | 1 + easytier/src/peers/foreign_network_client.rs | 24 +- easytier/src/peers/foreign_network_manager.rs | 89 +- easytier/src/peers/mod.rs | 19 +- easytier/src/peers/peer.rs | 28 +- easytier/src/peers/peer_conn_ping.rs | 219 +++++ easytier/src/peers/peer_manager.rs | 181 ++-- easytier/src/peers/peer_map.rs | 60 +- easytier/src/peers/peer_ospf_route.rs | 33 +- easytier/src/peers/peer_rip_route.rs | 35 +- easytier/src/peers/peer_rpc.rs | 131 +-- easytier/src/peers/tests.rs | 2 +- easytier/src/peers/zc_peer_conn.rs | 748 ++++++++++++++++ easytier/src/tests/mod.rs | 2 +- easytier/src/tests/three_node.rs | 183 +++- easytier/src/tunnel/buf.rs | 92 ++ easytier/src/tunnel/common.rs | 539 +++++++++++ easytier/src/tunnel/filter.rs | 362 ++++++++ easytier/src/tunnel/mod.rs | 196 ++++ easytier/src/tunnel/mpsc.rs | 180 ++++ easytier/src/tunnel/packet_def.rs | 340 +++++++ easytier/src/tunnel/quic.rs | 226 +++++ easytier/src/tunnel/ring.rs | 427 +++++++++ easytier/src/tunnel/stats.rs | 95 ++ easytier/src/tunnel/tcp.rs | 200 +++++ easytier/src/tunnel/udp.rs | 838 ++++++++++++++++++ easytier/src/tunnel/wireguard.rs | 827 +++++++++++++++++ easytier/src/tunnels/mod.rs | 12 +- easytier/src/vpn_portal/wireguard.rs | 162 ++-- 44 files changed, 6504 insertions(+), 688 deletions(-) create mode 100644 easytier/src/peers/peer_conn_ping.rs create mode 100644 easytier/src/peers/zc_peer_conn.rs create mode 100644 easytier/src/tunnel/buf.rs create mode 100644 easytier/src/tunnel/common.rs create mode 100644 easytier/src/tunnel/filter.rs create mode 100644 easytier/src/tunnel/mod.rs create mode 100644 easytier/src/tunnel/mpsc.rs create mode 100644 easytier/src/tunnel/packet_def.rs create mode 100644 easytier/src/tunnel/quic.rs create mode 100644 easytier/src/tunnel/ring.rs create mode 100644 easytier/src/tunnel/stats.rs create mode 100644 easytier/src/tunnel/tcp.rs create mode 100644 easytier/src/tunnel/udp.rs create mode 100644 easytier/src/tunnel/wireguard.rs diff --git a/Cargo.lock b/Cargo.lock index 4229a9e4..1b68679d 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -142,6 +142,15 @@ version = "1.0.79" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "080e9890a082662b09c1ad45f567faeeb47f22b5fb23895fbe1e651e718e25ca" +[[package]] +name = "async-event" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4172595da7ffb68640606be5723e35a353555f2829e9209437627a003725bbdb" +dependencies = [ + "loom", +] + [[package]] name = "async-recursion" version = "1.1.0" @@ -219,6 +228,12 @@ dependencies = [ "critical-section", ] +[[package]] +name = "atomicbox" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8a9a3820bc9e9aaf60c8389c2a4808548599f4ff254ce6bdb608ac3631d4ad76" + [[package]] name = "auto_impl" version = "1.1.0" @@ -309,6 +324,12 @@ version = "0.21.7" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9d297deb1925b89f2ccc13d7635fa0714f12c87adce1c75356b39ca9b7178567" +[[package]] +name = "base64" +version = "0.22.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9475866fec1451be56a3c2400fd081ff546538961565ccb5b7142cbd22bc7a51" + [[package]] name = "base64ct" version = "1.6.0" @@ -381,7 +402,7 @@ dependencies = [ "nix 0.25.1", "parking_lot", "rand_core 0.6.4", - "ring", + "ring 0.16.20", "tracing", "untrusted 0.9.0", "x25519-dalek", @@ -1078,6 +1099,26 @@ version = "2.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7e962a19be5cfc3f3bf6dd8f61eb50107f356ad6270fbb3ed41476571db78be5" +[[package]] +name = "defguard_wireguard_rs" +version = "0.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6ba16f17698d4b389907310af018b0c3a80b025bba9c38d947cbc6dd70921743" +dependencies = [ + "base64 0.21.7", + "libc", + "log", + "netlink-packet-core", + "netlink-packet-generic", + "netlink-packet-route", + "netlink-packet-utils", + "netlink-packet-wireguard", + "netlink-sys", + "nix 0.27.1", + "serde", + "thiserror", +] + [[package]] name = "deprecate-until" version = "0.1.1" @@ -1149,6 +1190,16 @@ dependencies = [ "syn 1.0.109", ] +[[package]] +name = "diatomic-waker" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "28025fb55a9d815acf7b0877555f437254f373036eec6ed265116c7a5c0825e9" +dependencies = [ + "loom", + "waker-fn", +] + [[package]] name = "digest" version = "0.10.7" @@ -1228,17 +1279,20 @@ dependencies = [ "async-recursion", "async-stream", "async-trait", + "atomicbox", "auto_impl", "base64 0.21.7", "boringtun", "bytecodec", "byteorder", + "bytes", "chrono", "cidr", "clap", "crossbeam", "crossbeam-queue", "dashmap", + "defguard_wireguard_rs", "derivative", "futures", "gethostname", @@ -1249,19 +1303,24 @@ dependencies = [ "once_cell", "pathfinding", "percent-encoding", + "pin-project-lite", "pnet", "postcard", "prost", "public-ip", + "quinn", "rand 0.8.5", + "rcgen", "reqwest", "rkyv", "rstest", + "rustls", "serde", "serial_test", "socket2 0.5.5", "stun_codec", "tabled", + "tachyonix", "tarpc", "thiserror", "time", @@ -1279,6 +1338,7 @@ dependencies = [ "url", "uuid", "windows-sys 0.52.0", + "zerocopy", "zip", ] @@ -2691,6 +2751,80 @@ dependencies = [ "jni-sys", ] +[[package]] +name = "netlink-packet-core" +version = "0.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "72724faf704479d67b388da142b186f916188505e7e0b26719019c525882eda4" +dependencies = [ + "anyhow", + "byteorder", + "netlink-packet-utils", +] + +[[package]] +name = "netlink-packet-generic" +version = "0.3.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1cd7eb8ad331c84c6b8cb7f685b448133e5ad82e1ffd5acafac374af4a5a308b" +dependencies = [ + "anyhow", + "byteorder", + "netlink-packet-core", + "netlink-packet-utils", +] + +[[package]] +name = "netlink-packet-route" +version = "0.17.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "053998cea5a306971f88580d0829e90f270f940befd7cf928da179d4187a5a66" +dependencies = [ + "anyhow", + "bitflags 1.3.2", + "byteorder", + "libc", + "netlink-packet-core", + "netlink-packet-utils", +] + +[[package]] +name = "netlink-packet-utils" +version = "0.5.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0ede8a08c71ad5a95cdd0e4e52facd37190977039a4704eb82a283f713747d34" +dependencies = [ + "anyhow", + "byteorder", + "paste", + "thiserror", +] + +[[package]] +name = "netlink-packet-wireguard" +version = "0.2.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "60b25b050ff1f6a1e23c6777b72db22790fe5b6b5ccfd3858672587a79876c8f" +dependencies = [ + "anyhow", + "byteorder", + "libc", + "log", + "netlink-packet-generic", + "netlink-packet-utils", +] + +[[package]] +name = "netlink-sys" +version = "0.8.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "416060d346fbaf1f23f9512963e3e878f1a78e707cb699ba9215761754244307" +dependencies = [ + "bytes", + "libc", + "log", +] + [[package]] name = "network-interface" version = "1.1.1" @@ -3011,6 +3145,12 @@ dependencies = [ "subtle", ] +[[package]] +name = "paste" +version = "1.0.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "de3145af08024dea9fa9914f381a17b8fc6034dfb00f3a84013f7ff43f29ed4c" + [[package]] name = "pathdiff" version = "0.2.1" @@ -3044,6 +3184,16 @@ dependencies = [ "sha2", ] +[[package]] +name = "pem" +version = "3.0.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8e459365e590736a54c3fa561947c84837534b8e9af6fc5bf781307e82658fae" +dependencies = [ + "base64 0.22.0", + "serde", +] + [[package]] name = "percent-encoding" version = "2.3.1" @@ -3574,6 +3724,54 @@ dependencies = [ "memchr", ] +[[package]] +name = "quinn" +version = "0.10.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8cc2c5017e4b43d5995dcea317bc46c1e09404c0a9664d2908f7f02dfe943d75" +dependencies = [ + "bytes", + "pin-project-lite", + "quinn-proto", + "quinn-udp", + "rustc-hash", + "rustls", + "thiserror", + "tokio", + "tracing", +] + +[[package]] +name = "quinn-proto" +version = "0.10.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "141bf7dfde2fbc246bfd3fe12f2455aa24b0fbd9af535d8c86c7bd1381ff2b1a" +dependencies = [ + "bytes", + "rand 0.8.5", + "ring 0.16.20", + "rustc-hash", + "rustls", + "rustls-native-certs", + "slab", + "thiserror", + "tinyvec", + "tracing", +] + +[[package]] +name = "quinn-udp" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "055b4e778e8feb9f93c4e439f71dc2156ef13360b432b799e179a8c4cdf0b1d7" +dependencies = [ + "bytes", + "libc", + "socket2 0.5.5", + "tracing", + "windows-sys 0.48.0", +] + [[package]] name = "quote" version = "1.0.35" @@ -3686,6 +3884,18 @@ version = "0.5.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f2ff9a1f06a88b01621b7ae906ef0211290d1c8a168a15542486a8f61c0833b9" +[[package]] +name = "rcgen" +version = "0.11.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "52c4f3084aa3bc7dfbba4eff4fab2a54db4324965d8872ab933565e6fbd83bc6" +dependencies = [ + "pem", + "ring 0.16.20", + "time", + "yasna", +] + [[package]] name = "redox_syscall" version = "0.4.1" @@ -3820,6 +4030,21 @@ dependencies = [ "winapi", ] +[[package]] +name = "ring" +version = "0.17.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c17fa4cb658e3583423e915b9f3acc01cceaee1860e33d59ebae66adc3a2dc0d" +dependencies = [ + "cc", + "cfg-if", + "getrandom 0.2.12", + "libc", + "spin 0.9.8", + "untrusted 0.9.0", + "windows-sys 0.52.0", +] + [[package]] name = "rkyv" version = "0.7.43" @@ -3912,6 +4137,30 @@ dependencies = [ "windows-sys 0.52.0", ] +[[package]] +name = "rustls" +version = "0.21.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7fecbfb7b1444f477b345853b1fce097a2c6fb637b2bfb87e6bc5db0f043fae4" +dependencies = [ + "log", + "ring 0.17.8", + "rustls-webpki", + "sct", +] + +[[package]] +name = "rustls-native-certs" +version = "0.6.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a9aace74cb666635c918e9c12bc0d348266037aa8eb599b5cba565709a8dff00" +dependencies = [ + "openssl-probe", + "rustls-pemfile", + "schannel", + "security-framework", +] + [[package]] name = "rustls-pemfile" version = "1.0.4" @@ -3921,6 +4170,16 @@ dependencies = [ "base64 0.21.7", ] +[[package]] +name = "rustls-webpki" +version = "0.101.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8b6275d1ee7a1cd780b64aca7726599a1dbc893b1e64144529e55c3c2f745765" +dependencies = [ + "ring 0.17.8", + "untrusted 0.9.0", +] + [[package]] name = "rustversion" version = "1.0.14" @@ -3963,6 +4222,16 @@ version = "1.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "94143f37725109f92c262ed2cf5e59bce7498c01bcc1502d7b9afe439a4e9f49" +[[package]] +name = "sct" +version = "0.7.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "da046153aa2352493d6cb7da4b6e5c0c057d8a1d0a9aa8560baffdd945acd414" +dependencies = [ + "ring 0.17.8", + "untrusted 0.9.0", +] + [[package]] name = "seahash" version = "4.1.0" @@ -4487,6 +4756,19 @@ dependencies = [ "syn 1.0.109", ] +[[package]] +name = "tachyonix" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "64e0bf82be3359dbefbfea621d6365db00e1d7846561daad2ea74cc4cb4c9604" +dependencies = [ + "async-event", + "crossbeam-utils", + "diatomic-waker", + "futures-core", + "loom", +] + [[package]] name = "tao" version = "0.16.8" @@ -5504,6 +5786,12 @@ dependencies = [ "libc", ] +[[package]] +name = "waker-fn" +version = "1.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f3c4517f54858c779bbcbf228f4fca63d121bf85fbecb2dc578cdf4a39395690" + [[package]] name = "walkdir" version = "2.5.0" @@ -6192,6 +6480,36 @@ dependencies = [ "rustix", ] +[[package]] +name = "yasna" +version = "0.5.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e17bb3549cc1321ae1296b9cdc2698e2b6cb1992adfa19a8c72e5b7a738f44cd" +dependencies = [ + "time", +] + +[[package]] +name = "zerocopy" +version = "0.7.32" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "74d4d3961e53fa4c9a25a8637fc2bfaf2595b3d3ae34875568a5cf64787716be" +dependencies = [ + "byteorder", + "zerocopy-derive", +] + +[[package]] +name = "zerocopy-derive" +version = "0.7.32" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9ce1b18ccd8e73a9321186f97e46f9f04b778851177567b1975109d26a08d2a6" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.48", +] + [[package]] name = "zeroize" version = "1.7.0" diff --git a/easytier/Cargo.toml b/easytier/Cargo.toml index ca975f07..2049129a 100644 --- a/easytier/Cargo.toml +++ b/easytier/Cargo.toml @@ -58,6 +58,17 @@ async-trait = "0.1.74" dashmap = "5.5.3" timedmap = "=1.0.1" +# for full-path zero-copy +zerocopy = { version = "0.7.32", features = ["derive", "simd"] } +bytes = "1.5.0" +pin-project-lite = "0.2.13" +atomicbox = "0.4.0" +tachyonix = "0.2.1" + +quinn = { version = "0.10.2" } +rustls = { version = "0.21.0", features = ["dangerous_configuration"] } +rcgen = "0.11.1" + # for tap device tun = { version = "0.6.1", features = ["async"] } # for net ns @@ -148,3 +159,4 @@ zip = "0.6.6" [dev-dependencies] serial_test = "3.0.0" rstest = "0.18.2" +defguard_wireguard_rs = "0.4.2" diff --git a/easytier/proto/cli.proto b/easytier/proto/cli.proto index 08fd85fd..6d3267e5 100644 --- a/easytier/proto/cli.proto +++ b/easytier/proto/cli.proto @@ -157,3 +157,21 @@ message GetVpnPortalInfoResponse { service VpnPortalRpc { rpc GetVpnPortalInfo (GetVpnPortalInfoRequest) returns (GetVpnPortalInfoResponse); } + +message HandshakeRequest { + uint32 magic = 1; + uint32 my_peer_id = 2; + uint32 version = 3; + repeated string features = 4; + string network_name = 5; + string network_secret = 6; +} + +message TaRpcPacket { + uint32 from_peer = 1; + uint32 to_peer = 2; + uint32 service_id = 3; + uint32 transact_id = 4; + bool is_req = 5; + bytes content = 6; +} diff --git a/easytier/src/common/error.rs b/easytier/src/common/error.rs index c48f780f..efa7b706 100644 --- a/easytier/src/common/error.rs +++ b/easytier/src/common/error.rs @@ -2,7 +2,7 @@ use std::{io, result}; use thiserror::Error; -use crate::tunnels; +use crate::{tunnel, tunnels}; use super::PeerId; @@ -38,6 +38,15 @@ pub enum Error { Unknown, #[error("anyhow error: {0}")] AnyhowError(#[from] anyhow::Error), + + #[error("wait resp error: {0}")] + WaitRespError(String), + + #[error("tunnel error")] + TunnelErr(#[from] tunnel::TunnelError), + + #[error("message decode error: {0}")] + MessageDecodeError(String), } pub type Result = result::Result; diff --git a/easytier/src/common/stun.rs b/easytier/src/common/stun.rs index 60fefcc0..6527d138 100644 --- a/easytier/src/common/stun.rs +++ b/easytier/src/common/stun.rs @@ -87,7 +87,7 @@ impl Stun { pub fn new(stun_server: SocketAddr) -> Self { Self { stun_server, - req_repeat: 5, + req_repeat: 1, resp_timeout: Duration::from_millis(3000), } } @@ -208,6 +208,7 @@ impl Stun { let mut tids = vec![]; for _ in 0..self.req_repeat { let tid = rand::random::(); + // let tid = 1; let mut buf = [0u8; 28]; // memset buf unsafe { std::ptr::write_bytes(buf.as_mut_ptr(), 0, buf.len()) }; @@ -511,30 +512,17 @@ mod tests { #[tokio::test] async fn test_stun_bind_request() { // miwifi / qq seems not correctly responde to change_ip and change_port, they always try to change the src ip and port. - let mut ips = HostResolverIter::new(vec!["stun1.l.google.com:19302".to_string()]); - let stun = Stun::new(ips.next().await.unwrap()); - // let stun = Stun::new("180.235.108.91:3478".to_string()); - // let stun = Stun::new("193.22.2.248:3478".to_string()); - // let stun = Stun::new("stun.chat.bilibili.com:3478".to_string()); - // let stun = Stun::new("stun.miwifi.com:3478".to_string()); - - // github actions are port restricted nat, so we only test last one. - - // let rs = stun.bind_request(12345, true, true).await.unwrap(); - // assert!(rs.ip_changed); - // assert!(rs.port_changed); - - // let rs = stun.bind_request(12345, true, false).await.unwrap(); - // assert!(rs.ip_changed); - // assert!(!rs.port_changed); - - // let rs = stun.bind_request(12345, false, true).await.unwrap(); - // assert!(!rs.ip_changed); - // assert!(rs.port_changed); - - let rs = stun.bind_request(12345, false, false).await.unwrap(); - assert!(!rs.ip_changed); - assert!(!rs.port_changed); + // let mut ips = HostResolverIter::new(vec!["stun1.l.google.com:19302".to_string()]); + let mut ips_ = HostResolverIter::new(vec!["stun.canets.org:3478".to_string()]); + let mut ips = vec![]; + while let Some(ip) = ips_.next().await { + ips.push(ip); + } + println!("ip: {:?}", ips); + for ip in ips.iter() { + let stun = Stun::new(ip.clone()); + let _rs = stun.bind_request(12345, true, true).await; + } } #[tokio::test] diff --git a/easytier/src/connector/manual.rs b/easytier/src/connector/manual.rs index f8f61992..8c7342d9 100644 --- a/easytier/src/connector/manual.rs +++ b/easytier/src/connector/manual.rs @@ -7,7 +7,9 @@ use tokio::{ time::timeout, }; -use crate::{common::PeerId, peers::peer_conn::PeerConnId, rpc as easytier_rpc}; +use crate::{ + common::PeerId, peers::zc_peer_conn::PeerConnId, rpc as easytier_rpc, tunnel::TunnelConnector, +}; use crate::{ common::{ @@ -21,13 +23,13 @@ use crate::{ connector_manage_rpc_server::ConnectorManageRpc, Connector, ConnectorStatus, ListConnectorRequest, ManageConnectorRequest, }, - tunnels::{Tunnel, TunnelConnector}, use_global_var, }; use super::create_connector_by_url; -type ConnectorMap = Arc>>; +type MutexConnector = Arc>>; +type ConnectorMap = Arc>; #[derive(Debug, Clone)] struct ReconnResult { @@ -81,12 +83,13 @@ impl ManualConnectorManager { pub fn add_connector(&self, connector: T) where - T: TunnelConnector + Send + Sync + 'static, + T: TunnelConnector + 'static, { log::info!("add_connector: {}", connector.remote_url()); - self.data - .connectors - .insert(connector.remote_url().into(), Box::new(connector)); + self.data.connectors.insert( + connector.remote_url().into(), + Arc::new(Mutex::new(Box::new(connector))), + ); } pub async fn add_connector_by_url(&self, url: &str) -> Result<(), Error> { @@ -254,7 +257,7 @@ impl ManualConnectorManager { async fn conn_reconnect( data: Arc, dead_url: String, - connector: Box, + connector: MutexConnector, ) -> Result { let connector = Arc::new(Mutex::new(Some(connector))); let net_ns = data.net_ns.clone(); @@ -269,15 +272,17 @@ impl ManualConnectorManager { let mut locked = connector_clone.lock().await; let conn = locked.as_mut().unwrap(); // TODO: should support set v6 here, use url in connector array - set_bind_addr_for_peer_connector(conn, true, &ip_collector).await; + set_bind_addr_for_peer_connector(conn.lock().await.as_mut(), true, &ip_collector).await; data_clone .global_ctx - .issue_event(GlobalCtxEvent::Connecting(conn.remote_url().clone())); + .issue_event(GlobalCtxEvent::Connecting( + conn.lock().await.remote_url().clone(), + )); let _g = net_ns.guard(); log::info!("reconnect try connect... conn: {:?}", conn); - let tunnel = conn.connect().await?; + let tunnel = conn.lock().await.connect().await?; log::info!("reconnect get tunnel succ: {:?}", tunnel); assert_eq!( url_clone, @@ -359,7 +364,7 @@ mod tests { use crate::{ peers::tests::create_mock_peer_manager, set_global_var, - tunnels::{Tunnel, TunnelError}, + tunnel::{Tunnel, TunnelError}, }; use super::*; @@ -379,7 +384,7 @@ mod tests { } async fn connect(&mut self) -> Result, TunnelError> { tokio::time::sleep(std::time::Duration::from_millis(10)).await; - Err(TunnelError::CommonError("fake error".into())) + Err(TunnelError::InvalidPacket("fake error".into())) } } diff --git a/easytier/src/connector/mod.rs b/easytier/src/connector/mod.rs index 5e11f061..7df015dc 100644 --- a/easytier/src/connector/mod.rs +++ b/easytier/src/connector/mod.rs @@ -5,10 +5,10 @@ use std::{ use crate::{ common::{error::Error, global_ctx::ArcGlobalCtx, network::IPCollector}, - tunnels::{ - ring_tunnel::RingTunnelConnector, - tcp_tunnel::TcpTunnelConnector, - udp_tunnel::UdpTunnelConnector, + tunnel::{ + ring::RingTunnelConnector, + tcp::TcpTunnelConnector, + udp::UdpTunnelConnector, wireguard::{WgConfig, WgTunnelConnector}, TunnelConnector, }, @@ -19,7 +19,7 @@ pub mod manual; pub mod udp_hole_punch; async fn set_bind_addr_for_peer_connector( - connector: &mut impl TunnelConnector, + connector: &mut (impl TunnelConnector + ?Sized), is_ipv4: bool, ip_collector: &Arc, ) { @@ -45,7 +45,7 @@ async fn set_bind_addr_for_peer_connector( pub async fn create_connector_by_url( url: &str, global_ctx: &ArcGlobalCtx, -) -> Result, Error> { +) -> Result, Error> { let url = url::Url::parse(url).map_err(|_| Error::InvalidUrl(url.to_owned()))?; match url.scheme() { "tcp" => { diff --git a/easytier/src/connector/udp_hole_punch.rs b/easytier/src/connector/udp_hole_punch.rs index 016528f3..53f138e3 100644 --- a/easytier/src/connector/udp_hole_punch.rs +++ b/easytier/src/connector/udp_hole_punch.rs @@ -2,20 +2,21 @@ use std::{net::SocketAddr, sync::Arc}; use anyhow::Context; use crossbeam::atomic::AtomicCell; -use rand::{seq::SliceRandom, Rng, SeedableRng}; +use rand::{seq::SliceRandom, SeedableRng}; use tokio::{net::UdpSocket, sync::Mutex, task::JoinSet}; use tracing::Instrument; use crate::{ common::{ constants, error::Error, global_ctx::ArcGlobalCtx, join_joinset_background, - rkyv_util::encode_to_bytes, stun::StunInfoCollectorTrait, PeerId, + stun::StunInfoCollectorTrait, PeerId, }, peers::peer_manager::PeerManager, rpc::NatType, - tunnels::{ + tunnel::{ common::setup_sokcet2, - udp_tunnel::{UdpPacket, UdpTunnelConnector, UdpTunnelListener}, + packet_def::ZCPacketType, + udp::{new_hole_punch_packet, UdpTunnelConnector, UdpTunnelListener}, Tunnel, TunnelConnCounter, TunnelListener, }, }; @@ -149,15 +150,10 @@ impl UdpHolePunchService for UdpHolePunchRpcServer { self.tasks.lock().unwrap().spawn(async move { for _ in 0..10 { tracing::info!(?local_mapped_addr, "sending hole punching packet"); - // generate a 128 bytes vec with random data - let mut rng = rand::rngs::StdRng::from_entropy(); - let mut buf = vec![0u8; 128]; - rng.fill(&mut buf[..]); - let udp_packet = UdpPacket::new_hole_punch_packet(buf); - let udp_packet_bytes = encode_to_bytes::<_, 256>(&udp_packet); + let udp_packet = new_hole_punch_packet(); let _ = socket - .send_to(udp_packet_bytes.as_ref(), local_mapped_addr) + .send_to(&udp_packet.into_bytes(ZCPacketType::UDP), local_mapped_addr) .await; tokio::time::sleep(std::time::Duration::from_millis(300)).await; } diff --git a/easytier/src/easytier-cli.rs b/easytier/src/easytier-cli.rs index e4366fc9..25ff5ae8 100644 --- a/easytier/src/easytier-cli.rs +++ b/easytier/src/easytier-cli.rs @@ -9,6 +9,7 @@ use utils::{list_peer_route_pair, PeerRoutePair}; mod arch; mod common; mod rpc; +mod tunnel; mod tunnels; mod utils; diff --git a/easytier/src/easytier-core.rs b/easytier/src/easytier-core.rs index cd65c7a6..fdc091fb 100644 --- a/easytier/src/easytier-core.rs +++ b/easytier/src/easytier-core.rs @@ -16,6 +16,7 @@ mod instance; mod peer_center; mod peers; mod rpc; +mod tunnel; mod tunnels; mod vpn_portal; diff --git a/easytier/src/gateway/icmp_proxy.rs b/easytier/src/gateway/icmp_proxy.rs index 03ed0b05..02dfa116 100644 --- a/easytier/src/gateway/icmp_proxy.rs +++ b/easytier/src/gateway/icmp_proxy.rs @@ -16,12 +16,13 @@ use tokio::{ sync::{mpsc::UnboundedSender, Mutex}, task::JoinSet, }; -use tokio_util::bytes::Bytes; + use tracing::Instrument; use crate::{ common::{error::Error, global_ctx::ArcGlobalCtx, PeerId}, - peers::{packet, peer_manager::PeerManager, PeerPacketFilter}, + peers::{peer_manager::PeerManager, PeerPacketFilter}, + tunnel::packet_def::{PacketType, ZCPacket}, }; use super::CidrSet; @@ -78,11 +79,7 @@ fn socket_recv(socket: &Socket, buf: &mut [MaybeUninit]) -> Result<(usize, I Ok((size, addr)) } -fn socket_recv_loop( - socket: Socket, - nat_table: IcmpNatTable, - sender: UnboundedSender, -) { +fn socket_recv_loop(socket: Socket, nat_table: IcmpNatTable, sender: UnboundedSender) { let mut buf = [0u8; 4096]; let data: &mut [MaybeUninit] = unsafe { std::mem::transmute(&mut buf[12..]) }; @@ -126,13 +123,14 @@ fn socket_recv_loop( ipv4_packet.set_destination(dest_ip); ipv4_packet.set_checksum(ipv4::checksum(&ipv4_packet.to_immutable())); - let peer_packet = packet::Packet::new_data_packet( - v.my_peer_id, - v.src_peer_id, - &ipv4_packet.to_immutable().packet(), + let mut p = ZCPacket::new_with_payload(ipv4_packet.packet()); + p.fill_peer_manager_hdr( + v.my_peer_id.into(), + v.src_peer_id.into(), + PacketType::Data as u8, ); - if let Err(e) = sender.send(peer_packet) { + if let Err(e) = sender.send(p) { tracing::error!("send icmp packet to peer failed: {:?}, may exiting..", e); break; } @@ -141,61 +139,12 @@ fn socket_recv_loop( #[async_trait::async_trait] impl PeerPacketFilter for IcmpProxy { - async fn try_process_packet_from_peer( - &self, - packet: &packet::ArchivedPacket, - _: &Bytes, - ) -> Option<()> { - let _ = self.global_ctx.get_ipv4()?; - - if packet.packet_type != packet::PacketType::Data { - return None; - }; - - let ipv4 = Ipv4Packet::new(&packet.payload.as_bytes())?; - - if ipv4.get_version() != 4 || ipv4.get_next_level_protocol() != IpNextHeaderProtocols::Icmp - { + async fn try_process_packet_from_peer(&self, packet: ZCPacket) -> Option { + if let Some(_) = self.try_handle_peer_packet(&packet).await { return None; + } else { + return Some(packet); } - - if !self.cidr_set.contains_v4(ipv4.get_destination()) { - return None; - } - - let icmp_packet = icmp::echo_request::EchoRequestPacket::new(&ipv4.payload())?; - - if icmp_packet.get_icmp_type() != IcmpTypes::EchoRequest { - // drop it because we do not support other icmp types - tracing::trace!("unsupported icmp type: {:?}", icmp_packet.get_icmp_type()); - return Some(()); - } - - let icmp_id = icmp_packet.get_identifier(); - let icmp_seq = icmp_packet.get_sequence_number(); - - let key = IcmpNatKey { - dst_ip: ipv4.get_destination().into(), - icmp_id, - icmp_seq, - }; - - let value = IcmpNatEntry::new( - packet.from_peer.into(), - packet.to_peer.into(), - ipv4.get_source().into(), - ) - .ok()?; - - if let Some(old) = self.nat_table.insert(key, value) { - tracing::info!("icmp nat table entry replaced: {:?}", old); - } - - if let Err(e) = self.send_icmp_packet(ipv4.get_destination(), &icmp_packet) { - tracing::error!("send icmp packet failed: {:?}", e); - } - - Some(()) } } @@ -262,8 +211,9 @@ impl IcmpProxy { self.tasks.lock().await.spawn( async move { while let Some(msg) = receiver.recv().await { - let to_peer_id = msg.to_peer.into(); - let ret = peer_manager.send_msg(msg.into(), to_peer_id).await; + let hdr = msg.peer_manager_header().unwrap(); + let to_peer_id = hdr.to_peer_id.into(); + let ret = peer_manager.send_msg(msg, to_peer_id).await; if ret.is_err() { tracing::error!("send icmp packet to peer failed: {:?}", ret); } @@ -290,4 +240,58 @@ impl IcmpProxy { Ok(()) } + + async fn try_handle_peer_packet(&self, packet: &ZCPacket) -> Option<()> { + let _ = self.global_ctx.get_ipv4()?; + let hdr = packet.peer_manager_header().unwrap(); + + if hdr.packet_type != PacketType::Data as u8 { + return None; + }; + + let ipv4 = Ipv4Packet::new(&packet.payload())?; + + if ipv4.get_version() != 4 || ipv4.get_next_level_protocol() != IpNextHeaderProtocols::Icmp + { + return None; + } + + if !self.cidr_set.contains_v4(ipv4.get_destination()) { + return None; + } + + let icmp_packet = icmp::echo_request::EchoRequestPacket::new(&ipv4.payload())?; + + if icmp_packet.get_icmp_type() != IcmpTypes::EchoRequest { + // drop it because we do not support other icmp types + tracing::trace!("unsupported icmp type: {:?}", icmp_packet.get_icmp_type()); + return Some(()); + } + + let icmp_id = icmp_packet.get_identifier(); + let icmp_seq = icmp_packet.get_sequence_number(); + + let key = IcmpNatKey { + dst_ip: ipv4.get_destination().into(), + icmp_id, + icmp_seq, + }; + + let value = IcmpNatEntry::new( + hdr.from_peer_id.into(), + hdr.to_peer_id.into(), + ipv4.get_source().into(), + ) + .ok()?; + + if let Some(old) = self.nat_table.insert(key, value) { + tracing::info!("icmp nat table entry replaced: {:?}", old); + } + + if let Err(e) = self.send_icmp_packet(ipv4.get_destination(), &icmp_packet) { + tracing::error!("send icmp packet failed: {:?}", e); + } + + Some(()) + } } diff --git a/easytier/src/gateway/tcp_proxy.rs b/easytier/src/gateway/tcp_proxy.rs index f5abd524..63452a30 100644 --- a/easytier/src/gateway/tcp_proxy.rs +++ b/easytier/src/gateway/tcp_proxy.rs @@ -2,7 +2,9 @@ use crossbeam::atomic::AtomicCell; use dashmap::DashMap; use pnet::packet::ip::IpNextHeaderProtocols; use pnet::packet::ipv4::{Ipv4Packet, MutableIpv4Packet}; -use pnet::packet::tcp::{ipv4_checksum, MutableTcpPacket}; +use pnet::packet::tcp::{ipv4_checksum, MutableTcpPacket, TcpPacket}; +use pnet::packet::MutablePacket; +use pnet::packet::Packet; use std::net::{IpAddr, Ipv4Addr, SocketAddr, SocketAddrV4}; use std::sync::atomic::AtomicU16; use std::sync::Arc; @@ -11,16 +13,16 @@ use tokio::io::copy_bidirectional; use tokio::net::{TcpListener, TcpSocket, TcpStream}; use tokio::sync::Mutex; use tokio::task::JoinSet; -use tokio_util::bytes::{Bytes, BytesMut}; use tracing::Instrument; use crate::common::error::Result; use crate::common::global_ctx::GlobalCtx; use crate::common::join_joinset_background; use crate::common::netns::NetNS; -use crate::peers::packet::{self, ArchivedPacket}; + use crate::peers::peer_manager::PeerManager; use crate::peers::{NicPacketFilter, PeerPacketFilter}; +use crate::tunnel::packet_def::{PacketType, ZCPacket}; use super::CidrSet; @@ -83,98 +85,37 @@ pub struct TcpProxy { #[async_trait::async_trait] impl PeerPacketFilter for TcpProxy { - async fn try_process_packet_from_peer(&self, packet: &ArchivedPacket, _: &Bytes) -> Option<()> { - let ipv4_addr = self.global_ctx.get_ipv4()?; - - if packet.packet_type != packet::PacketType::Data { - return None; - }; - - let payload_bytes = packet.payload.as_bytes(); - - let ipv4 = Ipv4Packet::new(payload_bytes)?; - if ipv4.get_version() != 4 || ipv4.get_next_level_protocol() != IpNextHeaderProtocols::Tcp { + async fn try_process_packet_from_peer(&self, mut packet: ZCPacket) -> Option { + if let Some(_) = self.try_handle_peer_packet(&mut packet).await { + if let Err(e) = self.peer_manager.get_nic_channel().send(packet).await { + tracing::error!("send to nic failed: {:?}", e); + } return None; + } else { + Some(packet) } - - if !self.cidr_set.contains_v4(ipv4.get_destination()) { - return None; - } - - tracing::trace!(ipv4 = ?ipv4, cidr_set = ?self.cidr_set, "proxy tcp packet received"); - - let mut packet_buffer = BytesMut::with_capacity(payload_bytes.len()); - packet_buffer.extend_from_slice(&payload_bytes.to_vec()); - - let (ip_buffer, tcp_buffer) = - packet_buffer.split_at_mut(ipv4.get_header_length() as usize * 4); - - let mut ip_packet = MutableIpv4Packet::new(ip_buffer).unwrap(); - let mut tcp_packet = MutableTcpPacket::new(tcp_buffer).unwrap(); - - let is_tcp_syn = tcp_packet.get_flags() & pnet::packet::tcp::TcpFlags::SYN != 0; - if is_tcp_syn { - let source_ip = ip_packet.get_source(); - let source_port = tcp_packet.get_source(); - let src = SocketAddr::V4(SocketAddrV4::new(source_ip, source_port)); - - let dest_ip = ip_packet.get_destination(); - let dest_port = tcp_packet.get_destination(); - let dst = SocketAddr::V4(SocketAddrV4::new(dest_ip, dest_port)); - - let old_val = self - .syn_map - .insert(src, Arc::new(NatDstEntry::new(src, dst))); - tracing::trace!(src = ?src, dst = ?dst, old_entry = ?old_val, "tcp syn received"); - } - - ip_packet.set_destination(ipv4_addr); - tcp_packet.set_destination(self.get_local_port()); - Self::update_ipv4_packet_checksum(&mut ip_packet, &mut tcp_packet); - - tracing::trace!(ip_packet = ?ip_packet, tcp_packet = ?tcp_packet, "tcp packet forwarded"); - - if let Err(e) = self - .peer_manager - .get_nic_channel() - .send(packet_buffer.freeze()) - .await - { - tracing::error!("send to nic failed: {:?}", e); - } - - Some(()) } } #[async_trait::async_trait] impl NicPacketFilter for TcpProxy { - async fn try_process_packet_from_nic(&self, mut data: BytesMut) -> BytesMut { + async fn try_process_packet_from_nic(&self, zc_packet: &mut ZCPacket) { let Some(my_ipv4) = self.global_ctx.get_ipv4() else { - return data; + return; }; - let header_len = { - let Some(ipv4) = &Ipv4Packet::new(&data[..]) else { - return data; - }; - - if ipv4.get_version() != 4 - || ipv4.get_source() != my_ipv4 - || ipv4.get_next_level_protocol() != IpNextHeaderProtocols::Tcp - { - return data; - } - - ipv4.get_header_length() as usize * 4 - }; - - let (ip_buffer, tcp_buffer) = data.split_at_mut(header_len); - let mut ip_packet = MutableIpv4Packet::new(ip_buffer).unwrap(); - let mut tcp_packet = MutableTcpPacket::new(tcp_buffer).unwrap(); + let data = zc_packet.payload(); + let ip_packet = Ipv4Packet::new(data).unwrap(); + if ip_packet.get_version() != 4 + || ip_packet.get_source() != my_ipv4 + || ip_packet.get_next_level_protocol() != IpNextHeaderProtocols::Tcp + { + return; + } + let tcp_packet = TcpPacket::new(ip_packet.payload()).unwrap(); if tcp_packet.get_source() != self.get_local_port() { - return data; + return; } let dst_addr = SocketAddr::V4(SocketAddrV4::new( @@ -187,7 +128,7 @@ impl NicPacketFilter for TcpProxy { entry } else { let Some(syn_entry) = self.syn_map.get(&dst_addr) else { - return data; + return; }; syn_entry }; @@ -199,13 +140,18 @@ impl NicPacketFilter for TcpProxy { panic!("v4 nat entry src ip is not v4"); }; + let mut ip_packet = MutableIpv4Packet::new(zc_packet.mut_payload()).unwrap(); ip_packet.set_source(ip); + let dst = ip_packet.get_destination(); + + let mut tcp_packet = MutableTcpPacket::new(ip_packet.payload_mut()).unwrap(); tcp_packet.set_source(nat_entry.dst.port()); - Self::update_ipv4_packet_checksum(&mut ip_packet, &mut tcp_packet); + + Self::update_tcp_packet_checksum(&mut tcp_packet, &ip, &dst); + drop(tcp_packet); + Self::update_ip_packet_checksum(&mut ip_packet); tracing::trace!(dst_addr = ?dst_addr, nat_entry = ?nat_entry, packet = ?ip_packet, "tcp packet after modified"); - - data } } @@ -226,17 +172,20 @@ impl TcpProxy { }) } - fn update_ipv4_packet_checksum( - ipv4_packet: &mut MutableIpv4Packet, + fn update_tcp_packet_checksum( tcp_packet: &mut MutableTcpPacket, + ipv4_src: &Ipv4Addr, + ipv4_dst: &Ipv4Addr, ) { tcp_packet.set_checksum(ipv4_checksum( &tcp_packet.to_immutable(), - &ipv4_packet.get_source(), - &ipv4_packet.get_destination(), + ipv4_src, + ipv4_dst, )); + } - ipv4_packet.set_checksum(pnet::packet::ipv4::checksum(&ipv4_packet.to_immutable())); + fn update_ip_packet_checksum(ip_packet: &mut MutableIpv4Packet) { + ip_packet.set_checksum(pnet::packet::ipv4::checksum(&ip_packet.to_immutable())); } pub async fn start(self: &Arc) -> Result<()> { @@ -302,6 +251,7 @@ impl TcpProxy { tracing::error!("tcp connection from unknown source: {:?}", socket_addr); continue; }; + tracing::info!(?socket_addr, "tcp connection accepted for proxy"); assert_eq!(entry.state.load(), NatDstEntryState::SynReceived); let entry_clone = entry.clone(); @@ -404,4 +354,60 @@ impl TcpProxy { pub fn get_local_port(&self) -> u16 { self.local_port.load(std::sync::atomic::Ordering::Relaxed) } + + async fn try_handle_peer_packet(&self, packet: &mut ZCPacket) -> Option<()> { + let ipv4_addr = self.global_ctx.get_ipv4()?; + let hdr = packet.peer_manager_header().unwrap(); + + if hdr.packet_type != PacketType::Data as u8 { + return None; + }; + + let payload_bytes = packet.mut_payload(); + + let ipv4 = Ipv4Packet::new(payload_bytes)?; + if ipv4.get_version() != 4 || ipv4.get_next_level_protocol() != IpNextHeaderProtocols::Tcp { + return None; + } + + if !self.cidr_set.contains_v4(ipv4.get_destination()) { + return None; + } + + tracing::info!(ipv4 = ?ipv4, cidr_set = ?self.cidr_set, "proxy tcp packet received"); + + let ip_packet = Ipv4Packet::new(payload_bytes).unwrap(); + let tcp_packet = TcpPacket::new(ip_packet.payload()).unwrap(); + + let is_tcp_syn = tcp_packet.get_flags() & pnet::packet::tcp::TcpFlags::SYN != 0; + if is_tcp_syn { + let source_ip = ip_packet.get_source(); + let source_port = tcp_packet.get_source(); + let src = SocketAddr::V4(SocketAddrV4::new(source_ip, source_port)); + + let dest_ip = ip_packet.get_destination(); + let dest_port = tcp_packet.get_destination(); + let dst = SocketAddr::V4(SocketAddrV4::new(dest_ip, dest_port)); + + let old_val = self + .syn_map + .insert(src, Arc::new(NatDstEntry::new(src, dst))); + tracing::trace!(src = ?src, dst = ?dst, old_entry = ?old_val, "tcp syn received"); + } + + let mut ip_packet = MutableIpv4Packet::new(payload_bytes).unwrap(); + ip_packet.set_destination(ipv4_addr); + let source = ip_packet.get_source(); + + let mut tcp_packet = MutableTcpPacket::new(ip_packet.payload_mut()).unwrap(); + tcp_packet.set_destination(self.get_local_port()); + + Self::update_tcp_packet_checksum(&mut tcp_packet, &source, &ipv4_addr); + drop(tcp_packet); + Self::update_ip_packet_checksum(&mut ip_packet); + + tracing::info!(?source, ?ipv4_addr, ?packet, "tcp packet after modified"); + + Some(()) + } } diff --git a/easytier/src/gateway/udp_proxy.rs b/easytier/src/gateway/udp_proxy.rs index 4a5ebf6a..c6119380 100644 --- a/easytier/src/gateway/udp_proxy.rs +++ b/easytier/src/gateway/udp_proxy.rs @@ -21,12 +21,12 @@ use tokio::{ time::timeout, }; -use tokio_util::bytes::Bytes; use tracing::Level; use crate::{ common::{error::Error, global_ctx::ArcGlobalCtx, PeerId}, - peers::{packet, peer_manager::PeerManager, PeerPacketFilter}, + peers::{peer_manager::PeerManager, PeerPacketFilter}, + tunnel::packet_def::{PacketType, ZCPacket}, tunnels::common::setup_sokcet2, }; @@ -79,7 +79,7 @@ impl UdpNatEntry { async fn compose_ipv4_packet( self: &Arc, - packet_sender: &mut UnboundedSender, + packet_sender: &mut UnboundedSender, buf: &mut [u8], src_v4: &SocketAddrV4, payload_len: usize, @@ -140,13 +140,10 @@ impl UdpNatEntry { tracing::trace!(?ipv4_packet, "udp nat packet response send"); - let peer_packet = packet::Packet::new_data_packet( - self.my_peer_id, - self.src_peer_id, - &ipv4_packet.to_immutable().packet(), - ); + let mut p = ZCPacket::new_with_payload(ipv4_packet.packet()); + p.fill_peer_manager_hdr(self.my_peer_id, self.src_peer_id, PacketType::Data as u8); - if let Err(e) = packet_sender.send(peer_packet) { + if let Err(e) = packet_sender.send(p) { tracing::error!("send icmp packet to peer failed: {:?}, may exiting..", e); return Err(Error::AnyhowError(e.into())); } @@ -158,7 +155,7 @@ impl UdpNatEntry { Ok(()) } - async fn forward_task(self: Arc, mut packet_sender: UnboundedSender) { + async fn forward_task(self: Arc, mut packet_sender: UnboundedSender) { let mut buf = [0u8; 8192]; let mut udp_body: &mut [u8] = unsafe { std::mem::transmute(&mut buf[20 + 8..]) }; let mut ip_id = 1; @@ -220,31 +217,25 @@ pub struct UdpProxy { nat_table: Arc>>, - sender: UnboundedSender, - receiver: Mutex>>, + sender: UnboundedSender, + receiver: Mutex>>, tasks: Mutex>, } -#[async_trait::async_trait] -impl PeerPacketFilter for UdpProxy { - async fn try_process_packet_from_peer( - &self, - packet: &packet::ArchivedPacket, - _: &Bytes, - ) -> Option<()> { +impl UdpProxy { + async fn try_handle_packet(&self, packet: &ZCPacket) -> Option<()> { if self.cidr_set.is_empty() { return None; } let _ = self.global_ctx.get_ipv4()?; - - if packet.packet_type != packet::PacketType::Data { + let hdr = packet.peer_manager_header().unwrap(); + if hdr.packet_type != PacketType::Data as u8 { return None; }; - let ipv4 = Ipv4Packet::new(packet.payload.as_bytes())?; - + let ipv4 = Ipv4Packet::new(packet.payload())?; if ipv4.get_version() != 4 || ipv4.get_next_level_protocol() != IpNextHeaderProtocols::Udp { return None; } @@ -272,8 +263,8 @@ impl PeerPacketFilter for UdpProxy { tracing::info!(?packet, ?ipv4, ?udp_packet, "udp nat table entry created"); let _g = self.global_ctx.net_ns.guard(); Ok(Arc::new(UdpNatEntry::new( - packet.from_peer.into(), - packet.to_peer.into(), + hdr.from_peer_id.get(), + hdr.to_peer_id.get(), nat_key.src_socket, )?)) }) @@ -316,6 +307,17 @@ impl PeerPacketFilter for UdpProxy { } } +#[async_trait::async_trait] +impl PeerPacketFilter for UdpProxy { + async fn try_process_packet_from_peer(&self, packet: ZCPacket) -> Option { + if let Some(_) = self.try_handle_packet(&packet).await { + return None; + } else { + return Some(packet); + } + } +} + impl UdpProxy { pub fn new( global_ctx: ArcGlobalCtx, @@ -362,9 +364,9 @@ impl UdpProxy { let peer_manager = self.peer_manager.clone(); self.tasks.lock().await.spawn(async move { while let Some(msg) = receiver.recv().await { - let to_peer_id: PeerId = msg.to_peer.into(); + let to_peer_id: PeerId = msg.peer_manager_header().unwrap().to_peer_id.get(); tracing::trace!(?msg, ?to_peer_id, "udp nat packet response send"); - let ret = peer_manager.send_msg(msg.into(), to_peer_id).await; + let ret = peer_manager.send_msg(msg, to_peer_id).await; if ret.is_err() { tracing::error!("send icmp packet to peer failed: {:?}", ret); } diff --git a/easytier/src/instance/instance.rs b/easytier/src/instance/instance.rs index 317665a7..a39c208b 100644 --- a/easytier/src/instance/instance.rs +++ b/easytier/src/instance/instance.rs @@ -3,12 +3,12 @@ use std::net::Ipv4Addr; use std::sync::{Arc, Weak}; use anyhow::Context; -use futures::StreamExt; +use futures::{SinkExt, StreamExt}; use pnet::packet::ethernet::EthernetPacket; use pnet::packet::ipv4::Ipv4Packet; +use bytes::BytesMut; use tokio::{sync::Mutex, task::JoinSet}; -use tokio_util::bytes::{Bytes, BytesMut}; use tonic::transport::Server; use crate::common::config::ConfigLoader; @@ -22,15 +22,15 @@ use crate::gateway::icmp_proxy::IcmpProxy; use crate::gateway::tcp_proxy::TcpProxy; use crate::gateway::udp_proxy::UdpProxy; use crate::peer_center::instance::PeerCenterInstance; -use crate::peers::peer_conn::PeerConnId; use crate::peers::peer_manager::{PeerManager, RouteAlgoType}; use crate::peers::rpc_service::PeerManagerRpcService; +use crate::peers::zc_peer_conn::PeerConnId; +use crate::peers::PacketRecvChanReceiver; use crate::rpc::vpn_portal_rpc_server::VpnPortalRpc; use crate::rpc::{GetVpnPortalInfoRequest, GetVpnPortalInfoResponse, VpnPortalInfo}; -use crate::tunnels::SinkItem; -use crate::vpn_portal::{self, VpnPortal}; +use crate::tunnel::packet_def::ZCPacket; -use tokio_stream::wrappers::ReceiverStream; +use crate::vpn_portal::{self, VpnPortal}; use super::listeners::ListenerManager; use super::virtual_nic; @@ -70,7 +70,7 @@ pub struct Instance { id: uuid::Uuid, virtual_nic: Option>, - peer_packet_receiver: Option>, + peer_packet_receiver: Option, tasks: JoinSet<()>, @@ -133,7 +133,7 @@ impl Instance { id, virtual_nic: None, - peer_packet_receiver: Some(ReceiverStream::new(peer_packet_receiver)), + peer_packet_receiver: Some(peer_packet_receiver), tasks: JoinSet::new(), peer_manager, @@ -167,7 +167,11 @@ impl Instance { ?ret, "[USER_PACKET] recv new packet from tun device and forward to peers." ); - let send_ret = mgr.send_msg_ipv4(ret, dst_ipv4).await; + + // TODO: use zero-copy + let send_ret = mgr + .send_msg_ipv4(ZCPacket::new_with_payload(ret.as_ref()), dst_ipv4) + .await; if send_ret.is_err() { tracing::trace!(?send_ret, "[USER_PACKET] send_msg_ipv4 failed") } @@ -209,23 +213,23 @@ impl Instance { fn do_forward_peers_to_nic( tasks: &mut JoinSet<()>, nic: Arc, - channel: Option>, + channel: Option, ) { tasks.spawn(async move { - let send = nic.pin_send_stream(); - let channel = channel.unwrap(); - let ret = channel - .map(|packet| { - log::trace!( - "[USER_PACKET] forward packet from peers to nic. packet: {:?}", - packet - ); - Ok(packet) - }) - .forward(send) - .await; - if ret.is_err() { - panic!("do_forward_tunnel_to_nic"); + let mut send = nic.pin_send_stream(); + let mut channel = channel.unwrap(); + while let Some(packet) = channel.recv().await { + tracing::trace!( + "[USER_PACKET] forward packet from peers to nic. packet: {:?}", + packet + ); + let mut b = BytesMut::new(); + b.extend_from_slice(packet.payload()); + + let ret = send.send(b.freeze()).await; + if ret.is_err() { + panic!("do_forward_tunnel_to_nic"); + } } }); } @@ -300,17 +304,25 @@ impl Instance { self.add_initial_peers().await?; - if let Some(_) = self.global_ctx.get_vpn_portal_cidr() { - self.vpn_portal - .lock() - .await - .start(self.get_global_ctx(), self.get_peer_manager()) - .await?; + if self.global_ctx.get_vpn_portal_cidr().is_some() { + self.run_vpn_portal().await?; } Ok(()) } + pub async fn run_vpn_portal(&mut self) -> Result<(), Error> { + if self.global_ctx.get_vpn_portal_cidr().is_none() { + return Err(anyhow::anyhow!("vpn portal cidr not set.").into()); + } + self.vpn_portal + .lock() + .await + .start(self.get_global_ctx(), self.get_peer_manager()) + .await?; + Ok(()) + } + pub fn get_peer_manager(&self) -> Arc { self.peer_manager.clone() } diff --git a/easytier/src/instance/listeners.rs b/easytier/src/instance/listeners.rs index 58860226..ee8e1850 100644 --- a/easytier/src/instance/listeners.rs +++ b/easytier/src/instance/listeners.rs @@ -11,10 +11,10 @@ use crate::{ netns::NetNS, }, peers::peer_manager::PeerManager, - tunnels::{ - ring_tunnel::RingTunnelListener, - tcp_tunnel::TcpTunnelListener, - udp_tunnel::UdpTunnelListener, + tunnel::{ + ring::RingTunnelListener, + tcp::TcpTunnelListener, + udp::UdpTunnelListener, wireguard::{WgConfig, WgTunnelListener}, Tunnel, TunnelListener, }, @@ -155,7 +155,7 @@ mod tests { use crate::{ common::global_ctx::tests::get_mock_global_ctx, - tunnels::{ring_tunnel::RingTunnelConnector, TunnelConnector}, + tunnel::{packet_def::ZCPacket, ring::RingTunnelConnector, TunnelConnector}, }; use super::*; @@ -165,9 +165,12 @@ mod tests { #[async_trait] impl TunnelHandlerForListener for MockListenerHandler { - async fn handle_tunnel(&self, _tunnel: Box) -> Result<(), Error> { + async fn handle_tunnel(&self, tunnel: Box) -> Result<(), Error> { let data = "abc"; - _tunnel.pin_sink().send(data.into()).await.unwrap(); + let (_recv, mut send) = tunnel.split(); + + let zc_packet = ZCPacket::new_with_payload(data.as_bytes()); + send.send(zc_packet).await.unwrap(); Err(Error::Unknown) } } @@ -187,7 +190,11 @@ mod tests { let connect_once = |ring_id| async move { let tunnel = RingTunnelConnector::new(ring_id).connect().await.unwrap(); - assert_eq!(tunnel.pin_stream().next().await.unwrap().unwrap(), "abc"); + let (mut recv, _send) = tunnel.split(); + assert_eq!( + recv.next().await.unwrap().unwrap().payload(), + "abc".as_bytes() + ); tunnel }; diff --git a/easytier/src/lib.rs b/easytier/src/lib.rs index 2a380341..bbba5700 100644 --- a/easytier/src/lib.rs +++ b/easytier/src/lib.rs @@ -8,6 +8,7 @@ pub mod instance; pub mod peer_center; pub mod peers; pub mod rpc; +pub mod tunnel; pub mod tunnels; pub mod utils; pub mod vpn_portal; diff --git a/easytier/src/peers/foreign_network_client.rs b/easytier/src/peers/foreign_network_client.rs index 1137ef32..b2763638 100644 --- a/easytier/src/peers/foreign_network_client.rs +++ b/easytier/src/peers/foreign_network_client.rs @@ -4,23 +4,23 @@ use std::{ }; use dashmap::DashMap; -use tokio::{ - sync::{mpsc, Mutex}, - task::JoinSet, -}; -use tokio_util::bytes::Bytes; +use tokio::{sync::Mutex, task::JoinSet}; -use crate::common::{ - error::Error, - global_ctx::{ArcGlobalCtx, NetworkIdentity}, - PeerId, +use crate::{ + common::{ + error::Error, + global_ctx::{ArcGlobalCtx, NetworkIdentity}, + PeerId, + }, + tunnel::packet_def::ZCPacket, }; use super::{ foreign_network_manager::{ForeignNetworkServiceClient, FOREIGN_NETWORK_SERVICE_ID}, - peer_conn::PeerConn, peer_map::PeerMap, peer_rpc::PeerRpcManager, + zc_peer_conn::PeerConn, + PacketRecvChan, }; pub struct ForeignNetworkClient { @@ -37,7 +37,7 @@ pub struct ForeignNetworkClient { impl ForeignNetworkClient { pub fn new( global_ctx: ArcGlobalCtx, - packet_sender_to_mgr: mpsc::Sender, + packet_sender_to_mgr: PacketRecvChan, peer_rpc: Arc, my_peer_id: PeerId, ) -> Self { @@ -148,7 +148,7 @@ impl ForeignNetworkClient { self.next_hop.get(&peer_id).map(|v| v.clone()) } - pub async fn send_msg(&self, msg: Bytes, peer_id: PeerId) -> Result<(), Error> { + pub async fn send_msg(&self, msg: ZCPacket, peer_id: PeerId) -> Result<(), Error> { if let Some(next_hop) = self.get_next_hop(peer_id) { let ret = self.peer_map.send_msg_directly(msg, next_hop).await; if ret.is_err() { diff --git a/easytier/src/peers/foreign_network_manager.rs b/easytier/src/peers/foreign_network_manager.rs index 1d959f58..6ceecf9d 100644 --- a/easytier/src/peers/foreign_network_manager.rs +++ b/easytier/src/peers/foreign_network_manager.rs @@ -15,19 +15,21 @@ use tokio::{ }, task::JoinSet, }; -use tokio_util::bytes::Bytes; -use crate::common::{ - error::Error, - global_ctx::{ArcGlobalCtx, GlobalCtxEvent, NetworkIdentity}, - PeerId, +use crate::{ + common::{ + error::Error, + global_ctx::{ArcGlobalCtx, GlobalCtxEvent, NetworkIdentity}, + PeerId, + }, + tunnel::packet_def::{PacketType, ZCPacket}, }; use super::{ - packet::{self}, - peer_conn::PeerConn, peer_map::PeerMap, peer_rpc::{PeerRpcManager, PeerRpcManagerTransport}, + zc_peer_conn::PeerConn, + PacketRecvChan, PacketRecvChanReceiver, }; struct ForeignNetworkEntry { @@ -38,7 +40,7 @@ struct ForeignNetworkEntry { impl ForeignNetworkEntry { fn new( network: NetworkIdentity, - packet_sender: mpsc::Sender, + packet_sender: PacketRecvChan, global_ctx: ArcGlobalCtx, my_peer_id: PeerId, ) -> Self { @@ -53,7 +55,7 @@ struct ForeignNetworkManagerData { } impl ForeignNetworkManagerData { - async fn send_msg(&self, msg: Bytes, dst_peer_id: PeerId) -> Result<(), Error> { + async fn send_msg(&self, msg: ZCPacket, dst_peer_id: PeerId) -> Result<(), Error> { let network_name = self .peer_network_map .get(&dst_peer_id) @@ -94,7 +96,7 @@ struct RpcTransport { my_peer_id: PeerId, data: Arc, - packet_recv: Mutex>, + packet_recv: Mutex>, } #[async_trait::async_trait] @@ -103,11 +105,11 @@ impl PeerRpcManagerTransport for RpcTransport { self.my_peer_id } - async fn send(&self, msg: Bytes, dst_peer_id: PeerId) -> Result<(), Error> { + async fn send(&self, msg: ZCPacket, dst_peer_id: PeerId) -> Result<(), Error> { self.data.send_msg(msg, dst_peer_id).await } - async fn recv(&self) -> Result { + async fn recv(&self) -> Result { if let Some(o) = self.packet_recv.lock().await.recv().await { Ok(o) } else { @@ -138,14 +140,14 @@ impl ForeignNetworkService for Arc { pub struct ForeignNetworkManager { my_peer_id: PeerId, global_ctx: ArcGlobalCtx, - packet_sender_to_mgr: mpsc::Sender, + packet_sender_to_mgr: PacketRecvChan, - packet_sender: mpsc::Sender, - packet_recv: Mutex>>, + packet_sender: PacketRecvChan, + packet_recv: Mutex>, data: Arc, rpc_mgr: Arc, - rpc_transport_sender: UnboundedSender, + rpc_transport_sender: UnboundedSender, tasks: Mutex>, } @@ -154,7 +156,7 @@ impl ForeignNetworkManager { pub fn new( my_peer_id: PeerId, global_ctx: ArcGlobalCtx, - packet_sender_to_mgr: mpsc::Sender, + packet_sender_to_mgr: PacketRecvChan, ) -> Self { // recv packet from all foreign networks let (packet_sender, packet_recv) = mpsc::channel(1000); @@ -242,12 +244,15 @@ impl ForeignNetworkManager { self.tasks.lock().await.spawn(async move { while let Some(packet_bytes) = recv.recv().await { - let packet = packet::Packet::decode(&packet_bytes); - let from_peer_id = packet.from_peer.into(); - let to_peer_id = packet.to_peer.into(); + let Some(hdr) = packet_bytes.peer_manager_header() else { + tracing::warn!("invalid packet, skip"); + continue; + }; + let from_peer_id = hdr.from_peer_id.get(); + let to_peer_id = hdr.to_peer_id.get(); if to_peer_id == my_node_id { - if packet.packet_type == packet::PacketType::TaRpc { - rpc_sender.send(packet_bytes.clone()).unwrap(); + if hdr.packet_type == PacketType::TaRpc as u8 { + rpc_sender.send(packet_bytes).unwrap(); continue; } if let Err(e) = sender_to_mgr.send(packet_bytes).await { @@ -343,6 +348,27 @@ mod tests { peer_mgr } + #[tokio::test] + async fn foreign_network_basic() { + let pm_center = create_mock_peer_manager_with_mock_stun(crate::rpc::NatType::Unknown).await; + tracing::debug!("pm_center: {:?}", pm_center.my_peer_id()); + + let pma_net1 = create_mock_peer_manager_for_foreign_network("net1").await; + let pmb_net1 = create_mock_peer_manager_for_foreign_network("net1").await; + tracing::debug!( + "pma_net1: {:?}, pmb_net1: {:?}", + pma_net1.my_peer_id(), + pmb_net1.my_peer_id() + ); + connect_peer_manager(pma_net1.clone(), pm_center.clone()).await; + connect_peer_manager(pmb_net1.clone(), pm_center.clone()).await; + wait_route_appear(pma_net1.clone(), pmb_net1.clone()) + .await + .unwrap(); + assert_eq!(1, pma_net1.list_routes().await.len()); + assert_eq!(1, pmb_net1.list_routes().await.len()); + } + #[tokio::test] async fn test_foreign_network_manager() { let pm_center = create_mock_peer_manager_with_mock_stun(crate::rpc::NatType::Unknown).await; @@ -350,11 +376,23 @@ mod tests { create_mock_peer_manager_with_mock_stun(crate::rpc::NatType::Unknown).await; connect_peer_manager(pm_center.clone(), pm_center2.clone()).await; + tracing::debug!( + "pm_center: {:?}, pm_center2: {:?}", + pm_center.my_peer_id(), + pm_center2.my_peer_id() + ); + let pma_net1 = create_mock_peer_manager_for_foreign_network("net1").await; let pmb_net1 = create_mock_peer_manager_for_foreign_network("net1").await; connect_peer_manager(pma_net1.clone(), pm_center.clone()).await; connect_peer_manager(pmb_net1.clone(), pm_center.clone()).await; + tracing::debug!( + "pma_net1: {:?}, pmb_net1: {:?}", + pma_net1.my_peer_id(), + pmb_net1.my_peer_id() + ); + let now = std::time::Instant::now(); let mut succ = false; while now.elapsed().as_secs() < 10 { @@ -399,8 +437,15 @@ mod tests { .unwrap(); assert_eq!(2, pmc_net1.list_routes().await.len()); + tracing::debug!("pmc_net1: {:?}", pmc_net1.my_peer_id()); + let pma_net2 = create_mock_peer_manager_for_foreign_network("net2").await; let pmb_net2 = create_mock_peer_manager_for_foreign_network("net2").await; + tracing::debug!( + "pma_net2: {:?}, pmb_net2: {:?}", + pma_net2.my_peer_id(), + pmb_net2.my_peer_id() + ); connect_peer_manager(pma_net2.clone(), pm_center.clone()).await; connect_peer_manager(pmb_net2.clone(), pm_center.clone()).await; wait_route_appear(pma_net2.clone(), pmb_net2.clone()) diff --git a/easytier/src/peers/mod.rs b/easytier/src/peers/mod.rs index be7b4d83..0cf06a84 100644 --- a/easytier/src/peers/mod.rs +++ b/easytier/src/peers/mod.rs @@ -1,6 +1,7 @@ pub mod packet; pub mod peer; -pub mod peer_conn; +// pub mod peer_conn; +pub mod peer_conn_ping; pub mod peer_manager; pub mod peer_map; pub mod peer_ospf_route; @@ -8,6 +9,7 @@ pub mod peer_rip_route; pub mod peer_rpc; pub mod route_trait; pub mod rpc_service; +pub mod zc_peer_conn; pub mod foreign_network_client; pub mod foreign_network_manager; @@ -15,25 +17,24 @@ pub mod foreign_network_manager; #[cfg(test)] pub mod tests; -use tokio_util::bytes::{Bytes, BytesMut}; +use crate::tunnel::packet_def::ZCPacket; #[async_trait::async_trait] #[auto_impl::auto_impl(Arc)] pub trait PeerPacketFilter { - async fn try_process_packet_from_peer( - &self, - _packet: &packet::ArchivedPacket, - _data: &Bytes, - ) -> Option<()> { - None + async fn try_process_packet_from_peer(&self, _zc_packet: ZCPacket) -> Option { + Some(_zc_packet) } } #[async_trait::async_trait] #[auto_impl::auto_impl(Arc)] pub trait NicPacketFilter { - async fn try_process_packet_from_nic(&self, data: BytesMut) -> BytesMut; + async fn try_process_packet_from_nic(&self, data: &mut ZCPacket); } type BoxPeerPacketFilter = Box; type BoxNicPacketFilter = Box; + +pub type PacketRecvChan = tokio::sync::mpsc::Sender; +pub type PacketRecvChanReceiver = tokio::sync::mpsc::Receiver; diff --git a/easytier/src/peers/peer.rs b/easytier/src/peers/peer.rs index fea4a39d..87d4a6f2 100644 --- a/easytier/src/peers/peer.rs +++ b/easytier/src/peers/peer.rs @@ -7,16 +7,22 @@ use tokio::{ sync::{mpsc, Mutex}, task::JoinHandle, }; -use tokio_util::bytes::Bytes; + use tracing::Instrument; -use super::peer_conn::{PeerConn, PeerConnId}; -use crate::common::{ - error::Error, - global_ctx::{ArcGlobalCtx, GlobalCtxEvent}, - PeerId, +use super::{ + zc_peer_conn::{PeerConn, PeerConnId}, + PacketRecvChan, }; use crate::rpc::PeerConnInfo; +use crate::{ + common::{ + error::Error, + global_ctx::{ArcGlobalCtx, GlobalCtxEvent}, + PeerId, + }, + tunnel::packet_def::ZCPacket, +}; type ArcPeerConn = Arc>; type ConnMap = Arc>; @@ -26,7 +32,7 @@ pub struct Peer { conns: ConnMap, global_ctx: ArcGlobalCtx, - packet_recv_chan: mpsc::Sender, + packet_recv_chan: PacketRecvChan, close_event_sender: mpsc::Sender, close_event_listener: JoinHandle<()>, @@ -37,7 +43,7 @@ pub struct Peer { impl Peer { pub fn new( peer_node_id: PeerId, - packet_recv_chan: mpsc::Sender, + packet_recv_chan: PacketRecvChan, global_ctx: ArcGlobalCtx, ) -> Self { let conns: ConnMap = Arc::new(DashMap::new()); @@ -106,7 +112,7 @@ impl Peer { .insert(conn.get_conn_id(), Arc::new(Mutex::new(conn))); } - pub async fn send_msg(&self, msg: Bytes) -> Result<(), Error> { + pub async fn send_msg(&self, msg: ZCPacket) -> Result<(), Error> { let Some(conn) = self.conns.iter().next() else { return Err(Error::PeerNoConnectionError(self.peer_node_id)); }; @@ -157,8 +163,8 @@ mod tests { use crate::{ common::{global_ctx::tests::get_mock_global_ctx, new_peer_id}, - peers::peer_conn::PeerConn, - tunnels::ring_tunnel::create_ring_tunnel_pair, + peers::zc_peer_conn::PeerConn, + tunnel::ring::create_ring_tunnel_pair, }; use super::Peer; diff --git a/easytier/src/peers/peer_conn_ping.rs b/easytier/src/peers/peer_conn_ping.rs new file mode 100644 index 00000000..e06c2a2f --- /dev/null +++ b/easytier/src/peers/peer_conn_ping.rs @@ -0,0 +1,219 @@ +use std::{ + sync::{ + atomic::{AtomicU32, Ordering}, + Arc, + }, + time::Duration, +}; + +use tokio::{sync::broadcast, task::JoinSet, time::timeout}; + +use crate::{ + common::{error::Error, PeerId}, + tunnel::{ + mpsc::MpscTunnelSender, + packet_def::{PacketType, ZCPacket}, + stats::WindowLatency, + TunnelError, + }, +}; + +pub struct PeerConnPinger { + my_peer_id: PeerId, + peer_id: PeerId, + sink: MpscTunnelSender, + ctrl_sender: broadcast::Sender, + latency_stats: Arc, + loss_rate_stats: Arc, + tasks: JoinSet>, +} + +impl std::fmt::Debug for PeerConnPinger { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("PeerConnPinger") + .field("my_peer_id", &self.my_peer_id) + .field("peer_id", &self.peer_id) + .finish() + } +} + +impl PeerConnPinger { + pub fn new( + my_peer_id: PeerId, + peer_id: PeerId, + sink: MpscTunnelSender, + ctrl_sender: broadcast::Sender, + latency_stats: Arc, + loss_rate_stats: Arc, + ) -> Self { + Self { + my_peer_id, + peer_id, + sink, + tasks: JoinSet::new(), + latency_stats, + ctrl_sender, + loss_rate_stats, + } + } + + fn new_ping_packet(my_node_id: PeerId, peer_id: PeerId, seq: u32) -> ZCPacket { + let mut packet = ZCPacket::new_with_payload(&seq.to_le_bytes()); + packet.fill_peer_manager_hdr(my_node_id, peer_id, PacketType::Ping as u8); + packet + } + + async fn do_pingpong_once( + my_node_id: PeerId, + peer_id: PeerId, + sink: &mut MpscTunnelSender, + receiver: &mut broadcast::Receiver, + seq: u32, + ) -> Result { + // should add seq here. so latency can be calculated more accurately + let req = Self::new_ping_packet(my_node_id, peer_id, seq); + sink.send(req).await?; + + let now = std::time::Instant::now(); + // wait until we get a pong packet in ctrl_resp_receiver + let resp = timeout(Duration::from_secs(1), async { + loop { + match receiver.recv().await { + Ok(p) => { + let payload = p.payload(); + let Ok(seq_buf) = payload[0..4].try_into() else { + tracing::debug!("pingpong recv invalid packet, continue"); + continue; + }; + let resp_seq = u32::from_le_bytes(seq_buf); + if resp_seq == seq { + break; + } + } + Err(e) => { + return Err(Error::WaitRespError(format!( + "wait ping response error: {:?}", + e + ))); + } + } + } + Ok(()) + }) + .await; + + tracing::trace!(?resp, "wait ping response done"); + + if resp.is_err() { + return Err(Error::WaitRespError( + "wait ping response timeout".to_owned(), + )); + } + + if resp.as_ref().unwrap().is_err() { + return Err(resp.unwrap().err().unwrap()); + } + + Ok(now.elapsed().as_micros()) + } + + pub async fn pingpong(&mut self) { + let sink = self.sink.clone(); + let my_node_id = self.my_peer_id; + let peer_id = self.peer_id; + let latency_stats = self.latency_stats.clone(); + + let (ping_res_sender, mut ping_res_receiver) = tokio::sync::mpsc::channel(100); + + let stopped = Arc::new(AtomicU32::new(0)); + + // generate a pingpong task every 200ms + let mut pingpong_tasks = JoinSet::new(); + let ctrl_resp_sender = self.ctrl_sender.clone(); + let stopped_clone = stopped.clone(); + self.tasks.spawn(async move { + let mut req_seq = 0; + loop { + let receiver = ctrl_resp_sender.subscribe(); + let ping_res_sender = ping_res_sender.clone(); + + if stopped_clone.load(Ordering::Relaxed) != 0 { + return Ok(()); + } + + while pingpong_tasks.len() > 5 { + pingpong_tasks.join_next().await; + } + + let mut sink = sink.clone(); + pingpong_tasks.spawn(async move { + let mut receiver = receiver.resubscribe(); + let pingpong_once_ret = Self::do_pingpong_once( + my_node_id, + peer_id, + &mut sink, + &mut receiver, + req_seq, + ) + .await; + + if let Err(e) = ping_res_sender.send(pingpong_once_ret).await { + tracing::info!(?e, "pingpong task send result error, exit.."); + }; + }); + + req_seq = req_seq.wrapping_add(1); + tokio::time::sleep(Duration::from_millis(1000)).await; + } + }); + + // one with 1% precision + let loss_rate_stats_1 = WindowLatency::new(100); + // one with 20% precision, so we can fast fail this conn. + let loss_rate_stats_20 = WindowLatency::new(5); + + let mut counter: u64 = 0; + + while let Some(ret) = ping_res_receiver.recv().await { + counter += 1; + + if let Ok(lat) = ret { + latency_stats.record_latency(lat as u32); + + loss_rate_stats_1.record_latency(0); + loss_rate_stats_20.record_latency(0); + } else { + loss_rate_stats_1.record_latency(1); + loss_rate_stats_20.record_latency(1); + } + + let loss_rate_20: f64 = loss_rate_stats_20.get_latency_us(); + let loss_rate_1: f64 = loss_rate_stats_1.get_latency_us(); + + tracing::trace!( + ?ret, + ?self, + ?loss_rate_1, + ?loss_rate_20, + "pingpong task recv pingpong_once result" + ); + + if (counter > 5 && loss_rate_20 > 0.74) || (counter > 150 && loss_rate_1 > 0.20) { + tracing::warn!( + ?ret, + ?self, + ?loss_rate_1, + ?loss_rate_20, + "pingpong loss rate too high, closing" + ); + break; + } + + self.loss_rate_stats + .store((loss_rate_1 * 100.0) as u32, Ordering::Relaxed); + } + + stopped.store(1, Ordering::Relaxed); + ping_res_receiver.close(); + } +} diff --git a/easytier/src/peers/peer_manager.rs b/easytier/src/peers/peer_manager.rs index c2f30fe6..2a65336e 100644 --- a/easytier/src/peers/peer_manager.rs +++ b/easytier/src/peers/peer_manager.rs @@ -5,6 +5,7 @@ use std::{ }; use async_trait::async_trait; + use futures::StreamExt; use tokio::{ @@ -15,30 +16,30 @@ use tokio::{ task::JoinSet, }; use tokio_stream::wrappers::ReceiverStream; -use tokio_util::bytes::{Bytes, BytesMut}; +use tokio_util::bytes::Bytes; use crate::{ - common::{ - error::Error, global_ctx::ArcGlobalCtx, rkyv_util::extract_bytes_from_archived_string, - PeerId, - }, + common::{error::Error, global_ctx::ArcGlobalCtx, PeerId}, peers::{ - packet, peer_conn::PeerConn, peer_rpc::PeerRpcManagerTransport, - route_trait::RouteInterface, PeerPacketFilter, + packet, peer_rpc::PeerRpcManagerTransport, route_trait::RouteInterface, + zc_peer_conn::PeerConn, PeerPacketFilter, + }, + tunnel::{ + packet_def::{PacketType, ZCPacket}, + SinkItem, Tunnel, TunnelConnector, }, - tunnels::{SinkItem, Tunnel, TunnelConnector}, }; use super::{ foreign_network_client::ForeignNetworkClient, foreign_network_manager::ForeignNetworkManager, - peer_conn::PeerConnId, peer_map::PeerMap, peer_ospf_route::PeerRoute, peer_rip_route::BasicRoute, peer_rpc::PeerRpcManager, route_trait::{ArcRoute, Route}, - BoxNicPacketFilter, BoxPeerPacketFilter, + zc_peer_conn::PeerConnId, + BoxNicPacketFilter, BoxPeerPacketFilter, PacketRecvChanReceiver, }; struct RpcTransport { @@ -46,8 +47,8 @@ struct RpcTransport { peers: Weak, foreign_peers: Mutex>>, - packet_recv: Mutex>, - peer_rpc_tspt_sender: UnboundedSender, + packet_recv: Mutex>, + peer_rpc_tspt_sender: UnboundedSender, } #[async_trait::async_trait] @@ -56,7 +57,7 @@ impl PeerRpcManagerTransport for RpcTransport { self.my_peer_id } - async fn send(&self, msg: Bytes, dst_peer_id: PeerId) -> Result<(), Error> { + async fn send(&self, msg: ZCPacket, dst_peer_id: PeerId) -> Result<(), Error> { let foreign_peers = self .foreign_peers .lock() @@ -67,21 +68,30 @@ impl PeerRpcManagerTransport for RpcTransport { .ok_or(Error::Unknown)?; let peers = self.peers.upgrade().ok_or(Error::Unknown)?; - let ret = peers.send_msg(msg.clone(), dst_peer_id).await; - - if matches!(ret, Err(Error::RouteError(..))) && foreign_peers.has_next_hop(dst_peer_id) { - tracing::info!( + if let Some(gateway_id) = peers.get_gateway_peer_id(dst_peer_id).await { + tracing::trace!( + ?dst_peer_id, + ?gateway_id, + ?self.my_peer_id, + "send msg to peer via gateway", + ); + peers.send_msg_directly(msg, gateway_id).await + } else if foreign_peers.has_next_hop(dst_peer_id) { + tracing::debug!( ?dst_peer_id, ?self.my_peer_id, "failed to send msg to peer, try foreign network", ); - return foreign_peers.send_msg(msg, dst_peer_id).await; + foreign_peers.send_msg(msg, dst_peer_id).await + } else { + Err(Error::RouteError(Some(format!( + "peermgr RpcTransport no route for dst_peer_id: {}", + dst_peer_id + )))) } - - ret } - async fn recv(&self) -> Result { + async fn recv(&self) -> Result { if let Some(o) = self.packet_recv.lock().await.recv().await { Ok(o) } else { @@ -110,7 +120,7 @@ pub struct PeerManager { tasks: Arc>>, - packet_recv: Arc>>>, + packet_recv: Arc>>, peers: Arc, @@ -261,17 +271,20 @@ impl PeerManager { self.tasks.lock().await.spawn(async move { log::trace!("start_peer_recv"); while let Some(ret) = recv.next().await { - log::trace!("peer recv a packet...: {:?}", ret); - let packet = packet::Packet::decode(&ret); - let from_peer_id: PeerId = packet.from_peer.into(); - let to_peer_id: PeerId = packet.to_peer.into(); + let Some(hdr) = ret.peer_manager_header() else { + tracing::warn!(?ret, "invalid packet, skip"); + continue; + }; + tracing::trace!(?hdr, ?ret, "peer recv a packet..."); + let from_peer_id = hdr.from_peer_id.get(); + let to_peer_id = hdr.to_peer_id.get(); if to_peer_id != my_peer_id { log::trace!( "need forward: to_peer_id: {:?}, my_peer_id: {:?}", to_peer_id, my_peer_id ); - let ret = peers.send_msg(ret.clone(), to_peer_id).await; + let ret = peers.send_msg(ret, to_peer_id).await; if ret.is_err() { log::error!( "forward packet error: {:?}, dst: {:?}, from: {:?}", @@ -282,15 +295,21 @@ impl PeerManager { } } else { let mut processed = false; + let mut zc_packet = Some(ret); + let mut idx = 0; for pipeline in pipe_line.read().await.iter().rev() { - if let Some(_) = pipeline.try_process_packet_from_peer(&packet, &ret).await - { + tracing::debug!(?zc_packet, ?idx, "try_process_packet_from_peer"); + idx += 1; + zc_packet = pipeline + .try_process_packet_from_peer(zc_packet.unwrap()) + .await; + if zc_packet.is_none() { processed = true; break; } } if !processed { - tracing::error!("unexpected packet: {:?}", ret); + tracing::error!(?zc_packet, "unhandled packet"); } } } @@ -321,20 +340,15 @@ impl PeerManager { } #[async_trait::async_trait] impl PeerPacketFilter for NicPacketProcessor { - async fn try_process_packet_from_peer( - &self, - packet: &packet::ArchivedPacket, - data: &Bytes, - ) -> Option<()> { - if packet.packet_type == packet::PacketType::Data { + async fn try_process_packet_from_peer(&self, packet: ZCPacket) -> Option { + let hdr = packet.peer_manager_header().unwrap(); + if hdr.packet_type == PacketType::Data as u8 { + tracing::trace!(?packet, "send packet to nic channel"); // TODO: use a function to get the body ref directly for zero copy - self.nic_channel - .send(extract_bytes_from_archived_string(data, &packet.payload)) - .await - .unwrap(); - Some(()) - } else { + self.nic_channel.send(packet).await.unwrap(); None + } else { + Some(packet) } } } @@ -345,21 +359,18 @@ impl PeerManager { // for peer rpc packet struct PeerRpcPacketProcessor { - peer_rpc_tspt_sender: UnboundedSender, + peer_rpc_tspt_sender: UnboundedSender, } #[async_trait::async_trait] impl PeerPacketFilter for PeerRpcPacketProcessor { - async fn try_process_packet_from_peer( - &self, - packet: &packet::ArchivedPacket, - data: &Bytes, - ) -> Option<()> { - if packet.packet_type == packet::PacketType::TaRpc { - self.peer_rpc_tspt_sender.send(data.clone()).unwrap(); - Some(()) - } else { + async fn try_process_packet_from_peer(&self, packet: ZCPacket) -> Option { + let hdr = packet.peer_manager_header().unwrap(); + if hdr.packet_type == PacketType::TaRpc as u8 { + self.peer_rpc_tspt_sender.send(packet).unwrap(); None + } else { + Some(packet) } } } @@ -401,7 +412,7 @@ impl PeerManager { async fn send_route_packet( &self, msg: Bytes, - route_id: u8, + _route_id: u8, dst_peer_id: PeerId, ) -> Result<(), Error> { let foreign_client = self @@ -409,15 +420,17 @@ impl PeerManager { .upgrade() .ok_or(Error::Unknown)?; let peer_map = self.peers.upgrade().ok_or(Error::Unknown)?; - - let packet_bytes: Bytes = - packet::Packet::new_route_packet(self.my_peer_id, dst_peer_id, route_id, &msg) - .into(); + let mut zc_packet = ZCPacket::new_with_payload(&msg); + zc_packet.fill_peer_manager_hdr( + self.my_peer_id, + dst_peer_id, + PacketType::Route as u8, + ); if foreign_client.has_next_hop(dst_peer_id) { - return foreign_client.send_msg(packet_bytes, dst_peer_id).await; + foreign_client.send_msg(zc_packet, dst_peer_id).await + } else { + peer_map.send_msg_directly(zc_packet, dst_peer_id).await } - - peer_map.send_msg_directly(packet_bytes, dst_peer_id).await } fn my_peer_id(&self) -> PeerId { self.my_peer_id @@ -450,18 +463,17 @@ impl PeerManager { self.get_route().list_routes().await } - async fn run_nic_packet_process_pipeline(&self, mut data: BytesMut) -> BytesMut { + async fn run_nic_packet_process_pipeline(&self, data: &mut ZCPacket) { for pipeline in self.nic_packet_process_pipeline.read().await.iter().rev() { - data = pipeline.try_process_packet_from_nic(data).await; + pipeline.try_process_packet_from_nic(data).await; } - data } - pub async fn send_msg(&self, msg: Bytes, dst_peer_id: PeerId) -> Result<(), Error> { + pub async fn send_msg(&self, msg: ZCPacket, dst_peer_id: PeerId) -> Result<(), Error> { self.peers.send_msg(msg, dst_peer_id).await } - pub async fn send_msg_ipv4(&self, msg: BytesMut, ipv4_addr: Ipv4Addr) -> Result<(), Error> { + pub async fn send_msg_ipv4(&self, mut msg: ZCPacket, ipv4_addr: Ipv4Addr) -> Result<(), Error> { log::trace!( "do send_msg in peer manager, msg: {:?}, ipv4_addr: {}", msg, @@ -487,25 +499,34 @@ impl PeerManager { return Ok(()); } - let msg = self.run_nic_packet_process_pipeline(msg).await; + self.run_nic_packet_process_pipeline(&mut msg).await; let mut errs: Vec = vec![]; - for peer_id in dst_peers.iter() { - let msg: Bytes = - packet::Packet::new_data_packet(self.my_peer_id, peer_id.clone(), &msg).into(); - let send_ret = self.peers.send_msg(msg.clone(), *peer_id).await; + let mut msg = Some(msg); + let total_dst_peers = dst_peers.len(); + for i in 0..total_dst_peers { + let mut msg = if i == total_dst_peers - 1 { + msg.take().unwrap() + } else { + msg.clone().unwrap() + }; - if matches!(send_ret, Err(Error::RouteError(..))) - && self.foreign_network_client.has_next_hop(*peer_id) - { - let foreign_send_ret = self.foreign_network_client.send_msg(msg, *peer_id).await; - if foreign_send_ret.is_ok() { - continue; + let peer_id = &dst_peers[i]; + + msg.fill_peer_manager_hdr(self.my_peer_id, *peer_id, packet::PacketType::Data as u8); + + if let Some(gateway) = self.peers.get_gateway_peer_id(*peer_id).await { + if let Err(e) = self.peers.send_msg_directly(msg.clone(), gateway).await { + errs.push(e); + } + } else if self.foreign_network_client.has_next_hop(*peer_id) { + if let Err(e) = self + .foreign_network_client + .send_msg(msg.clone(), *peer_id) + .await + { + errs.push(e); } - } - - if let Err(send_ret) = send_ret { - errs.push(send_ret); } } diff --git a/easytier/src/peers/peer_map.rs b/easytier/src/peers/peer_map.rs index cda479ba..26508ab9 100644 --- a/easytier/src/peers/peer_map.rs +++ b/easytier/src/peers/peer_map.rs @@ -2,8 +2,7 @@ use std::{net::Ipv4Addr, sync::Arc}; use anyhow::Context; use dashmap::DashMap; -use tokio::sync::{mpsc, RwLock}; -use tokio_util::bytes::Bytes; +use tokio::sync::RwLock; use crate::{ common::{ @@ -12,29 +11,27 @@ use crate::{ PeerId, }, rpc::PeerConnInfo, + tunnel::packet_def::ZCPacket, tunnels::TunnelError, }; use super::{ peer::Peer, - peer_conn::{PeerConn, PeerConnId}, route_trait::ArcRoute, + zc_peer_conn::{PeerConn, PeerConnId}, + PacketRecvChan, }; pub struct PeerMap { global_ctx: ArcGlobalCtx, my_peer_id: PeerId, peer_map: DashMap>, - packet_send: mpsc::Sender, + packet_send: PacketRecvChan, routes: RwLock>, } impl PeerMap { - pub fn new( - packet_send: mpsc::Sender, - global_ctx: ArcGlobalCtx, - my_peer_id: PeerId, - ) -> Self { + pub fn new(packet_send: PacketRecvChan, global_ctx: ArcGlobalCtx, my_peer_id: PeerId) -> Self { PeerMap { global_ctx, my_peer_id, @@ -72,7 +69,7 @@ impl PeerMap { self.peer_map.contains_key(&peer_id) } - pub async fn send_msg_directly(&self, msg: Bytes, dst_peer_id: PeerId) -> Result<(), Error> { + pub async fn send_msg_directly(&self, msg: ZCPacket, dst_peer_id: PeerId) -> Result<(), Error> { if dst_peer_id == self.my_peer_id { return Ok(self .packet_send @@ -87,48 +84,53 @@ impl PeerMap { } None => { log::error!("no peer for dst_peer_id: {}", dst_peer_id); - return Err(Error::RouteError(None)); + return Err(Error::RouteError(Some(format!( + "peer map sengmsg directly no connected dst_peer_id: {}", + dst_peer_id + )))); } } Ok(()) } - pub async fn send_msg(&self, msg: Bytes, dst_peer_id: PeerId) -> Result<(), Error> { + pub async fn get_gateway_peer_id(&self, dst_peer_id: PeerId) -> Option { if dst_peer_id == self.my_peer_id { - return Ok(self - .packet_send - .send(msg) - .await - .with_context(|| "send msg to self failed")?); + return Some(dst_peer_id); + } + + if self.has_peer(dst_peer_id) { + return Some(dst_peer_id); } // get route info - let mut gateway_peer_id = None; for route in self.routes.read().await.iter() { - gateway_peer_id = route.get_next_hop(dst_peer_id).await; - if gateway_peer_id.is_none() { - continue; - } else { - break; + if let Some(gateway_peer_id) = route.get_next_hop(dst_peer_id).await { + // for foreign network, gateway_peer_id may not connect to me + if self.has_peer(gateway_peer_id) { + return Some(gateway_peer_id); + } } } - if gateway_peer_id.is_none() && self.has_peer(dst_peer_id) { - gateway_peer_id = Some(dst_peer_id); - } + None + } - let Some(gateway_peer_id) = gateway_peer_id else { + pub async fn send_msg(&self, msg: ZCPacket, dst_peer_id: PeerId) -> Result<(), Error> { + let Some(gateway_peer_id) = self.get_gateway_peer_id(dst_peer_id).await else { tracing::trace!( "no gateway for dst_peer_id: {}, peers: {:?}, my_peer_id: {}", dst_peer_id, self.peer_map.iter().map(|v| *v.key()).collect::>(), self.my_peer_id ); - return Err(Error::RouteError(None)); + return Err(Error::RouteError(Some(format!( + "peer map sengmsg no gateway for dst_peer_id: {}", + dst_peer_id + )))); }; - self.send_msg_directly(msg.clone(), gateway_peer_id).await?; + self.send_msg_directly(msg, gateway_peer_id).await?; return Ok(()); } diff --git a/easytier/src/peers/peer_ospf_route.rs b/easytier/src/peers/peer_ospf_route.rs index 76703904..ac581726 100644 --- a/easytier/src/peers/peer_ospf_route.rs +++ b/easytier/src/peers/peer_ospf_route.rs @@ -203,6 +203,7 @@ struct SyncRouteInfoResponse { trait RouteService { async fn sync_route_info( my_peer_id: PeerId, + my_session_id: SessionId, is_initiator: bool, peer_infos: Option>, conn_bitmap: Option, @@ -547,6 +548,15 @@ impl SyncRouteSession { self.we_are_initiator.store(is_initiator, Ordering::Relaxed); self.need_sync_initiator_info.store(true, Ordering::Relaxed); } + + fn update_dst_session_id(&self, session_id: SessionId) { + if session_id != self.dst_session_id.load(Ordering::Relaxed) { + tracing::warn!(?self, ?session_id, "session id mismatch, clear saved info."); + self.dst_session_id.store(session_id, Ordering::Relaxed); + self.dst_saved_conn_bitmap_version.clear(); + self.dst_saved_peer_info_versions.clear(); + } + } } struct PeerRouteServiceImpl { @@ -794,6 +804,7 @@ impl PeerRouteServiceImpl { .sync_route_info( rpc_ctx, my_peer_id, + session.my_session_id.load(Ordering::Relaxed), session.we_are_initiator.load(Ordering::Relaxed), peer_infos.clone(), conn_bitmap.clone(), @@ -814,19 +825,7 @@ impl PeerRouteServiceImpl { .need_sync_initiator_info .store(false, Ordering::Relaxed); - if ret.session_id != session.dst_session_id.load(Ordering::Relaxed) { - tracing::warn!( - ?ret, - ?my_peer_id, - ?dst_peer_id, - "session id mismatch, clear saved info." - ); - session - .dst_session_id - .store(ret.session_id, Ordering::Relaxed); - session.dst_saved_conn_bitmap_version.clear(); - session.dst_saved_peer_info_versions.clear(); - } + session.update_dst_session_id(ret.session_id); if let Some(peer_infos) = &peer_infos { session.update_dst_saved_peer_info_version(&peer_infos); @@ -864,6 +863,7 @@ impl RouteService for RouteSessionManager { self, _: tarpc::context::Context, from_peer_id: PeerId, + from_session_id: SessionId, is_initiator: bool, peer_infos: Option>, conn_bitmap: Option, @@ -877,6 +877,8 @@ impl RouteService for RouteSessionManager { session.rpc_rx_count.fetch_add(1, Ordering::Relaxed); + session.update_dst_session_id(from_session_id); + if let Some(peer_infos) = &peer_infos { service_impl.synced_route_info.update_peer_infos( my_peer_id, @@ -1383,9 +1385,8 @@ mod tests { let i_a = get_is_initiator(&r_a, p_b.my_peer_id()); let i_b = get_is_initiator(&r_b, p_a.my_peer_id()); - assert_ne!(i_a.0, i_a.1); - assert_ne!(i_b.0, i_b.1); - assert_ne!(i_a.0, i_b.0); + assert_eq!(i_a.0, i_b.1); + assert_eq!(i_b.0, i_a.1); drop(r_b); drop(p_b); diff --git a/easytier/src/peers/peer_rip_route.rs b/easytier/src/peers/peer_rip_route.rs index 26b52079..fda9e5a3 100644 --- a/easytier/src/peers/peer_rip_route.rs +++ b/easytier/src/peers/peer_rip_route.rs @@ -15,14 +15,12 @@ use tracing::Instrument; use crate::{ common::{error::Error, global_ctx::ArcGlobalCtx, stun::StunInfoCollectorTrait, PeerId}, - peers::{ - packet, - route_trait::{Route, RouteInterfaceBox}, - }, + peers::route_trait::{Route, RouteInterfaceBox}, rpc::{NatType, StunInfo}, + tunnel::packet_def::{PacketType, ZCPacket}, }; -use super::{packet::CtrlPacketPayload, PeerPacketFilter}; +use super::PeerPacketFilter; const SEND_ROUTE_PERIOD_SEC: u64 = 60; const SEND_ROUTE_FAST_REPLY_SEC: u64 = 5; @@ -625,26 +623,15 @@ impl Route for BasicRoute { #[async_trait::async_trait] impl PeerPacketFilter for BasicRoute { - async fn try_process_packet_from_peer( - &self, - packet: &packet::ArchivedPacket, - _data: &Bytes, - ) -> Option<()> { - if packet.packet_type == packet::PacketType::RoutePacket { - let CtrlPacketPayload::RoutePacket(route_packet) = - CtrlPacketPayload::from_packet(packet) - else { - return None; - }; - - self.handle_route_packet( - packet.from_peer.into(), - route_packet.body.into_boxed_slice().into(), - ) - .await; - Some(()) - } else { + async fn try_process_packet_from_peer(&self, packet: ZCPacket) -> Option { + let hdr = packet.peer_manager_header().unwrap(); + if hdr.packet_type == PacketType::Route as u8 { + let b = packet.payload().to_vec(); + self.handle_route_packet(hdr.from_peer_id.get(), b.into()) + .await; None + } else { + Some(packet) } } } diff --git a/easytier/src/peers/peer_rpc.rs b/easytier/src/peers/peer_rpc.rs index cfea4cd6..1c8368ad 100644 --- a/easytier/src/peers/peer_rpc.rs +++ b/easytier/src/peers/peer_rpc.rs @@ -2,22 +2,22 @@ use std::sync::{atomic::AtomicU32, Arc}; use dashmap::DashMap; use futures::{SinkExt, StreamExt}; -use rkyv::Deserialize; +use prost::Message; + use tarpc::{server::Channel, transport::channel::UnboundedChannel}; use tokio::{ sync::mpsc::{self, UnboundedSender}, task::JoinSet, }; -use tokio_util::bytes::Bytes; + use tracing::Instrument; use crate::{ common::{error::Error, PeerId}, - peers::packet::Packet, + rpc::TaRpcPacket, + tunnel::packet_def::{PacketType, ZCPacket}, }; -use super::packet::CtrlPacketPayload; - type PeerRpcServiceId = u32; type PeerRpcTransactId = u32; @@ -25,11 +25,11 @@ type PeerRpcTransactId = u32; #[auto_impl::auto_impl(Arc)] pub trait PeerRpcManagerTransport: Send + Sync + 'static { fn my_peer_id(&self) -> PeerId; - async fn send(&self, msg: Bytes, dst_peer_id: PeerId) -> Result<(), Error>; - async fn recv(&self) -> Result; + async fn send(&self, msg: ZCPacket, dst_peer_id: PeerId) -> Result<(), Error>; + async fn recv(&self) -> Result; } -type PacketSender = UnboundedSender; +type PacketSender = UnboundedSender; struct PeerRpcEndPoint { peer_id: PeerId, @@ -63,16 +63,6 @@ impl std::fmt::Debug for PeerRpcManager { } } -#[derive(Debug)] -struct TaRpcPacketInfo { - from_peer: PeerId, - to_peer: PeerId, - service_id: PeerRpcServiceId, - transact_id: PeerRpcTransactId, - is_req: bool, - content: Vec, -} - impl PeerRpcManager { pub fn new(tspt: impl PeerRpcManagerTransport) -> Self { Self { @@ -100,7 +90,7 @@ impl PeerRpcManager { let tspt = self.tspt.clone(); let creator = Box::new(move |peer_id: PeerId| { let mut tasks = JoinSet::new(); - let (packet_sender, mut packet_receiver) = mpsc::unbounded_channel::(); + let (packet_sender, mut packet_receiver) = mpsc::unbounded_channel(); let (mut client_transport, server_transport) = tarpc::transport::channel::unbounded(); let server = tarpc::server::BaseChannel::with_defaults(server_transport); @@ -122,7 +112,7 @@ impl PeerRpcManager { continue; }; - tracing::trace!(resp = ?resp, "recv packet from client"); + tracing::debug!(resp = ?resp, "server recv packet from service provider"); if resp.is_err() { tracing::warn!(err = ?resp.err(), "[PEER RPC MGR] client_transport in server side got channel error, ignore it."); @@ -136,7 +126,7 @@ impl PeerRpcManager { continue; } - let msg = Packet::new_tarpc_packet( + let msg = Self::build_rpc_packet( tspt.my_peer_id(), cur_req_peer_id, service_id, @@ -145,12 +135,13 @@ impl PeerRpcManager { serialized_resp.unwrap(), ); - if let Err(e) = tspt.send(msg.into(), peer_id).await { + if let Err(e) = tspt.send(msg, peer_id).await { tracing::error!(error = ?e, peer_id = ?peer_id, service_id = ?service_id, "send resp to peer failed"); } } Some(packet) = packet_receiver.recv() => { let info = Self::parse_rpc_packet(&packet); + tracing::debug!(?info, "server recv packet from peer"); if let Err(e) = info { tracing::error!(error = ?e, packet = ?packet, "parse rpc packet failed"); continue; @@ -168,7 +159,7 @@ impl PeerRpcManager { } assert_eq!(info.service_id, service_id); - cur_req_peer_id = Some(packet.from_peer.clone().into()); + cur_req_peer_id = Some(info.from_peer); cur_transact_id = info.transact_id; tracing::trace!("recv packet from peer, packet: {:?}", packet); @@ -219,19 +210,33 @@ impl PeerRpcManager { ) } - fn parse_rpc_packet(packet: &Packet) -> Result { - let ctrl_packet_payload = CtrlPacketPayload::from_packet2(&packet); - match &ctrl_packet_payload { - CtrlPacketPayload::TaRpc(id, tid, is_req, body) => Ok(TaRpcPacketInfo { - from_peer: packet.from_peer.into(), - to_peer: packet.to_peer.into(), - service_id: *id, - transact_id: *tid, - is_req: *is_req, - content: body.clone(), - }), - _ => Err(Error::ShellCommandError("invalid packet".to_owned())), - } + fn parse_rpc_packet(packet: &ZCPacket) -> Result { + let payload = packet.payload(); + TaRpcPacket::decode(payload).map_err(|e| Error::MessageDecodeError(e.to_string())) + } + + fn build_rpc_packet( + from_peer: PeerId, + to_peer: PeerId, + service_id: PeerRpcServiceId, + transact_id: PeerRpcTransactId, + is_req: bool, + content: Vec, + ) -> ZCPacket { + let packet = TaRpcPacket { + from_peer, + to_peer, + service_id, + transact_id, + is_req, + content, + }; + let mut buf = Vec::new(); + packet.encode(&mut buf).unwrap(); + + let mut zc_packet = ZCPacket::new_with_payload(&buf); + zc_packet.fill_peer_manager_hdr(from_peer, to_peer, PacketType::TaRpc as u8); + zc_packet } pub fn run(&self) { @@ -245,9 +250,9 @@ impl PeerRpcManager { tracing::warn!("peer rpc transport read aborted, exiting"); break; }; - let packet = Packet::decode(&o); - let packet: Packet = packet.deserialize(&mut rkyv::Infallible).unwrap(); - let info = Self::parse_rpc_packet(&packet).unwrap(); + + let info = Self::parse_rpc_packet(&o).unwrap(); + tracing::debug!(?info, "recv rpc packet from peer"); if info.is_req { if !service_registry.contains_key(&info.service_id) { @@ -265,15 +270,15 @@ impl PeerRpcManager { service_registry.get(&info.service_id).unwrap()(info.from_peer) }); - endpoint.packet_sender.send(packet).unwrap(); + endpoint.packet_sender.send(o).unwrap(); } else { if let Some(a) = client_resp_receivers.get(&PeerRpcClientCtxKey( info.from_peer, info.service_id, info.transact_id, )) { - log::trace!("recv resp: {:?}", packet); - if let Err(e) = a.send(packet) { + tracing::trace!("recv resp: {:?}", info); + if let Err(e) = a.send(o) { tracing::error!(error = ?e, "send resp to client failed"); } } else { @@ -297,7 +302,8 @@ impl PeerRpcManager { Fut: std::future::Future, { let mut tasks = JoinSet::new(); - let (packet_sender, mut packet_receiver) = mpsc::unbounded_channel::(); + let (packet_sender, mut packet_receiver) = mpsc::unbounded_channel(); + let (client_transport, server_transport) = tarpc::transport::channel::unbounded::(); @@ -321,7 +327,7 @@ impl PeerRpcManager { continue; } - let a = Packet::new_tarpc_packet( + let packet = Self::build_rpc_packet( tspt.my_peer_id(), dst_peer_id, service_id, @@ -330,7 +336,9 @@ impl PeerRpcManager { a.unwrap(), ); - if let Err(e) = tspt.send(a.into(), dst_peer_id).await { + tracing::debug!(?packet, "client send rpc packet to peer"); + + if let Err(e) = tspt.send(packet, dst_peer_id).await { tracing::error!(error = ?e, dst_peer_id = ?dst_peer_id, "send to peer failed"); } } @@ -342,11 +350,12 @@ impl PeerRpcManager { while let Some(packet) = packet_receiver.recv().await { tracing::trace!("tunnel recv: {:?}", packet); - let info = PeerRpcManager::parse_rpc_packet(&packet); + let info = Self::parse_rpc_packet(&packet); if let Err(e) = info { tracing::error!(error = ?e, "parse rpc packet failed"); continue; } + tracing::debug!(?info, "client recv rpc packet from peer"); let decoded = postcard::from_bytes(&info.unwrap().content.as_slice()); if let Err(e) = decoded { @@ -381,8 +390,10 @@ impl PeerRpcManager { #[cfg(test)] mod tests { + use std::{pin::Pin, sync::Arc}; + use futures::{SinkExt, StreamExt}; - use tokio_util::bytes::Bytes; + use tokio::sync::Mutex; use crate::{ common::{error::Error, new_peer_id, PeerId}, @@ -390,7 +401,10 @@ mod tests { peer_rpc::PeerRpcManager, tests::{connect_peer_manager, create_mock_peer_manager, wait_route_appear}, }, - tunnels::{self, ring_tunnel::create_ring_tunnel_pair}, + tunnel::{ + packet_def::ZCPacket, ring::create_ring_tunnel_pair, Tunnel, ZCPacketSink, + ZCPacketStream, + }, }; use super::PeerRpcManagerTransport; @@ -415,7 +429,8 @@ mod tests { #[tokio::test] async fn peer_rpc_basic_test() { struct MockTransport { - tunnel: Box, + sink: Arc>>>, + stream: Arc>>>, my_peer_id: PeerId, } @@ -424,22 +439,25 @@ mod tests { fn my_peer_id(&self) -> PeerId { self.my_peer_id } - async fn send(&self, msg: Bytes, _dst_peer_id: PeerId) -> Result<(), Error> { + async fn send(&self, msg: ZCPacket, _dst_peer_id: PeerId) -> Result<(), Error> { println!("rpc mgr send: {:?}", msg); - self.tunnel.pin_sink().send(msg).await.unwrap(); + self.sink.lock().await.send(msg).await.unwrap(); Ok(()) } - async fn recv(&self) -> Result { - let ret = self.tunnel.pin_stream().next().await.unwrap(); + async fn recv(&self) -> Result { + let ret = self.stream.lock().await.next().await.unwrap(); println!("rpc mgr recv: {:?}", ret); - return ret.map(|v| v.freeze()).map_err(|_| Error::Unknown); + return ret.map_err(|e| e.into()); } } let (ct, st) = create_ring_tunnel_pair(); + let (cts, ctsr) = ct.split(); + let (sts, stsr) = st.split(); let server_rpc_mgr = PeerRpcManager::new(MockTransport { - tunnel: st, + sink: Arc::new(Mutex::new(ctsr)), + stream: Arc::new(Mutex::new(cts)), my_peer_id: new_peer_id(), }); server_rpc_mgr.run(); @@ -449,7 +467,8 @@ mod tests { server_rpc_mgr.run_service(1, s.serve()); let client_rpc_mgr = PeerRpcManager::new(MockTransport { - tunnel: ct, + sink: Arc::new(Mutex::new(stsr)), + stream: Arc::new(Mutex::new(sts)), my_peer_id: new_peer_id(), }); client_rpc_mgr.run(); diff --git a/easytier/src/peers/tests.rs b/easytier/src/peers/tests.rs index 3eb0ac84..6cef861a 100644 --- a/easytier/src/peers/tests.rs +++ b/easytier/src/peers/tests.rs @@ -4,7 +4,7 @@ use futures::Future; use crate::{ common::{error::Error, global_ctx::tests::get_mock_global_ctx, PeerId}, - tunnels::ring_tunnel::create_ring_tunnel_pair, + tunnel::ring::create_ring_tunnel_pair, }; use super::peer_manager::{PeerManager, RouteAlgoType}; diff --git a/easytier/src/peers/zc_peer_conn.rs b/easytier/src/peers/zc_peer_conn.rs new file mode 100644 index 00000000..637ebd97 --- /dev/null +++ b/easytier/src/peers/zc_peer_conn.rs @@ -0,0 +1,748 @@ +use std::{ + any::Any, + fmt::Debug, + pin::Pin, + sync::{ + atomic::{AtomicU32, Ordering}, + Arc, + }, +}; + +use futures::{SinkExt, StreamExt, TryFutureExt}; + +use prost::Message; + +use tokio::{ + sync::{broadcast, mpsc}, + task::JoinSet, + time::{timeout, Duration}, +}; + +use tokio_util::sync::PollSender; +use tracing::Instrument; +use zerocopy::AsBytes; + +use crate::{ + common::{ + error::Error, + global_ctx::{ArcGlobalCtx, NetworkIdentity}, + PeerId, + }, + peers::packet::PacketType, + rpc::{HandshakeRequest, PeerConnInfo, PeerConnStats, TunnelInfo}, + tunnel::{ + filter::{StatsRecorderTunnelFilter, TunnelFilter, TunnelWithFilter}, + mpsc::{MpscTunnel, MpscTunnelSender}, + packet_def::ZCPacket, + stats::{Throughput, WindowLatency}, + Tunnel, TunnelError, ZCPacketStream, + }, +}; + +use super::{peer_conn_ping::PeerConnPinger, PacketRecvChan}; + +pub type PeerConnId = uuid::Uuid; + +const MAGIC: u32 = 0xd1e1a5e1; +const VERSION: u32 = 1; + +pub struct PeerConn { + conn_id: PeerConnId, + + my_peer_id: PeerId, + global_ctx: ArcGlobalCtx, + + tunnel: Box, + sink: MpscTunnelSender, + recv: Option>>, + tunnel_info: Option, + + tasks: JoinSet>, + + info: Option, + + close_event_sender: Option>, + + ctrl_resp_sender: broadcast::Sender, + + latency_stats: Arc, + throughput: Arc, + loss_rate_stats: Arc, +} + +impl Debug for PeerConn { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("PeerConn") + .field("conn_id", &self.conn_id) + .field("my_peer_id", &self.my_peer_id) + .field("info", &self.info) + .finish() + } +} + +impl PeerConn { + pub fn new(my_peer_id: PeerId, global_ctx: ArcGlobalCtx, tunnel: Box) -> Self { + let tunnel_info = tunnel.info(); + let (ctrl_sender, _ctrl_receiver) = broadcast::channel(100); + + let peer_conn_tunnel_filter = StatsRecorderTunnelFilter::new(); + let throughput = peer_conn_tunnel_filter.filter_output(); + let peer_conn_tunnel = TunnelWithFilter::new(tunnel, peer_conn_tunnel_filter); + let mut mpsc_tunnel = MpscTunnel::new(peer_conn_tunnel); + + let (recv, sink) = (mpsc_tunnel.get_stream(), mpsc_tunnel.get_sink()); + + PeerConn { + conn_id: PeerConnId::new_v4(), + + my_peer_id, + global_ctx, + + tunnel: Box::new(mpsc_tunnel), + sink, + recv: Some(recv), + tunnel_info, + + tasks: JoinSet::new(), + + info: None, + close_event_sender: None, + + ctrl_resp_sender: ctrl_sender, + + latency_stats: Arc::new(WindowLatency::new(15)), + throughput, + loss_rate_stats: Arc::new(AtomicU32::new(0)), + } + } + + pub fn get_conn_id(&self) -> PeerConnId { + self.conn_id + } + + async fn wait_handshake(&mut self) -> Result { + let recv = self.recv.as_mut().unwrap(); + let Some(rsp) = recv.next().await else { + return Err(Error::WaitRespError( + "conn closed during wait handshake response".to_owned(), + )); + }; + let rsp = rsp?; + let rsp = HandshakeRequest::decode(rsp.payload()) + .map_err(|e| Error::WaitRespError(format!("decode handshake response error: {:?}", e))); + + return Ok(rsp.unwrap()); + } + + async fn wait_handshake_loop(&mut self) -> Result { + Ok(timeout(Duration::from_secs(5), async move { + loop { + match self.wait_handshake().await { + Ok(rsp) => return rsp, + Err(e) => { + log::warn!("wait handshake error: {:?}", e); + } + } + } + }) + .map_err(|e| Error::WaitRespError(format!("wait handshake timeout: {:?}", e))) + .await?) + } + + async fn send_handshake(&mut self) -> Result<(), Error> { + let network = self.global_ctx.get_network_identity(); + let req = HandshakeRequest { + magic: MAGIC, + my_peer_id: self.my_peer_id, + version: VERSION, + features: Vec::new(), + network_name: network.network_name.clone(), + network_secret: network.network_secret.clone(), + }; + + let hs_req = req.encode_to_vec(); + let mut zc_packet = ZCPacket::new_with_payload(hs_req.as_bytes()); + zc_packet.fill_peer_manager_hdr( + self.my_peer_id, + PeerId::default(), + PacketType::HandShake as u8, + ); + + self.sink.send(zc_packet).await.map_err(|e| { + tracing::warn!("send handshake request error: {:?}", e); + Error::WaitRespError("send handshake request error".to_owned()) + })?; + + Ok(()) + } + + #[tracing::instrument] + pub async fn do_handshake_as_server(&mut self) -> Result<(), Error> { + let rsp = self.wait_handshake_loop().await?; + tracing::info!("handshake request: {:?}", rsp); + self.info = Some(rsp); + self.send_handshake().await?; + Ok(()) + } + + #[tracing::instrument] + pub async fn do_handshake_as_client(&mut self) -> Result<(), Error> { + self.send_handshake().await?; + tracing::info!("waiting for handshake request from server"); + let rsp = self.wait_handshake_loop().await?; + tracing::info!("handshake response: {:?}", rsp); + self.info = Some(rsp); + Ok(()) + } + + pub fn handshake_done(&self) -> bool { + self.info.is_some() + } + + pub fn start_recv_loop(&mut self, packet_recv_chan: PacketRecvChan) { + let mut stream = self.recv.take().unwrap(); + let sink = self.sink.clone(); + let mut sender = PollSender::new(packet_recv_chan.clone()); + let close_event_sender = self.close_event_sender.clone().unwrap(); + let conn_id = self.conn_id; + let ctrl_sender = self.ctrl_resp_sender.clone(); + let _conn_info = self.get_conn_info(); + let conn_info_for_instrument = self.get_conn_info(); + + self.tasks.spawn( + async move { + tracing::info!("start recving peer conn packet"); + let mut task_ret = Ok(()); + while let Some(ret) = stream.next().await { + if ret.is_err() { + tracing::error!(error = ?ret, "peer conn recv error"); + task_ret = Err(ret.err().unwrap()); + break; + } + + let mut zc_packet = ret.unwrap(); + let Some(peer_mgr_hdr) = zc_packet.mut_peer_manager_header() else { + tracing::error!( + "unexpected packet: {:?}, cannot decode peer manager hdr", + zc_packet + ); + continue; + }; + + if peer_mgr_hdr.packet_type == PacketType::Ping as u8 { + peer_mgr_hdr.packet_type = PacketType::Pong as u8; + if let Err(e) = sink.send(zc_packet).await { + tracing::error!(?e, "peer conn send req error"); + } + } else if peer_mgr_hdr.packet_type == PacketType::Pong as u8 { + if let Err(e) = ctrl_sender.send(zc_packet) { + tracing::error!(?e, "peer conn send ctrl resp error"); + } + } else { + if sender.send(zc_packet).await.is_err() { + break; + } + } + } + + tracing::info!("end recving peer conn packet"); + + drop(sink); + if let Err(e) = close_event_sender.send(conn_id).await { + tracing::error!(error = ?e, "peer conn close event send error"); + } + + task_ret + } + .instrument( + tracing::info_span!("peer conn recv loop", conn_info = ?conn_info_for_instrument), + ), + ); + } + + pub fn start_pingpong(&mut self) { + let mut pingpong = PeerConnPinger::new( + self.my_peer_id, + self.get_peer_id(), + self.sink.clone(), + self.ctrl_resp_sender.clone(), + self.latency_stats.clone(), + self.loss_rate_stats.clone(), + ); + + let close_event_sender = self.close_event_sender.clone().unwrap(); + let conn_id = self.conn_id; + + self.tasks.spawn(async move { + pingpong.pingpong().await; + + tracing::warn!(?pingpong, "pingpong task exit"); + + if let Err(e) = close_event_sender.send(conn_id).await { + log::warn!("close event sender error: {:?}", e); + } + + Ok(()) + }); + } + + pub async fn send_msg(&mut self, msg: ZCPacket) -> Result<(), Error> { + Ok(self.sink.send(msg).await?) + } + + pub fn get_peer_id(&self) -> PeerId { + self.info.as_ref().unwrap().my_peer_id + } + + pub fn get_network_identity(&self) -> NetworkIdentity { + let info = self.info.as_ref().unwrap(); + NetworkIdentity { + network_name: info.network_name.clone(), + network_secret: info.network_secret.clone(), + } + } + + pub fn set_close_event_sender(&mut self, sender: mpsc::Sender) { + self.close_event_sender = Some(sender); + } + + pub fn get_stats(&self) -> PeerConnStats { + PeerConnStats { + latency_us: self.latency_stats.get_latency_us(), + + tx_bytes: self.throughput.tx_bytes(), + rx_bytes: self.throughput.rx_bytes(), + + tx_packets: self.throughput.tx_packets(), + rx_packets: self.throughput.rx_packets(), + } + } + + pub fn get_conn_info(&self) -> PeerConnInfo { + PeerConnInfo { + conn_id: self.conn_id.to_string(), + my_peer_id: self.my_peer_id, + peer_id: self.get_peer_id(), + features: self.info.as_ref().unwrap().features.clone(), + tunnel: self.tunnel_info.clone(), + stats: Some(self.get_stats()), + loss_rate: (f64::from(self.loss_rate_stats.load(Ordering::Relaxed)) / 100.0) as f32, + } + } +} + +#[cfg(test)] +mod tests { + use std::sync::Arc; + + use super::*; + use crate::common::global_ctx::tests::get_mock_global_ctx; + use crate::common::new_peer_id; + use crate::tunnel::filter::tests::DropSendTunnelFilter; + use crate::tunnel::filter::PacketRecorderTunnelFilter; + use crate::tunnel::ring::create_ring_tunnel_pair; + + #[tokio::test] + async fn peer_conn_handshake() { + let (c, s) = create_ring_tunnel_pair(); + + let c_recorder = Arc::new(PacketRecorderTunnelFilter::new()); + let s_recorder = Arc::new(PacketRecorderTunnelFilter::new()); + + let c = TunnelWithFilter::new(c, c_recorder.clone()); + let s = TunnelWithFilter::new(s, s_recorder.clone()); + + let c_peer_id = new_peer_id(); + let s_peer_id = new_peer_id(); + + let mut c_peer = PeerConn::new(c_peer_id, get_mock_global_ctx(), Box::new(c)); + + let mut s_peer = PeerConn::new(s_peer_id, get_mock_global_ctx(), Box::new(s)); + + let (c_ret, s_ret) = tokio::join!( + c_peer.do_handshake_as_client(), + s_peer.do_handshake_as_server() + ); + + c_ret.unwrap(); + s_ret.unwrap(); + + assert_eq!(c_recorder.sent.lock().unwrap().len(), 1); + assert_eq!(c_recorder.received.lock().unwrap().len(), 1); + + assert_eq!(s_recorder.sent.lock().unwrap().len(), 1); + assert_eq!(s_recorder.received.lock().unwrap().len(), 1); + + assert_eq!(c_peer.get_peer_id(), s_peer_id); + assert_eq!(s_peer.get_peer_id(), c_peer_id); + assert_eq!(c_peer.get_network_identity(), s_peer.get_network_identity()); + assert_eq!(c_peer.get_network_identity(), NetworkIdentity::default()); + } + + async fn peer_conn_pingpong_test_common(drop_start: u32, drop_end: u32, conn_closed: bool) { + let (c, s) = create_ring_tunnel_pair(); + + // drop 1-3 packets should not affect pingpong + let c_recorder = Arc::new(DropSendTunnelFilter::new(drop_start, drop_end)); + let c = TunnelWithFilter::new(c, c_recorder.clone()); + + let c_peer_id = new_peer_id(); + let s_peer_id = new_peer_id(); + + let mut c_peer = PeerConn::new(c_peer_id, get_mock_global_ctx(), Box::new(c)); + let mut s_peer = PeerConn::new(s_peer_id, get_mock_global_ctx(), Box::new(s)); + + let (c_ret, s_ret) = tokio::join!( + c_peer.do_handshake_as_client(), + s_peer.do_handshake_as_server() + ); + + s_peer.set_close_event_sender(tokio::sync::mpsc::channel(1).0); + s_peer.start_recv_loop(tokio::sync::mpsc::channel(200).0); + + assert!(c_ret.is_ok()); + assert!(s_ret.is_ok()); + + let (close_send, mut close_recv) = tokio::sync::mpsc::channel(1); + c_peer.set_close_event_sender(close_send); + c_peer.start_pingpong(); + c_peer.start_recv_loop(tokio::sync::mpsc::channel(200).0); + + // wait 5s, conn should not be disconnected + tokio::time::sleep(Duration::from_secs(15)).await; + + if conn_closed { + assert!(close_recv.try_recv().is_ok()); + } else { + assert!(close_recv.try_recv().is_err()); + } + } + + #[tokio::test] + async fn peer_conn_pingpong_timeout() { + peer_conn_pingpong_test_common(3, 5, false).await; + peer_conn_pingpong_test_common(5, 12, true).await; + } +} + +/* +use std::{ + fmt::Debug, + pin::Pin, + sync::{ + atomic::{AtomicU32, Ordering}, + Arc, + }, +}; + +use futures::{SinkExt, StreamExt}; +use pnet::datalink::NetworkInterface; + +use tokio::{ + sync::{broadcast, mpsc, Mutex}, + task::JoinSet, + time::{timeout, Duration}, +}; + +use tokio_util::{bytes::Bytes, sync::PollSender}; +use tracing::Instrument; + +use crate::{ + common::{ + error::Error, + global_ctx::{ArcGlobalCtx, NetworkIdentity}, + PeerId, + }, + define_tunnel_filter_chain, + peers::packet::{ArchivedPacketType, CtrlPacketPayload, PacketType}, + rpc::{PeerConnInfo, PeerConnStats}, + tunnel::{mpsc::MpscTunnelSender, stats::WindowLatency, TunnelError}, +}; + +use super::packet::{self, HandShake, Packet}; + +pub type PacketRecvChan = mpsc::Sender; + +macro_rules! wait_response { + ($stream: ident, $out_var:ident, $pattern:pat_param => $value:expr) => { + let Ok(rsp_vec) = timeout(Duration::from_secs(1), $stream.next()).await else { + return Err(Error::WaitRespError( + "wait handshake response timeout".to_owned(), + )); + }; + let Some(rsp_vec) = rsp_vec else { + return Err(Error::WaitRespError( + "wait handshake response get none".to_owned(), + )); + }; + let Ok(rsp_vec) = rsp_vec else { + return Err(Error::WaitRespError(format!( + "wait handshake response get error {}", + rsp_vec.err().unwrap() + ))); + }; + + let $out_var; + let rsp_bytes = Packet::decode(&rsp_vec); + if rsp_bytes.packet_type != PacketType::HandShake { + tracing::error!("unexpected packet type: {:?}", rsp_bytes); + return Err(Error::WaitRespError("unexpected packet type".to_owned())); + } + let resp_payload = CtrlPacketPayload::from_packet(&rsp_bytes); + match &resp_payload { + $pattern => $out_var = $value, + _ => { + tracing::error!( + "unexpected packet: {:?}, pattern: {:?}", + rsp_bytes, + stringify!($pattern) + ); + return Err(Error::WaitRespError("unexpected packet".to_owned())); + } + } + }; +} + +impl<'a> From<&HandShake> for PeerInfo { + fn from(hs: &HandShake) -> Self { + PeerInfo { + magic: hs.magic.into(), + my_peer_id: hs.my_peer_id.into(), + version: hs.version.into(), + features: hs.features.iter().map(|x| x.to_string()).collect(), + interfaces: Vec::new(), + network_identity: hs.network_identity.clone(), + } + } +} + + +define_tunnel_filter_chain!(PeerConnTunnel, stats = StatsRecorderTunnelFilter); + +pub struct PeerConn { + conn_id: PeerConnId, + + my_peer_id: PeerId, + global_ctx: ArcGlobalCtx, + + sink: Pin>, + tunnel: Box, + + tasks: JoinSet>, + + info: Option, + + close_event_sender: Option>, + + ctrl_resp_sender: broadcast::Sender, + + latency_stats: Arc, + throughput: Arc, + loss_rate_stats: Arc, +} + +enum PeerConnPacketType { + Data(Bytes), + CtrlReq(Bytes), + CtrlResp(Bytes), +} + +impl PeerConn { + pub fn new(my_peer_id: PeerId, global_ctx: ArcGlobalCtx, tunnel: Box) -> Self { + let (ctrl_sender, _ctrl_receiver) = broadcast::channel(100); + let peer_conn_tunnel = PeerConnTunnel::new(); + let tunnel = peer_conn_tunnel.wrap_tunnel(tunnel); + + PeerConn { + conn_id: PeerConnId::new_v4(), + + my_peer_id, + global_ctx, + + sink: tunnel.pin_sink(), + tunnel: Box::new(tunnel), + + tasks: JoinSet::new(), + + info: None, + close_event_sender: None, + + ctrl_resp_sender: ctrl_sender, + + latency_stats: Arc::new(WindowLatency::new(15)), + throughput: peer_conn_tunnel.stats.get_throughput().clone(), + loss_rate_stats: Arc::new(AtomicU32::new(0)), + } + } + + pub fn get_conn_id(&self) -> PeerConnId { + self.conn_id + } + + #[tracing::instrument] + pub async fn do_handshake_as_server(&mut self) -> Result<(), TunnelError> { + let mut stream = self.tunnel.pin_stream(); + let mut sink = self.tunnel.pin_sink(); + + tracing::info!("waiting for handshake request from client"); + wait_response!(stream, hs_req, CtrlPacketPayload::HandShake(x) => x); + self.info = Some(PeerInfo::from(hs_req)); + tracing::info!("handshake request: {:?}", hs_req); + + let hs_req = self + .global_ctx + .net_ns + .run(|| packet::Packet::new_handshake(self.my_peer_id, &self.global_ctx.network)); + sink.send(hs_req.into()).await?; + + Ok(()) + } + + #[tracing::instrument] + pub async fn do_handshake_as_client(&mut self) -> Result<(), TunnelError> { + let mut stream = self.tunnel.pin_stream(); + let mut sink = self.tunnel.pin_sink(); + + let hs_req = self + .global_ctx + .net_ns + .run(|| packet::Packet::new_handshake(self.my_peer_id, &self.global_ctx.network)); + sink.send(hs_req.into()).await?; + + tracing::info!("waiting for handshake request from server"); + wait_response!(stream, hs_rsp, CtrlPacketPayload::HandShake(x) => x); + self.info = Some(PeerInfo::from(hs_rsp)); + tracing::info!("handshake response: {:?}", hs_rsp); + + Ok(()) + } + + pub fn handshake_done(&self) -> bool { + self.info.is_some() + } + + pub fn start_recv_loop(&mut self, packet_recv_chan: PacketRecvChan) { + let mut stream = self.tunnel.pin_stream(); + let mut sink = self.tunnel.pin_sink(); + let mut sender = PollSender::new(packet_recv_chan.clone()); + let close_event_sender = self.close_event_sender.clone().unwrap(); + let conn_id = self.conn_id; + let ctrl_sender = self.ctrl_resp_sender.clone(); + let conn_info = self.get_conn_info(); + let conn_info_for_instrument = self.get_conn_info(); + + self.tasks.spawn( + async move { + tracing::info!("start recving peer conn packet"); + let mut task_ret = Ok(()); + while let Some(ret) = stream.next().await { + if ret.is_err() { + tracing::error!(error = ?ret, "peer conn recv error"); + task_ret = Err(ret.err().unwrap()); + break; + } + + let buf = ret.unwrap(); + let p = Packet::decode(&buf); + match p.packet_type { + ArchivedPacketType::Ping => { + let CtrlPacketPayload::Ping(seq) = CtrlPacketPayload::from_packet(p) + else { + log::error!("unexpected packet: {:?}", p); + continue; + }; + + let pong = packet::Packet::new_pong_packet( + conn_info.my_peer_id, + conn_info.peer_id, + seq.into(), + ); + + if let Err(e) = sink.send(pong.into()).await { + tracing::error!(?e, "peer conn send req error"); + } + } + ArchivedPacketType::Pong => { + if let Err(e) = ctrl_sender.send(buf.into()) { + tracing::error!(?e, "peer conn send ctrl resp error"); + } + } + _ => { + if sender.send(buf.into()).await.is_err() { + break; + } + } + } + } + + tracing::info!("end recving peer conn packet"); + + if let Err(close_ret) = sink.close().await { + tracing::error!(error = ?close_ret, "peer conn sink close error, ignore it"); + } + if let Err(e) = close_event_sender.send(conn_id).await { + tracing::error!(error = ?e, "peer conn close event send error"); + } + + task_ret + } + .instrument( + tracing::info_span!("peer conn recv loop", conn_info = ?conn_info_for_instrument), + ), + ); + } + + pub async fn send_msg(&mut self, msg: Bytes) -> Result<(), Error> { + self.sink.send(msg).await + } + + pub fn get_peer_id(&self) -> PeerId { + self.info.as_ref().unwrap().my_peer_id + } + + pub fn get_network_identity(&self) -> NetworkIdentity { + self.info.as_ref().unwrap().network_identity.clone() + } + + pub fn set_close_event_sender(&mut self, sender: mpsc::Sender) { + self.close_event_sender = Some(sender); + } + + pub fn get_stats(&self) -> PeerConnStats { + PeerConnStats { + latency_us: self.latency_stats.get_latency_us(), + + tx_bytes: self.throughput.tx_bytes(), + rx_bytes: self.throughput.rx_bytes(), + + tx_packets: self.throughput.tx_packets(), + rx_packets: self.throughput.rx_packets(), + } + } + + pub fn get_conn_info(&self) -> PeerConnInfo { + PeerConnInfo { + conn_id: self.conn_id.to_string(), + my_peer_id: self.my_peer_id, + peer_id: self.get_peer_id(), + features: self.info.as_ref().unwrap().features.clone(), + tunnel: self.tunnel.info(), + stats: Some(self.get_stats()), + loss_rate: (f64::from(self.loss_rate_stats.load(Ordering::Relaxed)) / 100.0) as f32, + } + } +} + +impl Drop for PeerConn { + fn drop(&mut self) { + let mut sink = self.tunnel.pin_sink(); + tokio::spawn(async move { + let ret = sink.close().await; + tracing::info!(error = ?ret, "peer conn tunnel closed."); + }); + log::info!("peer conn {:?} drop", self.conn_id); + } +} + +} + */ diff --git a/easytier/src/tests/mod.rs b/easytier/src/tests/mod.rs index 41c6e8af..4c299c15 100644 --- a/easytier/src/tests/mod.rs +++ b/easytier/src/tests/mod.rs @@ -116,7 +116,7 @@ pub fn add_ns_to_bridge(br_name: &str, ns_name: &str) { pub fn enable_log() { let filter = tracing_subscriber::EnvFilter::builder() - .with_default_directive(tracing::level_filters::LevelFilter::INFO.into()) + .with_default_directive(tracing::level_filters::LevelFilter::TRACE.into()) .from_env() .unwrap() .add_directive("tarpc=error".parse().unwrap()); diff --git a/easytier/src/tests/three_node.rs b/easytier/src/tests/three_node.rs index acf91cfc..d85db8f0 100644 --- a/easytier/src/tests/three_node.rs +++ b/easytier/src/tests/three_node.rs @@ -9,18 +9,18 @@ use super::*; use crate::{ common::{ - config::{ConfigLoader, NetworkIdentity, TomlConfigLoader}, + config::{ConfigLoader, NetworkIdentity, TomlConfigLoader, VpnPortalConfig}, netns::{NetNS, ROOT_NETNS_NAME}, }, instance::instance::Instance, peers::tests::wait_for_condition, - tunnels::{ - common::tests::_tunnel_pingpong_netns, - ring_tunnel::RingTunnelConnector, - tcp_tunnel::{TcpTunnelConnector, TcpTunnelListener}, - udp_tunnel::{UdpTunnelConnector, UdpTunnelListener}, + tunnel::{ + ring::RingTunnelConnector, + tcp::TcpTunnelConnector, + udp::UdpTunnelConnector, wireguard::{WgConfig, WgTunnelConnector}, }, + vpn_portal::wireguard::get_wg_config_for_portal, }; pub fn prepare_linux_namespaces() { @@ -113,6 +113,26 @@ pub async fn init_three_node(proto: &str) -> Vec { vec![inst1, inst2, inst3] } +async fn ping_test(from_netns: &str, target_ip: &str) -> bool { + let _g = NetNS::new(Some(ROOT_NETNS_NAME.to_owned())).guard(); + let code = tokio::process::Command::new("ip") + .args(&[ + "netns", + "exec", + from_netns, + "ping", + "-c", + "1", + "-W", + "1", + target_ip.to_string().as_str(), + ]) + .status() + .await + .unwrap(); + code.code().unwrap() == 0 +} + #[rstest::rstest] #[tokio::test] #[serial_test::serial] @@ -130,12 +150,20 @@ pub async fn basic_three_node_test(#[values("tcp", "udp", "wg")] proto: &str) { insts[2].peer_id(), insts[0].get_peer_manager().list_routes().await, ); + + wait_for_condition( + || async { ping_test("net_c", "10.144.144.1").await }, + Duration::from_secs(5), + ) + .await; } #[rstest::rstest] #[tokio::test] #[serial_test::serial] pub async fn tcp_proxy_three_node_test(#[values("tcp", "udp", "wg")] proto: &str) { + use crate::tunnel::{common::tests::_tunnel_pingpong_netns, tcp::TcpTunnelListener}; + let insts = init_three_node(proto).await; insts[2] @@ -187,25 +215,19 @@ pub async fn icmp_proxy_three_node_test(#[values("tcp", "udp", "wg")] proto: &st ) .await; - // wait updater - tokio::time::sleep(tokio::time::Duration::from_secs(6)).await; - - // send ping with shell in net_a to net_d - let _g = NetNS::new(Some(ROOT_NETNS_NAME.to_owned())).guard(); - let code = tokio::process::Command::new("ip") - .args(&[ - "netns", "exec", "net_a", "ping", "-c", "1", "-W", "1", "10.1.2.4", - ]) - .status() - .await - .unwrap(); - assert_eq!(code.code().unwrap(), 0); + wait_for_condition( + || async { ping_test("net_a", "10.1.2.4").await }, + Duration::from_secs(5), + ) + .await; } #[rstest::rstest] #[tokio::test] #[serial_test::serial] pub async fn proxy_three_node_disconnect_test(#[values("tcp", "wg")] proto: &str) { + use crate::tunnel::wireguard::{WgConfig, WgTunnelConnector}; + let insts = init_three_node(proto).await; let mut inst4 = Instance::new(get_inst_config("inst4", Some("net_d"), "10.144.144.4")); if proto == "tcp" { @@ -266,6 +288,8 @@ pub async fn proxy_three_node_disconnect_test(#[values("tcp", "wg")] proto: &str #[tokio::test] #[serial_test::serial] pub async fn udp_proxy_three_node_test(#[values("tcp", "udp", "wg")] proto: &str) { + use crate::tunnel::{common::tests::_tunnel_pingpong_netns, udp::UdpTunnelListener}; + let insts = init_three_node(proto).await; insts[2] @@ -389,21 +413,108 @@ pub async fn foreign_network_forward_nic_data() { ) .await; - let _g = NetNS::new(Some(ROOT_NETNS_NAME.to_owned())).guard(); - let code = tokio::process::Command::new("ip") - .args(&[ - "netns", - "exec", - "net_b", - "ping", - "-c", - "1", - "-W", - "1", - "10.144.145.2", - ]) - .status() - .await - .unwrap(); - assert_eq!(code.code().unwrap(), 0); + wait_for_condition( + || async { ping_test("net_b", "10.144.145.2").await }, + Duration::from_secs(5), + ) + .await; +} + +use std::{net::SocketAddr, str::FromStr}; + +use defguard_wireguard_rs::{ + host::Peer, key::Key, net::IpAddrMask, InterfaceConfiguration, WGApi, WireguardInterfaceApi, +}; + +fn run_wireguard_client( + endpoint: SocketAddr, + peer_public_key: Key, + client_private_key: Key, + allowed_ips: Vec, + client_ip: String, +) -> Result<(), Box> { + // Create new API object for interface + let ifname: String = if cfg!(target_os = "linux") || cfg!(target_os = "freebsd") { + "wg0".into() + } else { + "utun3".into() + }; + let wgapi = WGApi::new(ifname.clone(), false)?; + + // create interface + wgapi.create_interface()?; + + // Peer secret key + let mut peer = Peer::new(peer_public_key.clone()); + + log::info!("endpoint"); + // Peer endpoint and interval + peer.endpoint = Some(endpoint); + peer.persistent_keepalive_interval = Some(25); + for ip in allowed_ips { + peer.allowed_ips.push(IpAddrMask::from_str(ip.as_str())?); + } + + // interface configuration + let interface_config = InterfaceConfiguration { + name: ifname.clone(), + prvkey: client_private_key.to_string(), + address: client_ip, + port: 12345, + peers: vec![peer], + }; + + #[cfg(not(windows))] + wgapi.configure_interface(&interface_config)?; + #[cfg(windows)] + wgapi.configure_interface(&interface_config, &[])?; + wgapi.configure_peer_routing(&interface_config.peers)?; + Ok(()) +} + +#[tokio::test] +#[serial_test::serial] +pub async fn wireguard_vpn_portal() { + let mut insts = init_three_node("tcp").await; + let net_ns = NetNS::new(Some("net_d".into())); + let _g = net_ns.guard(); + insts[2] + .get_global_ctx() + .config + .set_vpn_portal_config(VpnPortalConfig { + wireguard_listen: "0.0.0.0:22121".parse().unwrap(), + client_cidr: "10.14.14.0/24".parse().unwrap(), + }); + insts[2].run_vpn_portal().await.unwrap(); + + let net_ns = NetNS::new(Some("net_d".into())); + let _g = net_ns.guard(); + let wg_cfg = get_wg_config_for_portal(&insts[2].get_global_ctx().get_network_identity()); + run_wireguard_client( + "10.1.2.3:22121".parse().unwrap(), + Key::try_from(wg_cfg.my_public_key()).unwrap(), + Key::try_from(wg_cfg.peer_secret_key()).unwrap(), + vec!["10.14.14.0/24".to_string(), "10.144.144.0/24".to_string()], + "10.14.14.2".to_string(), + ) + .unwrap(); + + // ping other node in network + wait_for_condition( + || async { ping_test("net_d", "10.144.144.1").await }, + Duration::from_secs(5), + ) + .await; + wait_for_condition( + || async { ping_test("net_d", "10.144.144.2").await }, + Duration::from_secs(5), + ) + .await; + + // ping portal node + wait_for_condition( + || async { ping_test("net_d", "10.144.144.3").await }, + Duration::from_secs(500), + ) + .await; } diff --git a/easytier/src/tunnel/buf.rs b/easytier/src/tunnel/buf.rs new file mode 100644 index 00000000..01363096 --- /dev/null +++ b/easytier/src/tunnel/buf.rs @@ -0,0 +1,92 @@ +use std::collections::VecDeque; +use std::io::IoSlice; + +use bytes::{Buf, BufMut, Bytes, BytesMut}; + +pub(crate) struct BufList { + bufs: VecDeque, +} + +impl BufList { + pub(crate) fn new() -> BufList { + BufList { + bufs: VecDeque::new(), + } + } + + #[inline] + pub(crate) fn push(&mut self, buf: T) { + debug_assert!(buf.has_remaining()); + self.bufs.push_back(buf); + } + + #[inline] + pub(crate) fn bufs_cnt(&self) -> usize { + self.bufs.len() + } +} + +impl Buf for BufList { + #[inline] + fn remaining(&self) -> usize { + self.bufs.iter().map(|buf| buf.remaining()).sum() + } + + #[inline] + fn chunk(&self) -> &[u8] { + self.bufs.front().map(Buf::chunk).unwrap_or_default() + } + + #[inline] + fn advance(&mut self, mut cnt: usize) { + while cnt > 0 { + { + let front = &mut self.bufs[0]; + let rem = front.remaining(); + if rem > cnt { + front.advance(cnt); + return; + } else { + front.advance(rem); + cnt -= rem; + } + } + self.bufs.pop_front(); + } + } + + #[inline] + fn chunks_vectored<'t>(&'t self, dst: &mut [IoSlice<'t>]) -> usize { + if dst.is_empty() { + return 0; + } + let mut vecs = 0; + for buf in &self.bufs { + vecs += buf.chunks_vectored(&mut dst[vecs..]); + if vecs == dst.len() { + break; + } + } + vecs + } + + #[inline] + fn copy_to_bytes(&mut self, len: usize) -> Bytes { + // Our inner buffer may have an optimized version of copy_to_bytes, and if the whole + // request can be fulfilled by the front buffer, we can take advantage. + match self.bufs.front_mut() { + Some(front) if front.remaining() == len => { + let b = front.copy_to_bytes(len); + self.bufs.pop_front(); + b + } + Some(front) if front.remaining() > len => front.copy_to_bytes(len), + _ => { + assert!(len <= self.remaining(), "`len` greater than remaining"); + let mut bm = BytesMut::with_capacity(len); + bm.put(self.take(len)); + bm.freeze() + } + } + } +} diff --git a/easytier/src/tunnel/common.rs b/easytier/src/tunnel/common.rs new file mode 100644 index 00000000..2e56a716 --- /dev/null +++ b/easytier/src/tunnel/common.rs @@ -0,0 +1,539 @@ +use std::{ + any::Any, + net::{IpAddr, SocketAddr}, + pin::Pin, + sync::{Arc, Mutex}, + task::{ready, Poll}, +}; + +use futures::{stream::FuturesUnordered, Future, Sink, Stream}; +use network_interface::NetworkInterfaceConfig as _; +use pin_project_lite::pin_project; +use tokio::io::{AsyncRead, AsyncWrite}; + +use bytes::{Buf, Bytes, BytesMut}; +use tokio_stream::StreamExt; +use tokio_util::io::{poll_read_buf, poll_write_buf}; +use zerocopy::FromBytes as _; + +use crate::{ + rpc::TunnelInfo, + tunnel::packet_def::{ZCPacket, PEER_MANAGER_HEADER_SIZE}, +}; + +use super::{ + buf::BufList, + packet_def::{TCPTunnelHeader, ZCPacketType, TCP_TUNNEL_HEADER_SIZE}, + SinkItem, StreamItem, Tunnel, TunnelError, ZCPacketSink, ZCPacketStream, +}; + +pub struct TunnelWrapper { + reader: Arc>>, + writer: Arc>>, + info: Option, + associate_data: Option>, +} + +impl TunnelWrapper { + pub fn new(reader: R, writer: W, info: Option) -> Self { + Self::new_with_associate_data(reader, writer, info, None) + } + + pub fn new_with_associate_data( + reader: R, + writer: W, + info: Option, + associate_data: Option>, + ) -> Self { + TunnelWrapper { + reader: Arc::new(Mutex::new(Some(reader))), + writer: Arc::new(Mutex::new(Some(writer))), + info, + associate_data, + } + } +} + +impl Tunnel for TunnelWrapper +where + R: ZCPacketStream + Send + 'static, + W: ZCPacketSink + Send + 'static, +{ + fn split(&self) -> (Pin>, Pin>) { + let reader = self.reader.lock().unwrap().take().unwrap(); + let writer = self.writer.lock().unwrap().take().unwrap(); + (Box::pin(reader), Box::pin(writer)) + } + + fn info(&self) -> Option { + self.info.clone() + } +} + +// a length delimited codec for async reader +pin_project! { + pub struct FramedReader { + #[pin] + reader: R, + buf: BytesMut, + state: FrameReaderState, + max_packet_size: usize, + associate_data: Option>, + } +} + +// usize means the size remaining to read +enum FrameReaderState { + ReadingHeader(usize), + ReadingBody(usize), +} + +impl FramedReader { + pub fn new(reader: R, max_packet_size: usize) -> Self { + Self::new_with_associate_data(reader, max_packet_size, None) + } + + pub fn new_with_associate_data( + reader: R, + max_packet_size: usize, + associate_data: Option>, + ) -> Self { + FramedReader { + reader, + buf: BytesMut::with_capacity(max_packet_size), + state: FrameReaderState::ReadingHeader(4), + max_packet_size, + associate_data, + } + } + + fn extract_one_packet(buf: &mut BytesMut) -> Option { + if buf.len() < TCP_TUNNEL_HEADER_SIZE { + // header is not complete + return None; + } + + let header = TCPTunnelHeader::ref_from_prefix(&buf[..]).unwrap(); + let body_len = header.len.get() as usize; + if buf.len() < TCP_TUNNEL_HEADER_SIZE + body_len { + // body is not complete + return None; + } + + // extract one packet + let packet_buf = buf.split_to(TCP_TUNNEL_HEADER_SIZE + body_len); + Some(ZCPacket::new_from_buf(packet_buf, ZCPacketType::TCP)) + } +} + +impl Stream for FramedReader +where + R: AsyncRead + Send + 'static + Unpin, +{ + type Item = StreamItem; + + fn poll_next( + self: Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + let mut self_mut = self.project(); + + loop { + while let Some(packet) = Self::extract_one_packet(self_mut.buf) { + return Poll::Ready(Some(Ok(packet))); + } + + reserve_buf( + &mut self_mut.buf, + *self_mut.max_packet_size, + *self_mut.max_packet_size * 64, + ); + + match ready!(poll_read_buf( + self_mut.reader.as_mut(), + cx, + &mut self_mut.buf + )) { + Ok(size) => { + if size == 0 { + return Poll::Ready(None); + } + } + Err(e) => { + return Poll::Ready(Some(Err(TunnelError::IOError(e)))); + } + } + } + } +} + +pin_project! { + pub struct FramedWriter { + #[pin] + writer: W, + sending_bufs: BufList, + associate_data: Option>, + } +} + +impl FramedWriter { + pub fn new(writer: W) -> Self { + Self::new_with_associate_data(writer, None) + } + + pub fn new_with_associate_data( + writer: W, + associate_data: Option>, + ) -> Self { + FramedWriter { + writer, + sending_bufs: BufList::new(), + associate_data: associate_data, + } + } + + fn max_buffer_count(&self) -> usize { + 64 + } +} + +impl Sink for FramedWriter +where + W: AsyncWrite + Send + 'static, +{ + type Error = TunnelError; + + fn poll_ready( + mut self: Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + let max_buffer_count = self.max_buffer_count(); + if self.sending_bufs.bufs_cnt() >= max_buffer_count { + self.as_mut().poll_flush(cx) + } else { + tracing::trace!(bufs_cnt = self.sending_bufs.bufs_cnt(), "ready to send"); + Poll::Ready(Ok(())) + } + } + + fn start_send(self: Pin<&mut Self>, mut item: ZCPacket) -> Result<(), Self::Error> { + let tcp_len = PEER_MANAGER_HEADER_SIZE + item.payload_len(); + let Some(header) = item.mut_tcp_tunnel_header() else { + return Err(TunnelError::InvalidPacket("packet too short".to_string())); + }; + header.len.set(tcp_len.try_into().unwrap()); + + let item = item.into_bytes(ZCPacketType::TCP); + self.project().sending_bufs.push(item); + + Ok(()) + } + + fn poll_flush( + self: Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> Poll> { + let mut pinned = self.project(); + let mut remaining = pinned.sending_bufs.remaining(); + while remaining != 0 { + let n = ready!(poll_write_buf( + pinned.writer.as_mut(), + cx, + pinned.sending_bufs + ))?; + if n == 0 { + return Poll::Ready(Err(TunnelError::IOError(std::io::Error::new( + std::io::ErrorKind::WriteZero, + "failed to \ + write frame to transport", + )))); + } + remaining -= n; + } + + tracing::trace!(?remaining, "flushed"); + + // Try flushing the underlying IO + ready!(pinned.writer.poll_flush(cx))?; + + Poll::Ready(Ok(())) + } + + fn poll_close( + mut self: Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> Poll> { + ready!(self.as_mut().poll_flush(cx))?; + ready!(self.project().writer.poll_shutdown(cx))?; + + Poll::Ready(Ok(())) + } +} + +pub(crate) fn get_interface_name_by_ip(local_ip: &IpAddr) -> Option { + if local_ip.is_unspecified() || local_ip.is_multicast() { + return None; + } + let ifaces = network_interface::NetworkInterface::show().ok()?; + for iface in ifaces { + for addr in iface.addr { + if addr.ip() == *local_ip { + return Some(iface.name); + } + } + } + + tracing::error!(?local_ip, "can not find interface name by ip"); + None +} + +pub(crate) fn setup_sokcet2_ext( + socket2_socket: &socket2::Socket, + bind_addr: &SocketAddr, + bind_dev: Option, +) -> Result<(), TunnelError> { + #[cfg(target_os = "windows")] + { + let is_udp = matches!(socket2_socket.r#type()?, socket2::Type::DGRAM); + crate::arch::windows::setup_socket_for_win(socket2_socket, bind_addr, bind_dev, is_udp)?; + } + + socket2_socket.set_nonblocking(true)?; + socket2_socket.set_reuse_address(true)?; + socket2_socket.bind(&socket2::SockAddr::from(*bind_addr))?; + + // #[cfg(all(unix, not(target_os = "solaris"), not(target_os = "illumos")))] + // socket2_socket.set_reuse_port(true)?; + + if bind_addr.ip().is_unspecified() { + return Ok(()); + } + + // linux/mac does not use interface of bind_addr to send packet, so we need to bind device + // win can handle this with bind correctly + #[cfg(any(target_os = "ios", target_os = "macos"))] + if let Some(dev_name) = bind_dev { + // use IP_BOUND_IF to bind device + unsafe { + let dev_idx = nix::libc::if_nametoindex(dev_name.as_str().as_ptr() as *const i8); + tracing::warn!(?dev_idx, ?dev_name, "bind device"); + socket2_socket.bind_device_by_index_v4(std::num::NonZeroU32::new(dev_idx))?; + tracing::warn!(?dev_idx, ?dev_name, "bind device doen"); + } + } + + #[cfg(any(target_os = "android", target_os = "fuchsia", target_os = "linux"))] + if let Some(dev_name) = bind_dev { + tracing::trace!(dev_name = ?dev_name, "bind device"); + socket2_socket.bind_device(Some(dev_name.as_bytes()))?; + } + + Ok(()) +} + +pub(crate) async fn wait_for_connect_futures( + mut futures: FuturesUnordered, +) -> Result +where + Fut: Future> + Send + Sync, + E: std::error::Error + Into + Send + Sync + 'static, +{ + // return last error + let mut last_err = None; + + while let Some(ret) = futures.next().await { + if let Err(e) = ret { + last_err = Some(e.into()); + } else { + return ret.map_err(|e| e.into()); + } + } + + Err(last_err.unwrap_or(TunnelError::Shutdown)) +} + +pub(crate) fn setup_sokcet2( + socket2_socket: &socket2::Socket, + bind_addr: &SocketAddr, +) -> Result<(), TunnelError> { + setup_sokcet2_ext( + socket2_socket, + bind_addr, + super::common::get_interface_name_by_ip(&bind_addr.ip()), + ) +} + +pub fn reserve_buf(buf: &mut BytesMut, min_size: usize, max_size: usize) { + if buf.capacity() < min_size { + buf.reserve(max_size); + } +} + +pub mod tests { + use std::time::Instant; + + use futures::{SinkExt, StreamExt, TryStreamExt}; + use tokio_util::bytes::{BufMut, Bytes, BytesMut}; + + use crate::{ + common::netns::NetNS, + tunnel::{packet_def::ZCPacket, TunnelConnector, TunnelListener}, + }; + + pub async fn _tunnel_echo_server(tunnel: Box, once: bool) { + let (mut recv, mut send) = tunnel.split(); + + if !once { + recv.forward(send).await.unwrap(); + } else { + let Some(ret) = recv.next().await else { + assert!(false, "recv error"); + return; + }; + + if ret.is_err() { + tracing::debug!(?ret, "recv error"); + return; + } + + let res = ret.unwrap(); + tracing::debug!(?res, "recv a msg, try echo back"); + send.send(res).await.unwrap(); + } + + tracing::warn!("echo server exit..."); + } + + pub(crate) async fn _tunnel_pingpong(listener: L, connector: C) + where + L: TunnelListener + Send + Sync + 'static, + C: TunnelConnector + Send + Sync + 'static, + { + _tunnel_pingpong_netns(listener, connector, NetNS::new(None), NetNS::new(None)).await + } + + pub(crate) async fn _tunnel_pingpong_netns( + mut listener: L, + mut connector: C, + l_netns: NetNS, + c_netns: NetNS, + ) where + L: TunnelListener + Send + Sync + 'static, + C: TunnelConnector + Send + Sync + 'static, + { + l_netns + .run_async(|| async { + listener.listen().await.unwrap(); + }) + .await; + + let lis = tokio::spawn(async move { + let ret = listener.accept().await.unwrap(); + assert_eq!( + ret.info().unwrap().local_addr, + listener.local_url().to_string() + ); + _tunnel_echo_server(ret, false).await + }); + + let tunnel = c_netns.run_async(|| connector.connect()).await.unwrap(); + + assert_eq!( + tunnel.info().unwrap().remote_addr, + connector.remote_url().to_string() + ); + + let (mut recv, mut send) = tunnel.split(); + + send.send(ZCPacket::new_with_payload("12345678abcdefg".as_bytes())) + .await + .unwrap(); + + let ret = tokio::time::timeout(tokio::time::Duration::from_secs(1), recv.next()) + .await + .unwrap() + .unwrap() + .unwrap(); + println!("echo back: {:?}", ret); + assert_eq!(ret.payload(), Bytes::from("12345678abcdefg")); + + drop(send); + + if ["udp", "wg"].contains(&connector.remote_url().scheme()) { + lis.abort(); + } else { + // lis should finish in 1 second + let ret = tokio::time::timeout(tokio::time::Duration::from_secs(1), lis).await; + assert!(ret.is_ok()); + } + } + + pub(crate) async fn _tunnel_bench(mut listener: L, mut connector: C) + where + L: TunnelListener + Send + Sync + 'static, + C: TunnelConnector + Send + Sync + 'static, + { + listener.listen().await.unwrap(); + + let lis = tokio::spawn(async move { + let ret = listener.accept().await.unwrap(); + _tunnel_echo_server(ret, false).await + }); + + let tunnel = connector.connect().await.unwrap(); + + let (recv, mut send) = tunnel.split(); + + // prepare a 4k buffer with random data + let mut send_buf = BytesMut::new(); + for _ in 0..64 { + send_buf.put_i128(rand::random::()); + } + + let r = tokio::spawn(async move { + let now = Instant::now(); + let count = recv + .try_fold(0usize, |mut ret, _| async move { + ret += 1; + Ok(ret) + }) + .await + .unwrap(); + + println!( + "bps: {}", + (count / 1024) * 4 / now.elapsed().as_secs() as usize + ); + }); + + let now = Instant::now(); + while now.elapsed().as_secs() < 10 { + // send.feed(item) + let item = ZCPacket::new_with_payload(send_buf.as_ref()); + let _ = send.feed(item).await.unwrap(); + } + + drop(send); + drop(connector); + drop(tunnel); + + tracing::warn!("wait for recv to finish..."); + + let _ = tokio::join!(r); + + lis.abort(); + let _ = tokio::join!(lis); + } + + pub fn enable_log() { + let filter = tracing_subscriber::EnvFilter::builder() + .with_default_directive(tracing::level_filters::LevelFilter::TRACE.into()) + .from_env() + .unwrap() + .add_directive("tarpc=error".parse().unwrap()); + tracing_subscriber::fmt::fmt() + .pretty() + .with_env_filter(filter) + .init(); + } +} diff --git a/easytier/src/tunnel/filter.rs b/easytier/src/tunnel/filter.rs new file mode 100644 index 00000000..20c21033 --- /dev/null +++ b/easytier/src/tunnel/filter.rs @@ -0,0 +1,362 @@ +use std::{ + sync::Arc, + task::{Context, Poll}, +}; + +use crate::rpc::TunnelInfo; +use auto_impl::auto_impl; +use futures::{Sink, SinkExt, Stream, StreamExt}; + +use self::stats::Throughput; + +use super::*; + +#[auto_impl(Arc, Box)] +pub trait TunnelFilter: Send + Sync { + type FilterOutput; + + fn before_send(&self, data: SinkItem) -> Option { + Some(data) + } + + fn after_received(&self, data: StreamItem) -> Option { + match data { + Ok(v) => Some(Ok(v)), + Err(e) => Some(Err(e)), + } + } + + fn filter_output(&self) -> Self::FilterOutput; +} + +pub struct TunnelFilterChain { + a: A, + b: B, +} + +impl TunnelFilter for TunnelFilterChain +where + A: TunnelFilter, + B: TunnelFilter, +{ + type FilterOutput = (OA, OB); + fn before_send(&self, data: SinkItem) -> Option { + let data = self.a.before_send(data)?; + self.b.before_send(data) + } + fn after_received(&self, data: StreamItem) -> Option { + let data = self.b.after_received(data)?; + self.a.after_received(data) + } + fn filter_output(&self) -> Self::FilterOutput { + (self.a.filter_output(), self.b.filter_output()) + } +} + +impl TunnelFilterChain { + pub fn new(a: A, b: B) -> Self { + Self { a, b } + } + + pub fn chain(self, c: T) -> TunnelFilterChain { + TunnelFilterChain::new(self, c) + } +} + +pub struct EmptyFilter; +impl TunnelFilter for EmptyFilter { + type FilterOutput = (); + fn filter_output(&self) {} +} + +pub trait ToTunnelChain { + fn to_chain(self) -> TunnelFilterChain + where + Self: Sized, + { + TunnelFilterChain::new(EmptyFilter, self) + } +} + +impl> ToTunnelChain for T {} + +pub struct TunnelWithFilter { + inner: T, + filter: Arc, +} + +impl TunnelWithFilter +where + T: Tunnel + Send + 'static, + F: TunnelFilter + Send + 'static, +{ + pub fn new(inner: T, filter: F) -> Self { + Self { + inner, + filter: Arc::new(filter), + } + } + + fn wrap_sink(&self, sink: S) -> impl ZCPacketSink { + struct SinkWrapper { + sink: S, + filter: Arc, + } + + impl Sink for SinkWrapper + where + F: TunnelFilter + 'static, + S: ZCPacketSink + 'static + Unpin, + { + type Error = SinkError; + + fn poll_ready( + self: std::pin::Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll> { + self.get_mut().sink.poll_ready_unpin(cx) + } + + fn start_send( + self: std::pin::Pin<&mut Self>, + item: ZCPacket, + ) -> Result<(), Self::Error> { + let Some(item) = self.filter.before_send(item) else { + return Ok(()); + }; + self.get_mut().sink.start_send_unpin(item) + } + + fn poll_flush( + self: std::pin::Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll> { + self.get_mut().sink.poll_flush_unpin(cx) + } + + fn poll_close( + self: std::pin::Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll> { + self.get_mut().sink.poll_close_unpin(cx) + } + } + + SinkWrapper { + sink, + filter: self.filter.clone(), + } + } + + fn wrap_stream(&self, stream: S) -> impl ZCPacketStream { + struct StreamWrapper { + stream: S, + filter: Arc, + } + + impl Stream for StreamWrapper + where + F: TunnelFilter + 'static, + S: ZCPacketStream + 'static + Unpin, + { + type Item = StreamItem; + + fn poll_next( + self: std::pin::Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll> { + let self_mut = self.get_mut(); + loop { + match self_mut.stream.poll_next_unpin(cx) { + Poll::Ready(Some(ret)) => { + let Some(ret) = self_mut.filter.after_received(ret) else { + continue; + }; + return Poll::Ready(Some(ret)); + } + Poll::Ready(None) => { + return Poll::Ready(None); + } + Poll::Pending => { + return Poll::Pending; + } + } + } + } + } + + StreamWrapper { + stream, + filter: self.filter.clone(), + } + } +} + +impl Tunnel for TunnelWithFilter +where + T: Tunnel + Send + 'static, + F: TunnelFilter + Send + 'static, +{ + fn info(&self) -> Option { + self.inner.info() + } + + fn split(&self) -> (Pin>, Pin>) { + let (stream, sink) = self.inner.split(); + ( + Box::pin(self.wrap_stream(stream)), + Box::pin(self.wrap_sink(sink)), + ) + } +} + +pub struct PacketRecorderTunnelFilter { + pub received: Arc>>, + pub sent: Arc>>, +} + +impl TunnelFilter for PacketRecorderTunnelFilter { + type FilterOutput = (Vec, Vec); + + fn before_send(&self, data: SinkItem) -> Option { + self.received.lock().unwrap().push(data.clone()); + Some(data) + } + + fn after_received(&self, data: StreamItem) -> Option { + match data { + Ok(v) => { + self.sent.lock().unwrap().push(v.clone().into()); + Some(Ok(v)) + } + Err(e) => Some(Err(e)), + } + } + + fn filter_output(&self) -> Self::FilterOutput { + ( + self.received.lock().unwrap().clone(), + self.sent.lock().unwrap().clone(), + ) + } +} + +impl PacketRecorderTunnelFilter { + pub fn new() -> Self { + Self { + received: Arc::new(std::sync::Mutex::new(Vec::new())), + sent: Arc::new(std::sync::Mutex::new(Vec::new())), + } + } +} + +pub struct StatsRecorderTunnelFilter { + throughput: Arc, +} + +impl TunnelFilter for StatsRecorderTunnelFilter { + type FilterOutput = Arc; + + fn before_send(&self, data: SinkItem) -> Option { + self.throughput.record_tx_bytes(data.buf_len() as u64); + Some(data) + } + + fn after_received(&self, data: StreamItem) -> Option { + match data { + Ok(v) => { + self.throughput.record_rx_bytes(v.buf_len() as u64); + Some(Ok(v)) + } + Err(e) => Some(Err(e)), + } + } + + fn filter_output(&self) -> Self::FilterOutput { + self.throughput.clone() + } +} + +impl StatsRecorderTunnelFilter { + pub fn new() -> Self { + Self { + throughput: Arc::new(Throughput::new()), + } + } + + pub fn get_throughput(&self) -> Arc { + self.throughput.clone() + } +} + +#[cfg(test)] +pub mod tests { + use std::sync::atomic::{AtomicU32, Ordering}; + + use filter::ring::create_ring_tunnel_pair; + + use super::*; + + pub struct DropSendTunnelFilter { + start: AtomicU32, + end: AtomicU32, + cur: AtomicU32, + } + + impl TunnelFilter for DropSendTunnelFilter { + type FilterOutput = (); + + fn before_send(&self, data: SinkItem) -> Option { + self.cur.fetch_add(1, Ordering::SeqCst); + if self.cur.load(Ordering::SeqCst) >= self.start.load(Ordering::SeqCst) + && self.cur.load(std::sync::atomic::Ordering::SeqCst) + < self.end.load(Ordering::SeqCst) + { + tracing::trace!("drop packet: {:?}", data); + return None; + } + Some(data) + } + + fn filter_output(&self) {} + } + + impl DropSendTunnelFilter { + pub fn new(start: u32, end: u32) -> Self { + Self { + start: AtomicU32::new(start), + end: AtomicU32::new(end), + cur: AtomicU32::new(0), + } + } + } + + #[tokio::test] + async fn test_nested_filter() { + let filter = Arc::new( + PacketRecorderTunnelFilter::new() + .to_chain() + .chain(PacketRecorderTunnelFilter::new()) + .chain(PacketRecorderTunnelFilter::new()) + .chain(PacketRecorderTunnelFilter::new()), + ); + let (s, _b) = create_ring_tunnel_pair(); + let tunnel = TunnelWithFilter::new(s, filter.clone()); + + let (_r, mut s) = tunnel.split(); + s.send(ZCPacket::new_with_payload("ab".as_bytes())) + .await + .unwrap(); + + let out = filter.filter_output(); + + let a = out.0 .0 .0 .1; + let b = out.0 .0 .1; + let c = out.0 .1; + let _d = out.1; + + assert_eq!(1, a.0.len()); + assert_eq!(1, b.0.len()); + assert_eq!(1, c.0.len()); + } +} diff --git a/easytier/src/tunnel/mod.rs b/easytier/src/tunnel/mod.rs new file mode 100644 index 00000000..51c27ef7 --- /dev/null +++ b/easytier/src/tunnel/mod.rs @@ -0,0 +1,196 @@ +use std::{net::SocketAddr, pin::Pin, sync::Arc}; + +use async_trait::async_trait; +use futures::{Sink, Stream}; +use std::fmt::Debug; + +use tokio::time::error::Elapsed; + +use crate::rpc::TunnelInfo; + +use self::packet_def::ZCPacket; + +pub mod buf; +pub mod common; +pub mod filter; +pub mod mpsc; +pub mod packet_def; +pub mod quic; +pub mod ring; +pub mod stats; +pub mod tcp; +pub mod udp; +pub mod wireguard; + +#[derive(thiserror::Error, Debug)] +pub enum TunnelError { + #[error("io error")] + IOError(#[from] std::io::Error), + #[error("invalid packet. msg: {0}")] + InvalidPacket(String), + #[error("exceed max packet size. max: {0}, input: {1}")] + ExceedMaxPacketSize(usize, usize), + + #[error("invalid protocol: {0}")] + InvalidProtocol(String), + #[error("invalid addr: {0}")] + InvalidAddr(String), + + #[error("internal error {0}")] + InternalError(String), + + #[error("conn id not match, expect: {0}, actual: {1}")] + ConnIdNotMatch(u32, u32), + #[error("buffer full")] + BufferFull, + + #[error("timeout")] + Timeout(#[from] Elapsed), + + #[error("anyhow error: {0}")] + Anyhow(#[from] anyhow::Error), + + #[error("shutdown")] + Shutdown, +} + +pub type StreamT = packet_def::ZCPacket; +pub type StreamItem = Result; +pub type SinkItem = packet_def::ZCPacket; +pub type SinkError = TunnelError; + +pub trait ZCPacketStream: Stream + Send {} +impl ZCPacketStream for T where T: Stream + Send {} +pub trait ZCPacketSink: Sink + Send {} +impl ZCPacketSink for T where T: Sink + Send {} + +#[auto_impl::auto_impl(Box, Arc)] +pub trait Tunnel: Send { + fn split(&self) -> (Pin>, Pin>); + fn info(&self) -> Option; +} + +#[auto_impl::auto_impl(Arc)] +pub trait TunnelConnCounter: 'static + Send + Sync + Debug { + fn get(&self) -> u32; +} + +#[async_trait] +#[auto_impl::auto_impl(Box)] +pub trait TunnelListener: Send { + async fn listen(&mut self) -> Result<(), TunnelError>; + async fn accept(&mut self) -> Result, TunnelError>; + fn local_url(&self) -> url::Url; + fn get_conn_counter(&self) -> Arc> { + #[derive(Debug)] + struct FakeTunnelConnCounter {} + impl TunnelConnCounter for FakeTunnelConnCounter { + fn get(&self) -> u32 { + 0 + } + } + Arc::new(Box::new(FakeTunnelConnCounter {})) + } +} + +#[async_trait] +#[auto_impl::auto_impl(Box)] +pub trait TunnelConnector: Send { + async fn connect(&mut self) -> Result, TunnelError>; + fn remote_url(&self) -> url::Url; + fn set_bind_addrs(&mut self, _addrs: Vec) {} +} + +pub fn build_url_from_socket_addr(addr: &String, scheme: &str) -> url::Url { + url::Url::parse(format!("{}://{}", scheme, addr).as_str()).unwrap() +} + +impl std::fmt::Debug for dyn Tunnel { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("Tunnel") + .field("info", &self.info()) + .finish() + } +} + +impl std::fmt::Debug for dyn TunnelConnector { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("TunnelConnector") + .field("remote_url", &self.remote_url()) + .finish() + } +} + +impl std::fmt::Debug for dyn TunnelListener { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("TunnelListener") + .field("local_url", &self.local_url()) + .finish() + } +} + +pub(crate) trait FromUrl { + fn from_url(url: url::Url) -> Result + where + Self: Sized; +} + +pub(crate) fn check_scheme_and_get_socket_addr( + url: &url::Url, + scheme: &str, +) -> Result +where + T: FromUrl, +{ + if url.scheme() != scheme { + return Err(TunnelError::InvalidProtocol(url.scheme().to_string())); + } + + Ok(T::from_url(url.clone())?) +} + +impl FromUrl for SocketAddr { + fn from_url(url: url::Url) -> Result { + Ok(url.socket_addrs(|| None)?.pop().unwrap()) + } +} + +impl FromUrl for uuid::Uuid { + fn from_url(url: url::Url) -> Result { + let o = url.host_str().unwrap(); + let o = uuid::Uuid::parse_str(o).map_err(|e| TunnelError::InvalidAddr(e.to_string()))?; + Ok(o) + } +} + +pub struct TunnelUrl { + inner: url::Url, +} + +impl From for TunnelUrl { + fn from(url: url::Url) -> Self { + TunnelUrl { inner: url } + } +} + +impl From for url::Url { + fn from(url: TunnelUrl) -> Self { + url.into_inner() + } +} + +impl TunnelUrl { + pub fn into_inner(self) -> url::Url { + self.inner + } + + pub fn bind_dev(&self) -> Option { + self.inner.path().strip_prefix("/").and_then(|s| { + if s.is_empty() { + None + } else { + Some(String::from_utf8(percent_encoding::percent_decode_str(&s).collect()).unwrap()) + } + }) + } +} diff --git a/easytier/src/tunnel/mpsc.rs b/easytier/src/tunnel/mpsc.rs new file mode 100644 index 00000000..e587b729 --- /dev/null +++ b/easytier/src/tunnel/mpsc.rs @@ -0,0 +1,180 @@ +// this mod wrap tunnel to a mpsc tunnel, based on crossbeam_channel + +use std::pin::Pin; + +use anyhow::Context; +use tokio::task::JoinHandle; + +use super::{packet_def::ZCPacket, Tunnel, TunnelError, ZCPacketSink, ZCPacketStream}; + +use tachyonix::{channel, Receiver, Sender}; + +use futures::SinkExt; + +#[derive(Clone)] +pub struct MpscTunnelSender(Sender); + +impl MpscTunnelSender { + pub async fn send(&self, item: ZCPacket) -> Result<(), TunnelError> { + self.0.send(item).await.with_context(|| "send error")?; + Ok(()) + } +} + +pub struct MpscTunnel { + tx: Sender, + + tunnel: T, + stream: Option>>, + + task: Option>, +} + +impl MpscTunnel { + pub fn new(tunnel: T) -> Self { + let (tx, mut rx) = channel(32); + let (stream, mut sink) = tunnel.split(); + + let task = tokio::spawn(async move { + loop { + if let Err(e) = Self::forward_one_round(&mut rx, &mut sink).await { + tracing::error!(?e, "forward error"); + break; + } + } + }); + + Self { + tx, + tunnel, + stream: Some(stream), + task: Some(task), + } + } + + async fn forward_one_round( + rx: &mut Receiver, + sink: &mut Pin>, + ) -> Result<(), TunnelError> { + let item = rx.recv().await.with_context(|| "recv error")?; + sink.feed(item).await?; + while let Ok(item) = rx.try_recv() { + if let Err(e) = sink.feed(item).await { + tracing::error!(?e, "feed error"); + break; + } + } + sink.flush().await + } + + pub fn get_stream(&mut self) -> Pin> { + self.stream.take().unwrap() + } + + pub fn get_sink(&self) -> MpscTunnelSender { + MpscTunnelSender(self.tx.clone()) + } +} + +impl From for MpscTunnel { + fn from(tunnel: T) -> Self { + Self::new(tunnel) + } +} + +#[cfg(test)] +mod tests { + use futures::StreamExt; + + use crate::tunnel::{ + tcp::{TcpTunnelConnector, TcpTunnelListener}, + TunnelConnector, TunnelListener, + }; + + use super::*; + // test slow send lock in framed tunnel + #[tokio::test] + async fn mpsc_slow_receiver() { + let mut listener = TcpTunnelListener::new("tcp://127.0.0.1:11014".parse().unwrap()); + let mut connector = TcpTunnelConnector::new("tcp://127.0.0.1:11014".parse().unwrap()); + + listener.listen().await.unwrap(); + let t1 = tokio::spawn(async move { + let t = listener.accept().await.unwrap(); + let (mut stream, _sink) = t.split(); + let now = tokio::time::Instant::now(); + + let mut a_counter = 0; + let mut b_counter = 0; + + while let Some(Ok(msg)) = stream.next().await { + tokio::time::sleep(tokio::time::Duration::from_millis(100)).await; + if now.elapsed().as_secs() > 5 { + break; + } + + if msg.payload() == "hello".as_bytes() { + a_counter += 1; + } else if msg.payload() == "hello2".as_bytes() { + b_counter += 1; + } + } + + tracing::info!("t1 exit"); + assert_ne!(a_counter, 0); + assert_ne!(b_counter, 0); + }); + + let tunnel = connector.connect().await.unwrap(); + let mpsc_tunnel = MpscTunnel::from(tunnel); + + let sink1 = mpsc_tunnel.get_sink(); + let t2 = tokio::spawn(async move { + for i in 0..1000000 { + tokio::time::sleep(tokio::time::Duration::from_millis(50)).await; + let a = sink1 + .send(ZCPacket::new_with_payload("hello".as_bytes())) + .await; + if a.is_err() { + tracing::info!(?a, "t2 exit with err"); + break; + } + + if i % 5000 == 0 { + tracing::info!(i, "send2 1000"); + } + } + + tracing::info!("t2 exit"); + }); + + let sink2 = mpsc_tunnel.get_sink(); + let t3 = tokio::spawn(async move { + for i in 0..1000000 { + tokio::time::sleep(tokio::time::Duration::from_millis(100)).await; + let a = sink2 + .send(ZCPacket::new_with_payload("hello2".as_bytes())) + .await; + if a.is_err() { + tracing::info!(?a, "t3 exit with err"); + break; + } + + if i % 5000 == 0 { + tracing::info!(i, "send2 1000"); + } + } + + tracing::info!("t3 exit"); + }); + + let t4 = tokio::spawn(async move { + tokio::time::sleep(tokio::time::Duration::from_secs(5)).await; + tracing::info!("closing"); + drop(mpsc_tunnel); + tracing::info!("closed"); + }); + + let _ = tokio::join!(t1, t2, t3, t4); + } +} diff --git a/easytier/src/tunnel/packet_def.rs b/easytier/src/tunnel/packet_def.rs new file mode 100644 index 00000000..8084a053 --- /dev/null +++ b/easytier/src/tunnel/packet_def.rs @@ -0,0 +1,340 @@ +use bytes::Bytes; +use bytes::BytesMut; +use zerocopy::byteorder::*; +use zerocopy::AsBytes; +use zerocopy::FromBytes; +use zerocopy::FromZeroes; + +type DefaultEndian = LittleEndian; + +// TCP TunnelHeader +#[repr(C, packed)] +#[derive(AsBytes, FromBytes, FromZeroes, Clone, Debug, Default)] +pub struct TCPTunnelHeader { + pub len: U32, +} +pub const TCP_TUNNEL_HEADER_SIZE: usize = std::mem::size_of::(); + +#[derive(AsBytes, FromZeroes, Clone, Debug)] +#[repr(u8)] +pub enum UdpPacketType { + Invalid = 0, + Syn = 1, + Sack = 2, + Data = 3, + Fin = 4, + HolePunch = 5, +} + +#[repr(C, packed)] +#[derive(AsBytes, FromBytes, FromZeroes, Clone, Debug, Default)] +pub struct UDPTunnelHeader { + pub conn_id: U32, + pub msg_type: u8, + pub padding: u8, + pub len: U16, +} +pub const UDP_TUNNEL_HEADER_SIZE: usize = std::mem::size_of::(); + +#[repr(C, packed)] +#[derive(AsBytes, FromBytes, FromZeroes, Clone, Debug, Default)] +pub struct WGTunnelHeader { + pub ipv4_header: [u8; 20], +} +pub const WG_TUNNEL_HEADER_SIZE: usize = std::mem::size_of::(); + +#[derive(AsBytes, FromZeroes, Clone, Debug)] +#[repr(u8)] +pub enum PacketType { + Invalid = 0, + Data = 1, + HandShake = 2, + RoutePacket = 3, + Ping = 4, + Pong = 5, + TaRpc = 6, + Route = 7, +} + +#[repr(C, packed)] +#[derive(AsBytes, FromBytes, FromZeroes, Clone, Debug, Default)] +pub struct PeerManagerHeader { + pub from_peer_id: U32, + pub to_peer_id: U32, + pub packet_type: u8, + pub len: U32, +} +pub const PEER_MANAGER_HEADER_SIZE: usize = std::mem::size_of::(); + +const fn max(a: usize, b: usize) -> usize { + [a, b][(a < b) as usize] +} + +#[derive(Default, Debug)] +pub struct ZCPacketOffsets { + pub payload_offset: usize, + pub peer_manager_header_offset: usize, + pub tcp_tunnel_header_offset: usize, + pub udp_tunnel_header_offset: usize, + pub wg_tunnel_header_offset: usize, +} + +#[derive(Debug, Clone, Copy, PartialEq)] +pub enum ZCPacketType { + // received from peer tcp connection + TCP, + // received from peer udp connection + UDP, + // received from peer wireguard connection + WG, + // received from local tun device, should reserve header space for tcp or udp tunnel + NIC, +} + +const PAYLOAD_OFFSET_FOR_NIC_PACKET: usize = max( + max(TCP_TUNNEL_HEADER_SIZE, UDP_TUNNEL_HEADER_SIZE), + WG_TUNNEL_HEADER_SIZE, +) + PEER_MANAGER_HEADER_SIZE; + +impl ZCPacketType { + pub fn get_packet_offsets(&self) -> ZCPacketOffsets { + match self { + ZCPacketType::TCP => ZCPacketOffsets { + payload_offset: TCP_TUNNEL_HEADER_SIZE + PEER_MANAGER_HEADER_SIZE, + peer_manager_header_offset: TCP_TUNNEL_HEADER_SIZE, + ..Default::default() + }, + ZCPacketType::UDP => ZCPacketOffsets { + payload_offset: UDP_TUNNEL_HEADER_SIZE + PEER_MANAGER_HEADER_SIZE, + peer_manager_header_offset: UDP_TUNNEL_HEADER_SIZE, + ..Default::default() + }, + ZCPacketType::WG => ZCPacketOffsets { + payload_offset: WG_TUNNEL_HEADER_SIZE + PEER_MANAGER_HEADER_SIZE, + peer_manager_header_offset: WG_TUNNEL_HEADER_SIZE, + ..Default::default() + }, + ZCPacketType::NIC => ZCPacketOffsets { + payload_offset: PAYLOAD_OFFSET_FOR_NIC_PACKET, + peer_manager_header_offset: PAYLOAD_OFFSET_FOR_NIC_PACKET + - PEER_MANAGER_HEADER_SIZE, + tcp_tunnel_header_offset: PAYLOAD_OFFSET_FOR_NIC_PACKET + - PEER_MANAGER_HEADER_SIZE + - TCP_TUNNEL_HEADER_SIZE, + udp_tunnel_header_offset: PAYLOAD_OFFSET_FOR_NIC_PACKET + - PEER_MANAGER_HEADER_SIZE + - UDP_TUNNEL_HEADER_SIZE, + wg_tunnel_header_offset: PAYLOAD_OFFSET_FOR_NIC_PACKET + - PEER_MANAGER_HEADER_SIZE + - WG_TUNNEL_HEADER_SIZE, + }, + } + } +} + +#[derive(Debug, Clone)] +pub struct ZCPacket { + inner: BytesMut, + packet_type: ZCPacketType, +} + +impl ZCPacket { + pub fn new_nic_packet() -> Self { + Self { + inner: BytesMut::new(), + packet_type: ZCPacketType::NIC, + } + } + + pub fn new_from_buf(buf: BytesMut, packet_type: ZCPacketType) -> Self { + Self { + inner: buf, + packet_type, + } + } + + pub fn new_with_payload(payload: &[u8]) -> Self { + let mut ret = Self::new_nic_packet(); + let total_len = ret.packet_type.get_packet_offsets().payload_offset + payload.len(); + ret.inner.resize(total_len, 0); + ret.mut_payload()[..payload.len()].copy_from_slice(&payload); + ret + } + + pub fn packet_type(&self) -> ZCPacketType { + self.packet_type + } + + pub fn mut_payload(&mut self) -> &mut [u8] { + &mut self.inner[self.packet_type.get_packet_offsets().payload_offset..] + } + + pub fn mut_peer_manager_header(&mut self) -> Option<&mut PeerManagerHeader> { + PeerManagerHeader::mut_from_prefix( + &mut self.inner[self + .packet_type + .get_packet_offsets() + .peer_manager_header_offset..], + ) + } + + pub fn mut_tcp_tunnel_header(&mut self) -> Option<&mut TCPTunnelHeader> { + TCPTunnelHeader::mut_from_prefix( + &mut self.inner[self + .packet_type + .get_packet_offsets() + .tcp_tunnel_header_offset..], + ) + } + + pub fn mut_udp_tunnel_header(&mut self) -> Option<&mut UDPTunnelHeader> { + UDPTunnelHeader::mut_from_prefix( + &mut self.inner[self + .packet_type + .get_packet_offsets() + .udp_tunnel_header_offset..], + ) + } + + pub fn mut_wg_tunnel_header(&mut self) -> Option<&mut WGTunnelHeader> { + WGTunnelHeader::mut_from_prefix( + &mut self.inner[self + .packet_type + .get_packet_offsets() + .wg_tunnel_header_offset..], + ) + } + + // ref versions + pub fn payload(&self) -> &[u8] { + &self.inner[self.packet_type.get_packet_offsets().payload_offset..] + } + + pub fn peer_manager_header(&self) -> Option<&PeerManagerHeader> { + PeerManagerHeader::ref_from_prefix( + &self.inner[self + .packet_type + .get_packet_offsets() + .peer_manager_header_offset..], + ) + } + + pub fn tcp_tunnel_header(&self) -> Option<&TCPTunnelHeader> { + TCPTunnelHeader::ref_from_prefix( + &self.inner[self + .packet_type + .get_packet_offsets() + .tcp_tunnel_header_offset..], + ) + } + + pub fn udp_tunnel_header(&self) -> Option<&UDPTunnelHeader> { + UDPTunnelHeader::ref_from_prefix( + &self.inner[self + .packet_type + .get_packet_offsets() + .udp_tunnel_header_offset..], + ) + } + + pub fn udp_payload(&self) -> &[u8] { + &self.inner[self + .packet_type + .get_packet_offsets() + .udp_tunnel_header_offset + + UDP_TUNNEL_HEADER_SIZE..] + } + + pub fn payload_len(&self) -> usize { + let payload_offset = self.packet_type.get_packet_offsets().payload_offset; + self.inner.len() - payload_offset + } + + pub fn buf_len(&self) -> usize { + self.inner.len() + } + + pub fn fill_peer_manager_hdr(&mut self, from_peer_id: u32, to_peer_id: u32, packet_type: u8) { + let payload_len = self.payload_len(); + let hdr = self.mut_peer_manager_header().unwrap(); + hdr.from_peer_id.set(from_peer_id); + hdr.to_peer_id.set(to_peer_id); + hdr.packet_type = packet_type; + hdr.len.set(payload_len as u32); + } + + pub fn into_bytes(mut self, target_packet_type: ZCPacketType) -> Bytes { + if target_packet_type == self.packet_type { + return self.inner.freeze(); + } else { + assert_eq!( + self.packet_type, + ZCPacketType::NIC, + "only support NIC, got {:?}", + self + ); + } + + match target_packet_type { + ZCPacketType::TCP => self + .inner + .split_off( + self.packet_type + .get_packet_offsets() + .tcp_tunnel_header_offset, + ) + .freeze(), + ZCPacketType::UDP => self + .inner + .split_off( + self.packet_type + .get_packet_offsets() + .udp_tunnel_header_offset, + ) + .freeze(), + ZCPacketType::WG => self + .inner + .split_off( + self.packet_type + .get_packet_offsets() + .wg_tunnel_header_offset, + ) + .freeze(), + ZCPacketType::NIC => unreachable!(), + } + } + + pub fn inner(self) -> BytesMut { + self.inner + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_zc_packet() { + let payload = b"hello world"; + let mut packet = ZCPacket::new_with_payload(payload); + let peer_manager_header = packet.mut_peer_manager_header().unwrap(); + peer_manager_header.packet_type = PacketType::Data as u8; + peer_manager_header.len.set(payload.len() as u32); + + let tcp_tunnel_header = packet.mut_tcp_tunnel_header().unwrap(); + tcp_tunnel_header.len.set(payload.len() as u32); + + // let udp_tunnel_header = packet.mut_udp_tunnel_header().unwrap(); + // udp_tunnel_header.conn_id = 1; + // udp_tunnel_header.msg_type = 2; + // udp_tunnel_header.len = payload.len() as u32; + + assert_eq!(packet.payload(), b"hello world"); + assert_eq!(packet.payload_len(), 11); + println!("{:?}", packet.inner); + + let tcp_packet = packet.into_bytes(ZCPacketType::TCP); + assert_eq!(&tcp_packet[..1], b"\x0b"); + println!("{:?}", tcp_packet); + } +} diff --git a/easytier/src/tunnel/quic.rs b/easytier/src/tunnel/quic.rs new file mode 100644 index 00000000..a06bbf82 --- /dev/null +++ b/easytier/src/tunnel/quic.rs @@ -0,0 +1,226 @@ +//! This example demonstrates how to make a QUIC connection that ignores the server certificate. +//! +//! Checkout the `README.md` for guidance. + +use std::{error::Error, net::SocketAddr, sync::Arc}; + +use crate::{ + rpc::TunnelInfo, + tunnel::common::{FramedReader, FramedWriter, TunnelWrapper}, +}; +use anyhow::Context; +use quinn::{ClientConfig, Connection, Endpoint, ServerConfig}; + +use super::{ + check_scheme_and_get_socket_addr, Tunnel, TunnelConnector, TunnelError, TunnelListener, +}; + +/// Dummy certificate verifier that treats any certificate as valid. +/// NOTE, such verification is vulnerable to MITM attacks, but convenient for testing. +struct SkipServerVerification; + +impl SkipServerVerification { + fn new() -> Arc { + Arc::new(Self) + } +} + +impl rustls::client::ServerCertVerifier for SkipServerVerification { + fn verify_server_cert( + &self, + _end_entity: &rustls::Certificate, + _intermediates: &[rustls::Certificate], + _server_name: &rustls::ServerName, + _scts: &mut dyn Iterator, + _ocsp_response: &[u8], + _now: std::time::SystemTime, + ) -> Result { + Ok(rustls::client::ServerCertVerified::assertion()) + } +} + +fn configure_client() -> ClientConfig { + let crypto = rustls::ClientConfig::builder() + .with_safe_defaults() + .with_custom_certificate_verifier(SkipServerVerification::new()) + .with_no_client_auth(); + + ClientConfig::new(Arc::new(crypto)) +} + +/// Constructs a QUIC endpoint configured to listen for incoming connections on a certain address +/// and port. +/// +/// ## Returns +/// +/// - a stream of incoming QUIC connections +/// - server certificate serialized into DER format +#[allow(unused)] +pub fn make_server_endpoint(bind_addr: SocketAddr) -> Result<(Endpoint, Vec), Box> { + let (server_config, server_cert) = configure_server()?; + let endpoint = Endpoint::server(server_config, bind_addr)?; + Ok((endpoint, server_cert)) +} + +/// Returns default server configuration along with its certificate. +fn configure_server() -> Result<(ServerConfig, Vec), Box> { + let cert = rcgen::generate_simple_self_signed(vec!["localhost".into()]).unwrap(); + let cert_der = cert.serialize_der().unwrap(); + let priv_key = cert.serialize_private_key_der(); + let priv_key = rustls::PrivateKey(priv_key); + let cert_chain = vec![rustls::Certificate(cert_der.clone())]; + + let mut server_config = ServerConfig::with_single_cert(cert_chain, priv_key)?; + let transport_config = Arc::get_mut(&mut server_config.transport).unwrap(); + transport_config.max_concurrent_uni_streams(10_u8.into()); + transport_config.max_concurrent_bidi_streams(10_u8.into()); + + Ok((server_config, cert_der)) +} + +#[allow(unused)] +pub const ALPN_QUIC_HTTP: &[&[u8]] = &[b"hq-29"]; + +/// Runs a QUIC server bound to given address. + +struct ConnWrapper { + conn: Connection, +} + +impl Drop for ConnWrapper { + fn drop(&mut self) { + self.conn.close(0u32.into(), b"done"); + } +} + +pub struct QUICTunnelListener { + addr: url::Url, + endpoint: Option, + server_cert: Option>, +} + +impl QUICTunnelListener { + pub fn new(addr: url::Url) -> Self { + QUICTunnelListener { + addr, + endpoint: None, + server_cert: None, + } + } +} + +#[async_trait::async_trait] +impl TunnelListener for QUICTunnelListener { + async fn listen(&mut self) -> Result<(), TunnelError> { + let addr = check_scheme_and_get_socket_addr::(&self.addr, "quic")?; + let (endpoint, server_cert) = make_server_endpoint(addr).unwrap(); + self.endpoint = Some(endpoint); + self.server_cert = Some(server_cert); + Ok(()) + } + + async fn accept(&mut self) -> Result, super::TunnelError> { + // accept a single connection + let incoming_conn = self.endpoint.as_ref().unwrap().accept().await.unwrap(); + let conn = incoming_conn.await.unwrap(); + println!( + "[server] connection accepted: addr={}", + conn.remote_address() + ); + let remote_addr = conn.remote_address(); + let (w, r) = conn.accept_bi().await.with_context(|| "accept_bi failed")?; + + let arc_conn = Arc::new(ConnWrapper { conn }); + + let info = TunnelInfo { + tunnel_type: "quic".to_owned(), + local_addr: self.local_url().into(), + remote_addr: super::build_url_from_socket_addr(&remote_addr.to_string(), "quic").into(), + }; + + Ok(Box::new(TunnelWrapper::new( + FramedReader::new_with_associate_data(r, 4500, Some(Box::new(arc_conn.clone()))), + FramedWriter::new_with_associate_data(w, Some(Box::new(arc_conn))), + Some(info), + ))) + } + + fn local_url(&self) -> url::Url { + self.addr.clone() + } +} + +pub struct QUICTunnelConnector { + addr: url::Url, + endpoint: Option, +} + +impl QUICTunnelConnector { + pub fn new(addr: url::Url) -> Self { + QUICTunnelConnector { + addr, + endpoint: None, + } + } +} + +#[async_trait::async_trait] +impl TunnelConnector for QUICTunnelConnector { + async fn connect(&mut self) -> Result, super::TunnelError> { + let addr = check_scheme_and_get_socket_addr::(&self.addr, "quic")?; + + let mut endpoint = Endpoint::client("127.0.0.1:0".parse().unwrap())?; + endpoint.set_default_client_config(configure_client()); + + // connect to server + let connection = endpoint.connect(addr, "localhost").unwrap().await.unwrap(); + println!("[client] connected: addr={}", connection.remote_address()); + + let local_addr = endpoint.local_addr().unwrap(); + + self.endpoint = Some(endpoint); + + let (w, r) = connection + .open_bi() + .await + .with_context(|| "open_bi failed")?; + + let info = TunnelInfo { + tunnel_type: "quic".to_owned(), + local_addr: super::build_url_from_socket_addr(&local_addr.to_string(), "quic").into(), + remote_addr: self.addr.to_string(), + }; + + let arc_conn = Arc::new(ConnWrapper { conn: connection }); + Ok(Box::new(TunnelWrapper::new( + FramedReader::new_with_associate_data(r, 4500, Some(Box::new(arc_conn.clone()))), + FramedWriter::new_with_associate_data(w, Some(Box::new(arc_conn))), + Some(info), + ))) + } + + fn remote_url(&self) -> url::Url { + self.addr.clone() + } +} + +#[cfg(test)] +mod tests { + use crate::tunnel::common::tests::{_tunnel_bench, _tunnel_pingpong}; + + use super::*; + + #[tokio::test] + async fn quic_pingpong() { + let listener = QUICTunnelListener::new("quic://0.0.0.0:21011".parse().unwrap()); + let connector = QUICTunnelConnector::new("quic://127.0.0.1:21011".parse().unwrap()); + _tunnel_pingpong(listener, connector).await + } + + #[tokio::test] + async fn quic_bench() { + let listener = QUICTunnelListener::new("quic://0.0.0.0:21012".parse().unwrap()); + let connector = QUICTunnelConnector::new("quic://127.0.0.1:21012".parse().unwrap()); + _tunnel_bench(listener, connector).await + } +} diff --git a/easytier/src/tunnel/ring.rs b/easytier/src/tunnel/ring.rs new file mode 100644 index 00000000..a70ac22e --- /dev/null +++ b/easytier/src/tunnel/ring.rs @@ -0,0 +1,427 @@ +use std::{ + collections::HashMap, + sync::{ + atomic::{AtomicBool, Ordering}, + Arc, + }, + task::{Poll, Waker}, +}; + +use atomicbox::AtomicOptionBox; +use crossbeam_queue::ArrayQueue; + +use async_trait::async_trait; +use futures::{Sink, Stream}; +use once_cell::sync::Lazy; + +use tokio::sync::{ + mpsc::{UnboundedReceiver, UnboundedSender}, + Mutex, +}; + +use uuid::Uuid; + +use crate::tunnel::{SinkError, SinkItem}; + +use super::{ + build_url_from_socket_addr, check_scheme_and_get_socket_addr, common::TunnelWrapper, + StreamItem, Tunnel, TunnelConnector, TunnelError, TunnelInfo, TunnelListener, +}; + +static RING_TUNNEL_CAP: usize = 128; + +#[derive(Debug)] +pub struct RingTunnel { + id: Uuid, + ring: ArrayQueue, + closed: AtomicBool, + + wait_for_new_item: AtomicOptionBox, + wait_for_empty_slot: AtomicOptionBox, +} + +impl RingTunnel { + fn wait_for_new_item(&self, cx: &mut std::task::Context<'_>) -> Poll { + let ret = self + .wait_for_new_item + .swap(Some(Box::new(cx.waker().clone())), Ordering::AcqRel); + if let Some(old_waker) = ret { + assert!(old_waker.will_wake(cx.waker())); + } + Poll::Pending + } + + fn wait_for_empty_slot(&self, cx: &mut std::task::Context<'_>) -> Poll { + let ret = self + .wait_for_empty_slot + .swap(Some(Box::new(cx.waker().clone())), Ordering::AcqRel); + if let Some(old_waker) = ret { + assert!(old_waker.will_wake(cx.waker())); + } + Poll::Pending + } + + fn notify_new_item(&self) { + if let Some(w) = self.wait_for_new_item.take(Ordering::AcqRel) { + tracing::trace!(?self.id, "notify new item"); + w.wake(); + } + } + + fn notify_empty_slot(&self) { + if let Some(w) = self.wait_for_empty_slot.take(Ordering::AcqRel) { + tracing::trace!(?self.id, "notify empty slot"); + w.wake(); + } + } + + fn id(&self) -> &Uuid { + &self.id + } + + pub fn len(&self) -> usize { + self.ring.len() + } + + pub fn capacity(&self) -> usize { + self.ring.capacity() + } + + fn close(&self) { + tracing::info!("close ring tunnel {:?}", self.id); + self.closed + .store(true, std::sync::atomic::Ordering::Relaxed); + self.notify_new_item(); + } + + fn closed(&self) -> bool { + self.closed.load(std::sync::atomic::Ordering::Relaxed) + } + + pub fn new(cap: usize) -> Self { + let id = Uuid::new_v4(); + Self { + id: id.clone(), + ring: ArrayQueue::new(cap), + closed: AtomicBool::new(false), + + wait_for_new_item: AtomicOptionBox::new(None), + wait_for_empty_slot: AtomicOptionBox::new(None), + } + } + + pub fn new_with_id(id: Uuid, cap: usize) -> Self { + let mut ret = Self::new(cap); + ret.id = id; + ret + } +} + +#[derive(Debug)] +pub struct RingStream { + tunnel: Arc, +} + +impl RingStream { + pub fn new(tunnel: Arc) -> Self { + Self { tunnel } + } +} + +impl Stream for RingStream { + type Item = StreamItem; + + fn poll_next( + self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> Poll> { + let s = self.get_mut(); + let ret = s.tunnel.ring.pop(); + match ret { + Some(v) => { + s.tunnel.notify_empty_slot(); + return Poll::Ready(Some(Ok(v))); + } + None => { + if s.tunnel.closed() { + tracing::warn!("ring recv tunnel {:?} closed", s.tunnel.id()); + return Poll::Ready(None); + } else { + tracing::trace!("waiting recv buffer, id: {}", s.tunnel.id()); + } + s.tunnel.wait_for_new_item(cx) + } + } + } +} + +#[derive(Debug)] +pub struct RingSink { + tunnel: Arc, +} + +impl Drop for RingSink { + fn drop(&mut self) { + self.tunnel.close(); + } +} + +impl RingSink { + pub fn new(tunnel: Arc) -> Self { + Self { tunnel } + } + + pub fn push_no_check(&self, item: SinkItem) -> Result<(), TunnelError> { + if self.tunnel.closed() { + return Err(TunnelError::Shutdown); + } + + log::trace!("id: {}, send buffer, buf: {:?}", self.tunnel.id(), &item); + self.tunnel.ring.push(item).unwrap(); + self.tunnel.notify_new_item(); + + Ok(()) + } + + pub fn has_empty_slot(&self) -> bool { + self.tunnel.len() < self.tunnel.capacity() + } +} + +impl Sink for RingSink { + type Error = SinkError; + + fn poll_ready( + self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + let self_mut = self.get_mut(); + if !self_mut.has_empty_slot() { + if self_mut.tunnel.closed() { + return Poll::Ready(Err(TunnelError::Shutdown)); + } + self_mut.tunnel.wait_for_empty_slot(cx) + } else { + Poll::Ready(Ok(())) + } + } + + fn start_send(self: std::pin::Pin<&mut Self>, item: SinkItem) -> Result<(), Self::Error> { + self.push_no_check(item) + } + + fn poll_flush( + self: std::pin::Pin<&mut Self>, + _cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + if self.tunnel.closed() { + return Poll::Ready(Err(TunnelError::Shutdown)); + } + Poll::Ready(Ok(())) + } + + fn poll_close( + self: std::pin::Pin<&mut Self>, + _cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + self.tunnel.close(); + Poll::Ready(Ok(())) + } +} + +struct Connection { + client: Arc, + server: Arc, +} + +static CONNECTION_MAP: Lazy>>>>> = + Lazy::new(|| Arc::new(Mutex::new(HashMap::new()))); + +#[derive(Debug)] +pub struct RingTunnelListener { + listerner_addr: url::Url, + conn_sender: UnboundedSender>, + conn_receiver: UnboundedReceiver>, +} + +impl RingTunnelListener { + pub fn new(key: url::Url) -> Self { + let (conn_sender, conn_receiver) = tokio::sync::mpsc::unbounded_channel(); + RingTunnelListener { + listerner_addr: key, + conn_sender, + conn_receiver, + } + } +} + +fn get_tunnel_for_client(conn: Arc) -> impl Tunnel { + TunnelWrapper::new( + RingStream::new(conn.client.clone()), + RingSink::new(conn.server.clone()), + Some(TunnelInfo { + tunnel_type: "ring".to_owned(), + local_addr: build_url_from_socket_addr(&conn.client.id.into(), "ring").into(), + remote_addr: build_url_from_socket_addr(&conn.server.id.into(), "ring").into(), + }), + ) +} + +fn get_tunnel_for_server(conn: Arc) -> impl Tunnel { + TunnelWrapper::new( + RingStream::new(conn.server.clone()), + RingSink::new(conn.client.clone()), + Some(TunnelInfo { + tunnel_type: "ring".to_owned(), + local_addr: build_url_from_socket_addr(&conn.server.id.into(), "ring").into(), + remote_addr: build_url_from_socket_addr(&conn.client.id.into(), "ring").into(), + }), + ) +} + +impl RingTunnelListener { + fn get_addr(&self) -> Result { + check_scheme_and_get_socket_addr::(&self.listerner_addr, "ring") + } +} + +#[async_trait] +impl TunnelListener for RingTunnelListener { + async fn listen(&mut self) -> Result<(), TunnelError> { + log::info!("listen new conn of key: {}", self.listerner_addr); + CONNECTION_MAP + .lock() + .await + .insert(self.get_addr()?, self.conn_sender.clone()); + Ok(()) + } + + async fn accept(&mut self) -> Result, TunnelError> { + log::info!("waiting accept new conn of key: {}", self.listerner_addr); + let my_addr = self.get_addr()?; + if let Some(conn) = self.conn_receiver.recv().await { + if conn.server.id == my_addr { + log::info!("accept new conn of key: {}", self.listerner_addr); + return Ok(Box::new(get_tunnel_for_server(conn))); + } else { + tracing::error!(?conn.server.id, ?my_addr, "got new conn with wrong id"); + return Err(TunnelError::InternalError( + "accept got wrong ring server id".to_owned(), + )); + } + } + + return Err(TunnelError::InternalError( + "conn receiver stopped".to_owned(), + )); + } + + fn local_url(&self) -> url::Url { + self.listerner_addr.clone() + } +} + +pub struct RingTunnelConnector { + remote_addr: url::Url, +} + +impl RingTunnelConnector { + pub fn new(remote_addr: url::Url) -> Self { + RingTunnelConnector { remote_addr } + } +} + +#[async_trait] +impl TunnelConnector for RingTunnelConnector { + async fn connect(&mut self) -> Result, super::TunnelError> { + let remote_addr = check_scheme_and_get_socket_addr::(&self.remote_addr, "ring")?; + let entry = CONNECTION_MAP + .lock() + .await + .get(&remote_addr) + .unwrap() + .clone(); + log::info!("connecting"); + let conn = Arc::new(Connection { + client: Arc::new(RingTunnel::new(RING_TUNNEL_CAP)), + server: Arc::new(RingTunnel::new_with_id( + remote_addr.clone(), + RING_TUNNEL_CAP, + )), + }); + entry + .send(conn.clone()) + .map_err(|_| TunnelError::InternalError("send conn to listner failed".to_owned()))?; + Ok(Box::new(get_tunnel_for_client(conn))) + } + + fn remote_url(&self) -> url::Url { + self.remote_addr.clone() + } +} + +pub fn create_ring_tunnel_pair() -> (Box, Box) { + let conn = Arc::new(Connection { + client: Arc::new(RingTunnel::new(RING_TUNNEL_CAP)), + server: Arc::new(RingTunnel::new(RING_TUNNEL_CAP)), + }); + ( + Box::new(get_tunnel_for_server(conn.clone())), + Box::new(get_tunnel_for_client(conn)), + ) +} + +#[cfg(test)] +mod tests { + use futures::StreamExt; + use tokio::time::timeout; + + use crate::tunnel::common::tests::{_tunnel_bench, _tunnel_pingpong}; + + use super::*; + + #[tokio::test] + async fn ring_pingpong() { + let id: url::Url = format!("ring://{}", Uuid::new_v4()).parse().unwrap(); + let listener = RingTunnelListener::new(id.clone()); + let connector = RingTunnelConnector::new(id.clone()); + _tunnel_pingpong(listener, connector).await + } + + #[tokio::test] + async fn ring_bench() { + let id: url::Url = format!("ring://{}", Uuid::new_v4()).parse().unwrap(); + let listener = RingTunnelListener::new(id.clone()); + let connector = RingTunnelConnector::new(id); + _tunnel_bench(listener, connector).await + } + + #[tokio::test] + async fn ring_close() { + let (stunnel, ctunnel) = create_ring_tunnel_pair(); + drop(stunnel); + + let mut stream = ctunnel.split().0; + let ret = stream.next().await; + assert!(ret.as_ref().is_none(), "expect none, got {:?}", ret); + } + + #[tokio::test] + async fn abort_ring_stream() { + let (_stunnel, ctunnel) = create_ring_tunnel_pair(); + let mut stream = ctunnel.split().0; + let task = tokio::spawn(async move { + let _ = stream.next().await; + }); + tokio::time::sleep(tokio::time::Duration::from_secs(1)).await; + task.abort(); + let _ = tokio::join!(task); + } + + #[tokio::test] + async fn ring_stream_recv_timeout() { + let (_stunnel, ctunnel) = create_ring_tunnel_pair(); + let mut stream = ctunnel.split().0; + let _ = timeout(tokio::time::Duration::from_millis(10), stream.next()).await; + } +} diff --git a/easytier/src/tunnel/stats.rs b/easytier/src/tunnel/stats.rs new file mode 100644 index 00000000..8e8d7a4b --- /dev/null +++ b/easytier/src/tunnel/stats.rs @@ -0,0 +1,95 @@ +use std::sync::atomic::{AtomicU32, AtomicU64, Ordering::Relaxed}; + +pub struct WindowLatency { + latency_us_window: Vec, + latency_us_window_index: AtomicU32, + latency_us_window_size: u32, + + sum: AtomicU32, + count: AtomicU32, +} + +impl WindowLatency { + pub fn new(window_size: u32) -> Self { + Self { + latency_us_window: (0..window_size).map(|_| AtomicU32::new(0)).collect(), + latency_us_window_index: AtomicU32::new(0), + latency_us_window_size: window_size, + + sum: AtomicU32::new(0), + count: AtomicU32::new(0), + } + } + + pub fn record_latency(&self, latency_us: u32) { + let index = self.latency_us_window_index.fetch_add(1, Relaxed); + if self.count.load(Relaxed) < self.latency_us_window_size { + self.count.fetch_add(1, Relaxed); + } + + let index = index % self.latency_us_window_size; + let old_lat = self.latency_us_window[index as usize].swap(latency_us, Relaxed); + + if old_lat < latency_us { + self.sum.fetch_add(latency_us - old_lat, Relaxed); + } else { + self.sum.fetch_sub(old_lat - latency_us, Relaxed); + } + } + + pub fn get_latency_us + std::ops::Div>(&self) -> T { + let count = self.count.load(Relaxed); + let sum = self.sum.load(Relaxed); + if count == 0 { + 0.into() + } else { + (T::from(sum)) / T::from(count) + } + } +} + +pub struct Throughput { + tx_bytes: AtomicU64, + rx_bytes: AtomicU64, + + tx_packets: AtomicU64, + rx_packets: AtomicU64, +} + +impl Throughput { + pub fn new() -> Self { + Self { + tx_bytes: AtomicU64::new(0), + rx_bytes: AtomicU64::new(0), + + tx_packets: AtomicU64::new(0), + rx_packets: AtomicU64::new(0), + } + } + + pub fn tx_bytes(&self) -> u64 { + self.tx_bytes.load(Relaxed) + } + + pub fn rx_bytes(&self) -> u64 { + self.rx_bytes.load(Relaxed) + } + + pub fn tx_packets(&self) -> u64 { + self.tx_packets.load(Relaxed) + } + + pub fn rx_packets(&self) -> u64 { + self.rx_packets.load(Relaxed) + } + + pub fn record_tx_bytes(&self, bytes: u64) { + self.tx_bytes.fetch_add(bytes, Relaxed); + self.tx_packets.fetch_add(1, Relaxed); + } + + pub fn record_rx_bytes(&self, bytes: u64) { + self.rx_bytes.fetch_add(bytes, Relaxed); + self.rx_packets.fetch_add(1, Relaxed); + } +} diff --git a/easytier/src/tunnel/tcp.rs b/easytier/src/tunnel/tcp.rs new file mode 100644 index 00000000..55f1224a --- /dev/null +++ b/easytier/src/tunnel/tcp.rs @@ -0,0 +1,200 @@ +use std::net::SocketAddr; + +use async_trait::async_trait; +use futures::stream::FuturesUnordered; +use tokio::net::{TcpListener, TcpSocket, TcpStream}; + +use crate::{rpc::TunnelInfo, tunnel::common::setup_sokcet2}; + +use super::{ + check_scheme_and_get_socket_addr, + common::{wait_for_connect_futures, FramedReader, FramedWriter, TunnelWrapper}, + Tunnel, TunnelError, TunnelListener, +}; + +const TCP_MTU_BYTES: usize = 64 * 1024; + +#[derive(Debug)] +pub struct TcpTunnelListener { + addr: url::Url, + listener: Option, +} + +impl TcpTunnelListener { + pub fn new(addr: url::Url) -> Self { + TcpTunnelListener { + addr, + listener: None, + } + } +} + +#[async_trait] +impl TunnelListener for TcpTunnelListener { + async fn listen(&mut self) -> Result<(), TunnelError> { + let addr = check_scheme_and_get_socket_addr::(&self.addr, "tcp")?; + + let socket = if addr.is_ipv4() { + TcpSocket::new_v4()? + } else { + TcpSocket::new_v6()? + }; + + socket.set_reuseaddr(true)?; + // #[cfg(all(unix, not(target_os = "solaris"), not(target_os = "illumos")))] + // socket.set_reuseport(true)?; + socket.bind(addr)?; + + self.listener = Some(socket.listen(1024)?); + Ok(()) + } + + async fn accept(&mut self) -> Result, super::TunnelError> { + let listener = self.listener.as_ref().unwrap(); + let (stream, _) = listener.accept().await?; + stream.set_nodelay(true).unwrap(); + let info = TunnelInfo { + tunnel_type: "tcp".to_owned(), + local_addr: self.local_url().into(), + remote_addr: super::build_url_from_socket_addr(&stream.peer_addr()?.to_string(), "tcp") + .into(), + }; + + let (r, w) = stream.into_split(); + Ok(Box::new(TunnelWrapper::new( + FramedReader::new(r, TCP_MTU_BYTES), + FramedWriter::new(w), + Some(info), + ))) + } + + fn local_url(&self) -> url::Url { + self.addr.clone() + } +} + +fn get_tunnel_with_tcp_stream( + stream: TcpStream, + remote_url: url::Url, +) -> Result, super::TunnelError> { + stream.set_nodelay(true).unwrap(); + + let info = TunnelInfo { + tunnel_type: "tcp".to_owned(), + local_addr: super::build_url_from_socket_addr(&stream.local_addr()?.to_string(), "tcp") + .into(), + remote_addr: remote_url.into(), + }; + + let (r, w) = stream.into_split(); + Ok(Box::new(TunnelWrapper::new( + FramedReader::new(r, TCP_MTU_BYTES), + FramedWriter::new(w), + Some(info), + ))) +} + +#[derive(Debug)] +pub struct TcpTunnelConnector { + addr: url::Url, + + bind_addrs: Vec, +} + +impl TcpTunnelConnector { + pub fn new(addr: url::Url) -> Self { + TcpTunnelConnector { + addr, + bind_addrs: vec![], + } + } + + async fn connect_with_default_bind(&mut self) -> Result, super::TunnelError> { + tracing::info!(addr = ?self.addr, "connect tcp start"); + let addr = check_scheme_and_get_socket_addr::(&self.addr, "tcp")?; + let stream = TcpStream::connect(addr).await?; + tracing::info!(addr = ?self.addr, "connect tcp succ"); + return get_tunnel_with_tcp_stream(stream, self.addr.clone().into()); + } + + async fn connect_with_custom_bind(&mut self) -> Result, super::TunnelError> { + let futures = FuturesUnordered::new(); + let dst_addr = check_scheme_and_get_socket_addr::(&self.addr, "tcp")?; + + for bind_addr in self.bind_addrs.iter() { + tracing::info!(bind_addr = ?bind_addr, ?dst_addr, "bind addr"); + + let socket2_socket = socket2::Socket::new( + socket2::Domain::for_address(dst_addr), + socket2::Type::STREAM, + Some(socket2::Protocol::TCP), + )?; + setup_sokcet2(&socket2_socket, bind_addr)?; + + let socket = TcpSocket::from_std_stream(socket2_socket.into()); + futures.push(socket.connect(dst_addr.clone())); + } + + let ret = wait_for_connect_futures(futures).await; + return get_tunnel_with_tcp_stream(ret?, self.addr.clone().into()); + } +} + +#[async_trait] +impl super::TunnelConnector for TcpTunnelConnector { + async fn connect(&mut self) -> Result, super::TunnelError> { + if self.bind_addrs.is_empty() { + self.connect_with_default_bind().await + } else { + self.connect_with_custom_bind().await + } + } + + fn remote_url(&self) -> url::Url { + self.addr.clone() + } + fn set_bind_addrs(&mut self, addrs: Vec) { + self.bind_addrs = addrs; + } +} + +#[cfg(test)] +mod tests { + use crate::tunnel::{ + common::tests::{_tunnel_bench, _tunnel_pingpong}, + TunnelConnector, + }; + + use super::*; + + #[tokio::test] + async fn tcp_pingpong() { + let listener = TcpTunnelListener::new("tcp://0.0.0.0:31011".parse().unwrap()); + let connector = TcpTunnelConnector::new("tcp://127.0.0.1:31011".parse().unwrap()); + _tunnel_pingpong(listener, connector).await + } + + #[tokio::test] + async fn tcp_bench() { + let listener = TcpTunnelListener::new("tcp://0.0.0.0:31012".parse().unwrap()); + let connector = TcpTunnelConnector::new("tcp://127.0.0.1:31012".parse().unwrap()); + _tunnel_bench(listener, connector).await + } + + #[tokio::test] + async fn tcp_bench_with_bind() { + let listener = TcpTunnelListener::new("tcp://127.0.0.1:11013".parse().unwrap()); + let mut connector = TcpTunnelConnector::new("tcp://127.0.0.1:11013".parse().unwrap()); + connector.set_bind_addrs(vec!["127.0.0.1:0".parse().unwrap()]); + _tunnel_pingpong(listener, connector).await + } + + #[tokio::test] + #[should_panic] + async fn tcp_bench_with_bind_fail() { + let listener = TcpTunnelListener::new("tcp://127.0.0.1:11014".parse().unwrap()); + let mut connector = TcpTunnelConnector::new("tcp://127.0.0.1:11014".parse().unwrap()); + connector.set_bind_addrs(vec!["10.0.0.1:0".parse().unwrap()]); + _tunnel_pingpong(listener, connector).await + } +} diff --git a/easytier/src/tunnel/udp.rs b/easytier/src/tunnel/udp.rs new file mode 100644 index 00000000..81ceab4b --- /dev/null +++ b/easytier/src/tunnel/udp.rs @@ -0,0 +1,838 @@ +use std::{fmt::Debug, sync::Arc}; + +use async_trait::async_trait; +use bytes::BytesMut; +use dashmap::DashMap; +use futures::{stream::FuturesUnordered, StreamExt}; +use rand::{Rng, SeedableRng}; + +use std::net::SocketAddr; +use tokio::{ + net::UdpSocket, + sync::mpsc::{Receiver, Sender, UnboundedReceiver, UnboundedSender}, + task::{JoinHandle, JoinSet}, +}; + +use tracing::{instrument, Instrument}; + +use crate::{ + common::join_joinset_background, + rpc::TunnelInfo, + tunnel::{ + common::{reserve_buf, TunnelWrapper}, + packet_def::{UdpPacketType, ZCPacket, ZCPacketType}, + ring::RingTunnel, + }, +}; + +use super::{ + common::{setup_sokcet2, setup_sokcet2_ext, wait_for_connect_futures}, + packet_def::{UDPTunnelHeader, UDP_TUNNEL_HEADER_SIZE}, + ring::{RingSink, RingStream}, + Tunnel, TunnelConnCounter, TunnelError, TunnelListener, TunnelUrl, +}; + +pub const UDP_DATA_MTU: usize = 65000; + +type UdpCloseEventSender = UnboundedSender>; +type UdpCloseEventReceiver = UnboundedReceiver>; + +fn new_udp_packet(f: F, udp_body: Option<&mut [u8]>) -> ZCPacket +where + F: FnOnce(&mut UDPTunnelHeader), +{ + let mut buf = BytesMut::new(); + buf.resize( + UDP_TUNNEL_HEADER_SIZE + udp_body.as_ref().map(|v| v.len()).unwrap_or(0), + 0, + ); + buf[UDP_TUNNEL_HEADER_SIZE..].copy_from_slice(udp_body.unwrap()); + + let mut ret = ZCPacket::new_from_buf(buf, ZCPacketType::UDP); + let header = ret.mut_udp_tunnel_header().unwrap(); + f(header); + ret +} + +fn new_syn_packet(conn_id: u32, magic: u64) -> ZCPacket { + new_udp_packet( + |header| { + header.msg_type = UdpPacketType::Syn as u8; + header.conn_id.set(conn_id); + header.len.set(8); + }, + Some(&mut magic.to_le_bytes()), + ) +} + +fn new_sack_packet(conn_id: u32, magic: u64) -> ZCPacket { + new_udp_packet( + |header| { + header.msg_type = UdpPacketType::Sack as u8; + header.conn_id.set(conn_id); + header.len.set(8); + }, + Some(&mut magic.to_le_bytes()), + ) +} + +pub fn new_hole_punch_packet() -> ZCPacket { + // generate a 128 bytes vec with random data + let mut rng = rand::rngs::StdRng::from_entropy(); + let mut buf = vec![0u8; 128]; + rng.fill(&mut buf[..]); + new_udp_packet( + |header| { + header.msg_type = UdpPacketType::HolePunch as u8; + header.conn_id.set(0); + header.len.set(0); + }, + Some(&mut buf), + ) +} + +fn get_zcpacket_from_buf(buf: BytesMut) -> Result { + let dg_size = buf.len(); + if dg_size < UDP_TUNNEL_HEADER_SIZE { + return Err(TunnelError::InvalidPacket(format!( + "udp packet size too small: {:?}, packet: {:?}", + dg_size, buf + ))); + } + + let zc_packet = ZCPacket::new_from_buf(buf, ZCPacketType::UDP); + let header = zc_packet.udp_tunnel_header().unwrap(); + let payload_len = header.len.get() as usize; + if payload_len != dg_size - UDP_TUNNEL_HEADER_SIZE { + return Err(TunnelError::InvalidPacket(format!( + "udp packet payload len not match: header len: {:?}, real len: {:?}", + payload_len, dg_size + ))); + } + + Ok(zc_packet) +} + +#[instrument] +async fn forward_from_ring_to_udp( + mut ring_recv: RingStream, + socket: &Arc, + addr: &SocketAddr, + conn_id: u32, +) -> Option { + tracing::debug!("udp forward from ring to udp"); + loop { + let Some(buf) = ring_recv.next().await else { + return None; + }; + let mut packet = match buf { + Ok(v) => v, + Err(e) => { + return Some(e); + } + }; + + let udp_payload_len = packet.udp_payload().len(); + let header = packet.mut_udp_tunnel_header().unwrap(); + header.conn_id.set(conn_id); + header.len.set(udp_payload_len as u16); + header.msg_type = UdpPacketType::Data as u8; + + let buf = packet.into_bytes(ZCPacketType::UDP); + tracing::trace!(?udp_payload_len, ?buf, "udp forward from ring to udp"); + let ret = socket.send_to(&buf, &addr).await; + if ret.is_err() { + return Some(TunnelError::IOError(ret.unwrap_err())); + } else if ret.unwrap() == 0 { + return None; + } + } +} + +struct UdpConnection { + socket: Arc, + conn_id: u32, + dst_addr: SocketAddr, + + ring_sender: RingSink, + forward_task: JoinHandle<()>, +} + +impl UdpConnection { + pub fn new( + socket: Arc, + conn_id: u32, + dst_addr: SocketAddr, + ring_sender: RingSink, + ring_recv: RingStream, + close_event_sender: UdpCloseEventSender, + ) -> Self { + let s = socket.clone(); + let forward_task = tokio::spawn(async move { + let close_event_sender = close_event_sender; + let err = forward_from_ring_to_udp(ring_recv, &s, &dst_addr, conn_id).await; + if let Err(e) = close_event_sender.send(err) { + tracing::error!(?e, "udp send close event error"); + } + }); + + Self { + socket, + conn_id, + dst_addr, + ring_sender, + forward_task, + } + } +} + +impl Drop for UdpConnection { + fn drop(&mut self) { + self.forward_task.abort(); + } +} + +#[derive(Clone)] +struct UdpTunnelListenerData { + local_url: url::Url, + socket: Option>, + sock_map: Arc>, + conn_send: Sender>, + close_event_sender: UdpCloseEventSender, +} + +impl UdpTunnelListenerData { + pub fn new( + local_url: url::Url, + conn_send: Sender>, + close_event_sender: UdpCloseEventSender, + ) -> Self { + Self { + local_url, + socket: None, + sock_map: Arc::new(DashMap::new()), + conn_send, + close_event_sender, + } + } + + async fn handle_new_connect(self: Self, remote_addr: SocketAddr, zc_packet: ZCPacket) { + let udp_payload = zc_packet.udp_payload(); + if udp_payload.len() != 8 { + tracing::warn!( + "udp syn packet payload len not match: {:?}, packet: {:?}", + udp_payload.len(), + zc_packet, + ); + return; + } + let magic = u64::from_le_bytes(udp_payload[..8].try_into().unwrap()); + let conn_id = zc_packet.udp_tunnel_header().unwrap().conn_id.get(); + + tracing::info!(?conn_id, ?remote_addr, "udp connection accept handling",); + let socket = self.socket.as_ref().unwrap().clone(); + + let sack_buf = new_sack_packet(conn_id, magic).into_bytes(ZCPacketType::UDP); + if let Err(e) = socket.send_to(&sack_buf, remote_addr).await { + tracing::error!(?e, "udp send sack packet error"); + return; + } + + let ring_for_send_udp = Arc::new(RingTunnel::new(128)); + let ring_for_recv_udp = Arc::new(RingTunnel::new(128)); + tracing::debug!( + ?ring_for_send_udp, + ?ring_for_recv_udp, + "udp build tunnel for listener" + ); + + let internal_conn = UdpConnection::new( + socket.clone(), + conn_id, + remote_addr, + RingSink::new(ring_for_recv_udp.clone()), + RingStream::new(ring_for_send_udp.clone()), + self.close_event_sender.clone(), + ); + self.sock_map.insert(remote_addr, internal_conn); + + let conn = Box::new(TunnelWrapper::new( + Box::new(RingStream::new(ring_for_recv_udp)), + Box::new(RingSink::new(ring_for_send_udp)), + Some(TunnelInfo { + tunnel_type: "udp".to_owned(), + local_addr: self.local_url.clone().into(), + remote_addr: url::Url::parse(&format!("udp://{}", remote_addr)) + .unwrap() + .into(), + }), + )); + + if let Err(e) = self.conn_send.send(conn).await { + tracing::warn!(?e, "udp send conn to accept channel error"); + } + } + + async fn try_forward_packet( + self: &Self, + remote_addr: &SocketAddr, + conn_id: u32, + p: ZCPacket, + ) -> Result<(), TunnelError> { + let Some(conn) = self.sock_map.get(remote_addr) else { + return Err(TunnelError::InternalError( + "udp connection not found".to_owned(), + )); + }; + + if conn.conn_id != conn_id { + return Err(TunnelError::ConnIdNotMatch(conn.conn_id, conn_id)); + } + + if !conn.ring_sender.has_empty_slot() { + return Err(TunnelError::BufferFull); + } + + conn.ring_sender.push_no_check(p)?; + + Ok(()) + } + + async fn process_forward_packet(&self, zc_packet: ZCPacket, addr: &SocketAddr) { + let header = zc_packet.udp_tunnel_header().unwrap(); + if header.msg_type == UdpPacketType::Syn as u8 { + tokio::spawn(Self::handle_new_connect(self.clone(), *addr, zc_packet)); + } else { + if let Err(e) = self + .try_forward_packet(addr, header.conn_id.get(), zc_packet) + .await + { + tracing::trace!(?e, "udp forward packet error"); + } + } + } + + async fn do_forward_task(self: Self) { + let socket = self.socket.as_ref().unwrap().clone(); + let mut buf = BytesMut::new(); + loop { + reserve_buf(&mut buf, UDP_DATA_MTU, UDP_DATA_MTU * 128); + let (dg_size, addr) = socket.recv_buf_from(&mut buf).await.unwrap(); + tracing::trace!( + "udp recv packet: {:?}, buf: {:?}, size: {}", + addr, + buf, + dg_size + ); + + let zc_packet = match get_zcpacket_from_buf(buf.split()) { + Ok(v) => v, + Err(e) => { + tracing::warn!(?e, "udp get zc packet from buf error"); + continue; + } + }; + self.process_forward_packet(zc_packet, &addr).await; + } + } +} + +pub struct UdpTunnelListener { + addr: url::Url, + socket: Option>, + + conn_recv: Receiver>, + data: UdpTunnelListenerData, + forward_tasks: Arc>>, + close_event_recv: UdpCloseEventReceiver, +} + +impl UdpTunnelListener { + pub fn new(addr: url::Url) -> Self { + let (close_event_send, close_event_recv) = tokio::sync::mpsc::unbounded_channel(); + let (conn_send, conn_recv) = tokio::sync::mpsc::channel(100); + Self { + addr: addr.clone(), + socket: None, + conn_recv, + data: UdpTunnelListenerData::new(addr, conn_send, close_event_send), + forward_tasks: Arc::new(std::sync::Mutex::new(JoinSet::new())), + close_event_recv, + } + } + + pub fn get_socket(&self) -> Option> { + self.socket.clone() + } +} + +#[async_trait] +impl TunnelListener for UdpTunnelListener { + async fn listen(&mut self) -> Result<(), super::TunnelError> { + let addr = super::check_scheme_and_get_socket_addr::(&self.addr, "udp")?; + + let socket2_socket = socket2::Socket::new( + socket2::Domain::for_address(addr), + socket2::Type::DGRAM, + Some(socket2::Protocol::UDP), + )?; + + let tunnel_url: TunnelUrl = self.addr.clone().into(); + if let Some(bind_dev) = tunnel_url.bind_dev() { + setup_sokcet2_ext(&socket2_socket, &addr, Some(bind_dev))?; + } else { + setup_sokcet2(&socket2_socket, &addr)?; + } + + self.socket = Some(Arc::new(UdpSocket::from_std(socket2_socket.into())?)); + self.data.socket = self.socket.clone(); + + self.forward_tasks + .lock() + .unwrap() + .spawn(self.data.clone().do_forward_task()); + + join_joinset_background(self.forward_tasks.clone(), "UdpTunnelListener".to_owned()); + + Ok(()) + } + + async fn accept(&mut self) -> Result, super::TunnelError> { + log::info!("start udp accept: {:?}", self.addr); + while let Some(conn) = self.conn_recv.recv().await { + return Ok(conn); + } + return Err(super::TunnelError::InternalError( + "udp accept error".to_owned(), + )); + } + + fn local_url(&self) -> url::Url { + self.addr.clone() + } + + fn get_conn_counter(&self) -> Arc> { + struct UdpTunnelConnCounter { + sock_map: Arc>, + } + + impl Debug for UdpTunnelConnCounter { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("UdpTunnelConnCounter") + .field("sock_map_len", &self.sock_map.len()) + .finish() + } + } + + impl TunnelConnCounter for UdpTunnelConnCounter { + fn get(&self) -> u32 { + self.sock_map.len() as u32 + } + } + + Arc::new(Box::new(UdpTunnelConnCounter { + sock_map: self.data.sock_map.clone(), + })) + } +} + +pub struct UdpTunnelConnector { + addr: url::Url, + bind_addrs: Vec, +} + +impl UdpTunnelConnector { + pub fn new(addr: url::Url) -> Self { + Self { + addr, + bind_addrs: vec![], + } + } + + async fn wait_sack( + socket: &UdpSocket, + addr: SocketAddr, + conn_id: u32, + magic: u64, + ) -> Result { + let mut buf = BytesMut::new(); + buf.reserve(UDP_DATA_MTU); + + let (usize, recv_addr) = tokio::time::timeout( + tokio::time::Duration::from_secs(3), + socket.recv_buf_from(&mut buf), + ) + .await??; + let zc_packet = get_zcpacket_from_buf(buf.split())?; + if recv_addr != addr { + tracing::warn!(?recv_addr, ?addr, ?usize, "udp wait sack addr not match"); + } + + let header = zc_packet.udp_tunnel_header().unwrap(); + + if header.conn_id.get() != conn_id { + return Err(super::TunnelError::ConnIdNotMatch( + header.conn_id.get(), + conn_id, + )); + } + + if header.msg_type != UdpPacketType::Sack as u8 { + return Err(TunnelError::InvalidPacket("not sack packet".to_owned())); + } + + let payload = zc_packet.udp_payload(); + if payload.len() != 8 { + return Err(TunnelError::InvalidPacket( + "udp sack packet payload len not match".to_owned(), + )); + } + + let sack_magic = u64::from_le_bytes(payload[..8].try_into().unwrap()); + if sack_magic != magic { + return Err(TunnelError::InvalidPacket( + "udp sack magic not match".to_owned(), + )); + } + + Ok(recv_addr) + } + + async fn wait_sack_loop( + socket: &UdpSocket, + addr: SocketAddr, + conn_id: u32, + magic: u64, + ) -> Result { + loop { + let ret = Self::wait_sack(socket, addr, conn_id, magic).await; + if ret.is_err() { + tracing::debug!(?ret, "udp wait sack error"); + continue; + } else { + return ret; + } + } + } + + async fn build_tunnel( + &self, + socket: UdpSocket, + dst_addr: SocketAddr, + conn_id: u32, + ) -> Result, super::TunnelError> { + let socket = Arc::new(socket); + let ring_for_send_udp = Arc::new(RingTunnel::new(128)); + let ring_for_recv_udp = Arc::new(RingTunnel::new(128)); + tracing::debug!( + ?ring_for_send_udp, + ?ring_for_recv_udp, + "udp build tunnel for connector" + ); + + let (close_event_send, mut close_event_recv) = tokio::sync::mpsc::unbounded_channel(); + + // forward from ring to udp + let socket_sender = socket.clone(); + let ring_recv = RingStream::new(ring_for_send_udp.clone()); + tokio::spawn(async move { + let err = forward_from_ring_to_udp(ring_recv, &socket_sender, &dst_addr, conn_id).await; + tracing::debug!(?err, "udp forward from ring to udp done"); + close_event_send.send(err).unwrap(); + }); + + let socket_recv = socket.clone(); + let ring_sender = RingSink::new(ring_for_recv_udp.clone()); + tokio::spawn(async move { + let mut buf = BytesMut::new(); + loop { + reserve_buf(&mut buf, UDP_DATA_MTU, UDP_DATA_MTU * 128); + let ret; + tokio::select! { + _ = close_event_recv.recv() => { + tracing::debug!("connector udp close event"); + break; + } + recv_res = socket_recv.recv_buf_from(&mut buf) => ret = Some(recv_res.unwrap()), + } + let (dg_size, addr) = ret.unwrap(); + tracing::trace!( + "connector udp recv packet: {:?}, buf: {:?}, size: {}", + addr, + buf, + dg_size + ); + + let zc_packet = match get_zcpacket_from_buf(buf.split()) { + Ok(v) => v, + Err(e) => { + tracing::warn!(?e, "connector udp get zc packet from buf error"); + continue; + } + }; + let header = zc_packet.udp_tunnel_header().unwrap(); + if header.conn_id.get() != conn_id { + tracing::trace!( + "connector udp conn id not match: {:?}, {:?}", + header.conn_id.get(), + conn_id + ); + } + + if header.msg_type == UdpPacketType::Data as u8 { + if let Err(e) = ring_sender.push_no_check(zc_packet) { + tracing::trace!(?e, "udp forward packet error"); + } + } + } + }.instrument(tracing::info_span!("udp connector forward from udp to ring", ?ring_for_recv_udp))); + + Ok(Box::new(TunnelWrapper::new( + Box::new(RingStream::new(ring_for_recv_udp)), + Box::new(RingSink::new(ring_for_send_udp)), + Some(TunnelInfo { + tunnel_type: "udp".to_owned(), + local_addr: url::Url::parse(&format!("udp://{}", socket.local_addr()?)) + .unwrap() + .into(), + remote_addr: self.addr.clone().into(), + }), + ))) + } + + pub async fn try_connect_with_socket( + &self, + socket: UdpSocket, + ) -> Result, super::TunnelError> { + let addr = super::check_scheme_and_get_socket_addr::(&self.addr, "udp")?; + log::warn!("udp connect: {:?}", self.addr); + + #[cfg(target_os = "windows")] + crate::arch::windows::disable_connection_reset(&socket)?; + + // send syn + let conn_id = rand::random(); + let magic = rand::random(); + let udp_packet = new_syn_packet(conn_id, magic).into_bytes(ZCPacketType::UDP); + let ret = socket.send_to(&udp_packet, &addr).await?; + tracing::warn!(?udp_packet, ?ret, "udp send syn"); + + // wait sack + let recv_addr = tokio::time::timeout( + tokio::time::Duration::from_secs(3), + Self::wait_sack_loop(&socket, addr, conn_id, magic), + ) + .await??; + + socket.connect(recv_addr).await?; + self.build_tunnel(socket, addr, conn_id).await + } + + async fn connect_with_default_bind(&mut self) -> Result, super::TunnelError> { + let socket = UdpSocket::bind("0.0.0.0:0").await?; + return self.try_connect_with_socket(socket).await; + } + + async fn connect_with_custom_bind(&mut self) -> Result, super::TunnelError> { + let futures = FuturesUnordered::new(); + + for bind_addr in self.bind_addrs.iter() { + let socket2_socket = socket2::Socket::new( + socket2::Domain::for_address(*bind_addr), + socket2::Type::DGRAM, + Some(socket2::Protocol::UDP), + )?; + setup_sokcet2(&socket2_socket, &bind_addr)?; + let socket = UdpSocket::from_std(socket2_socket.into())?; + futures.push(self.try_connect_with_socket(socket)); + } + wait_for_connect_futures(futures).await + } +} + +#[async_trait] +impl super::TunnelConnector for UdpTunnelConnector { + async fn connect(&mut self) -> Result, super::TunnelError> { + if self.bind_addrs.is_empty() { + self.connect_with_default_bind().await + } else { + self.connect_with_custom_bind().await + } + } + + fn remote_url(&self) -> url::Url { + self.addr.clone() + } + + fn set_bind_addrs(&mut self, addrs: Vec) { + self.bind_addrs = addrs; + } +} + +#[cfg(test)] +mod tests { + use std::time::Duration; + + use futures::SinkExt; + use tokio::time::timeout; + + use super::*; + use crate::{ + common::global_ctx::tests::get_mock_global_ctx, + tunnel::{ + check_scheme_and_get_socket_addr, + common::{ + get_interface_name_by_ip, + tests::{_tunnel_bench, _tunnel_echo_server, _tunnel_pingpong}, + }, + TunnelConnector, + }, + }; + + #[tokio::test] + async fn udp_pingpong() { + let listener = UdpTunnelListener::new("udp://0.0.0.0:5556".parse().unwrap()); + let connector = UdpTunnelConnector::new("udp://127.0.0.1:5556".parse().unwrap()); + _tunnel_pingpong(listener, connector).await + } + + #[tokio::test] + async fn udp_bench() { + let listener = UdpTunnelListener::new("udp://0.0.0.0:5555".parse().unwrap()); + let connector = UdpTunnelConnector::new("udp://127.0.0.1:5555".parse().unwrap()); + _tunnel_bench(listener, connector).await + } + + #[tokio::test] + async fn udp_bench_with_bind() { + let listener = UdpTunnelListener::new("udp://127.0.0.1:5554".parse().unwrap()); + let mut connector = UdpTunnelConnector::new("udp://127.0.0.1:5554".parse().unwrap()); + connector.set_bind_addrs(vec!["127.0.0.1:0".parse().unwrap()]); + _tunnel_pingpong(listener, connector).await + } + + #[tokio::test] + #[should_panic] + async fn udp_bench_with_bind_fail() { + let listener = UdpTunnelListener::new("udp://127.0.0.1:5553".parse().unwrap()); + let mut connector = UdpTunnelConnector::new("udp://127.0.0.1:5553".parse().unwrap()); + connector.set_bind_addrs(vec!["10.0.0.1:0".parse().unwrap()]); + _tunnel_pingpong(listener, connector).await + } + + async fn send_random_data_to_socket(remote_url: url::Url) { + let socket = UdpSocket::bind("0.0.0.0:0").await.unwrap(); + socket + .connect(format!( + "{}:{}", + remote_url.host().unwrap(), + remote_url.port().unwrap() + )) + .await + .unwrap(); + + // get a random 100-len buf + loop { + let mut buf = vec![0u8; 100]; + rand::thread_rng().fill(&mut buf[..]); + socket.send(&buf).await.unwrap(); + tokio::time::sleep(tokio::time::Duration::from_millis(50)).await; + } + } + + #[tokio::test] + async fn udp_multiple_conns() { + let mut listener = UdpTunnelListener::new("udp://0.0.0.0:5557".parse().unwrap()); + listener.listen().await.unwrap(); + + let _lis = tokio::spawn(async move { + loop { + let ret = listener.accept().await.unwrap(); + assert_eq!( + ret.info().unwrap().local_addr, + listener.local_url().to_string() + ); + tokio::spawn(async move { _tunnel_echo_server(ret, false).await }); + } + }); + + let mut connector1 = UdpTunnelConnector::new("udp://127.0.0.1:5557".parse().unwrap()); + let mut connector2 = UdpTunnelConnector::new("udp://127.0.0.1:5557".parse().unwrap()); + + let t1 = connector1.connect().await.unwrap(); + let t2 = connector2.connect().await.unwrap(); + + tokio::spawn(timeout( + Duration::from_secs(2), + send_random_data_to_socket(t1.info().unwrap().local_addr.parse().unwrap()), + )); + tokio::spawn(timeout( + Duration::from_secs(2), + send_random_data_to_socket(t1.info().unwrap().remote_addr.parse().unwrap()), + )); + tokio::spawn(timeout( + Duration::from_secs(2), + send_random_data_to_socket(t2.info().unwrap().remote_addr.parse().unwrap()), + )); + + let sender1 = tokio::spawn(async move { + let (mut stream, mut sink) = t1.split(); + + for i in 0..10 { + sink.send(ZCPacket::new_with_payload("hello1".as_bytes())) + .await + .unwrap(); + let recv = stream.next().await.unwrap().unwrap(); + println!("t1 recv: {:?}, {:?}", recv, i); + assert_eq!(recv.payload(), "hello1".as_bytes()); + tokio::time::sleep(tokio::time::Duration::from_millis(100)).await; + } + }); + + let sender2 = tokio::spawn(async move { + let (mut stream, mut sink) = t2.split(); + + for i in 0..10 { + sink.send(ZCPacket::new_with_payload("hello2".as_bytes())) + .await + .unwrap(); + let recv = stream.next().await.unwrap().unwrap(); + println!("t2 recv: {:?}, {:?}", recv, i); + assert_eq!(recv.payload(), "hello2".as_bytes()); + tokio::time::sleep(tokio::time::Duration::from_millis(100)).await; + } + }); + + let _ = tokio::join!(sender1, sender2); + } + + #[tokio::test] + async fn bind_multi_ip_to_same_dev() { + let global_ctx = get_mock_global_ctx(); + let ips = global_ctx + .get_ip_collector() + .collect_ip_addrs() + .await + .interface_ipv4s; + if ips.is_empty() { + return; + } + let bind_dev = get_interface_name_by_ip(&ips[0].parse().unwrap()); + + for ip in ips { + println!("bind to ip: {:?}, {:?}", ip, bind_dev); + let addr = check_scheme_and_get_socket_addr::( + &format!("udp://{}:11111", ip).parse().unwrap(), + "udp", + ) + .unwrap(); + let socket2_socket = socket2::Socket::new( + socket2::Domain::for_address(addr), + socket2::Type::DGRAM, + Some(socket2::Protocol::UDP), + ) + .unwrap(); + setup_sokcet2_ext(&socket2_socket, &addr, bind_dev.clone()).unwrap(); + } + } +} diff --git a/easytier/src/tunnel/wireguard.rs b/easytier/src/tunnel/wireguard.rs new file mode 100644 index 00000000..02091ae4 --- /dev/null +++ b/easytier/src/tunnel/wireguard.rs @@ -0,0 +1,827 @@ +use std::{ + collections::hash_map::DefaultHasher, + fmt::{Debug, Formatter}, + hash::Hasher, + net::SocketAddr, + pin::Pin, + sync::{atomic::AtomicBool, Arc}, + time::Duration, +}; + +use anyhow::Context; +use async_recursion::async_recursion; +use async_trait::async_trait; +use boringtun::{ + noise::{errors::WireGuardError, Tunn, TunnResult}, + x25519::{PublicKey, StaticSecret}, +}; +use bytes::BytesMut; +use dashmap::DashMap; +use futures::{stream::FuturesUnordered, SinkExt, StreamExt}; +use rand::RngCore; +use tokio::{net::UdpSocket, sync::Mutex, task::JoinSet}; + +use crate::{ + rpc::TunnelInfo, + tunnel::{ + build_url_from_socket_addr, + common::TunnelWrapper, + packet_def::{ZCPacket, WG_TUNNEL_HEADER_SIZE}, + }, +}; + +use super::{ + check_scheme_and_get_socket_addr, + common::{setup_sokcet2, setup_sokcet2_ext, wait_for_connect_futures}, + packet_def::{ZCPacketType, PEER_MANAGER_HEADER_SIZE}, + ring::create_ring_tunnel_pair, + Tunnel, TunnelError, TunnelListener, TunnelUrl, ZCPacketSink, ZCPacketStream, +}; + +const MAX_PACKET: usize = 65500; + +#[derive(Debug, Clone)] +enum WgType { + // used by easytier peer, need remove/add ip header for in/out wg msg + InternalUse, + // used by wireguard peer, keep original ip header + ExternalUse, +} + +#[derive(Clone)] +pub struct WgConfig { + my_secret_key: StaticSecret, + my_public_key: PublicKey, + + peer_secret_key: StaticSecret, + peer_public_key: PublicKey, + + wg_type: WgType, +} + +impl WgConfig { + pub fn new_from_network_identity(network_name: &str, network_secret: &str) -> Self { + let mut my_sec = [0u8; 32]; + let mut hasher = DefaultHasher::new(); + hasher.write(network_name.as_bytes()); + hasher.write(network_secret.as_bytes()); + my_sec[0..8].copy_from_slice(&hasher.finish().to_be_bytes()); + hasher.write(&my_sec[0..8]); + my_sec[8..16].copy_from_slice(&hasher.finish().to_be_bytes()); + hasher.write(&my_sec[0..16]); + my_sec[16..24].copy_from_slice(&hasher.finish().to_be_bytes()); + hasher.write(&my_sec[0..24]); + my_sec[24..32].copy_from_slice(&hasher.finish().to_be_bytes()); + + let my_secret_key = StaticSecret::from(my_sec); + let my_public_key = PublicKey::from(&my_secret_key); + let peer_secret_key = StaticSecret::from(my_sec); + let peer_public_key = my_public_key.clone(); + + WgConfig { + my_secret_key, + my_public_key, + peer_secret_key, + peer_public_key, + + wg_type: WgType::InternalUse, + } + } + + pub fn new_for_portal(server_key_seed: &str, client_key_seed: &str) -> Self { + let server_cfg = Self::new_from_network_identity("server", server_key_seed); + let client_cfg = Self::new_from_network_identity("client", client_key_seed); + Self { + my_secret_key: server_cfg.my_secret_key, + my_public_key: server_cfg.my_public_key, + peer_secret_key: client_cfg.my_secret_key, + peer_public_key: client_cfg.my_public_key, + + wg_type: WgType::ExternalUse, + } + } + + pub fn my_secret_key(&self) -> &[u8] { + self.my_secret_key.as_bytes() + } + + pub fn peer_secret_key(&self) -> &[u8] { + self.peer_secret_key.as_bytes() + } + + pub fn my_public_key(&self) -> &[u8] { + self.my_public_key.as_bytes() + } + + pub fn peer_public_key(&self) -> &[u8] { + self.peer_public_key.as_bytes() + } +} + +#[derive(Clone)] +struct WgPeerData { + udp: Arc, // only for send + endpoint: SocketAddr, + tunn: Arc>, + wg_type: WgType, + stopped: Arc, +} + +impl Debug for WgPeerData { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + f.debug_struct("WgPeerData") + .field("endpoint", &self.endpoint) + .field("local", &self.udp.local_addr()) + .finish() + } +} + +impl WgPeerData { + #[tracing::instrument] + async fn handle_one_packet_from_me( + &self, + mut zc_packet: ZCPacket, + ) -> Result<(), anyhow::Error> { + let mut send_buf = vec![0u8; MAX_PACKET]; + + let packet = if matches!(self.wg_type, WgType::InternalUse) { + Self::fill_ip_header(&mut zc_packet); + zc_packet.into_bytes(ZCPacketType::WG) + } else { + zc_packet.into_bytes(ZCPacketType::WG) + }; + tracing::trace!(?packet, "Sending packet to peer"); + + let encapsulate_result = { + let mut peer = self.tunn.lock().await; + peer.encapsulate(&packet, &mut send_buf) + }; + + tracing::trace!( + ?encapsulate_result, + "Received {} bytes from me", + packet.len() + ); + + match encapsulate_result { + TunnResult::WriteToNetwork(packet) => { + self.udp + .send_to(packet, self.endpoint) + .await + .context("Failed to send encrypted IP packet to WireGuard endpoint.")?; + tracing::debug!( + "Sent {} bytes to WireGuard endpoint (encrypted IP packet)", + packet.len() + ); + } + TunnResult::Err(e) => { + tracing::error!("Failed to encapsulate IP packet: {:?}", e); + } + TunnResult::Done => { + // Ignored + } + other => { + tracing::error!( + "Unexpected WireGuard state during encapsulation: {:?}", + other + ); + } + }; + Ok(()) + } + + /// WireGuard consumption task. Receives encrypted packets from the WireGuard endpoint, + /// decapsulates them, and dispatches newly received IP packets. + #[tracing::instrument(skip(sink))] + pub async fn handle_one_packet_from_peer( + &self, + mut sink: S, + recv_buf: &[u8], + ) { + let mut send_buf = vec![0u8; MAX_PACKET]; + let data = &recv_buf[..]; + let decapsulate_result = { + let mut peer = self.tunn.lock().await; + peer.decapsulate(None, data, &mut send_buf) + }; + + tracing::debug!("Decapsulation result: {:?}", decapsulate_result); + + match decapsulate_result { + TunnResult::WriteToNetwork(packet) => { + match self.udp.send_to(packet, self.endpoint).await { + Ok(_) => {} + Err(e) => { + tracing::error!("Failed to send decapsulation-instructed packet to WireGuard endpoint: {:?}", e); + return; + } + }; + let mut peer = self.tunn.lock().await; + loop { + let mut send_buf = vec![0u8; MAX_PACKET]; + match peer.decapsulate(None, &[], &mut send_buf) { + TunnResult::WriteToNetwork(packet) => { + match self.udp.send_to(packet, self.endpoint).await { + Ok(_) => {} + Err(e) => { + tracing::error!("Failed to send decapsulation-instructed packet to WireGuard endpoint: {:?}", e); + break; + } + }; + } + _ => { + break; + } + } + } + } + TunnResult::WriteToTunnelV4(packet, _) | TunnResult::WriteToTunnelV6(packet, _) => { + tracing::debug!( + ?packet, + "receive IP packet from peer: {} bytes", + packet.len() + ); + let mut b = BytesMut::new(); + if matches!(self.wg_type, WgType::InternalUse) { + b.resize(WG_TUNNEL_HEADER_SIZE, 0); + b.extend_from_slice(self.remove_ip_header(packet, packet[0] >> 4 == 4)); + } else { + b.extend_from_slice(packet); + }; + let zc_packet = ZCPacket::new_from_buf(b, ZCPacketType::WG); + tracing::trace!(?zc_packet, "forward zc_packet to sink"); + let ret = sink.send(zc_packet).await; + if ret.is_err() { + tracing::error!("Failed to send packet to tunnel: {:?}", ret); + } + } + _ => { + tracing::warn!( + "Unexpected WireGuard state during decapsulation: {:?}", + decapsulate_result + ); + } + } + } + + #[tracing::instrument] + #[async_recursion] + async fn handle_routine_tun_result<'a: 'async_recursion>(&self, result: TunnResult<'a>) -> () { + match result { + TunnResult::WriteToNetwork(packet) => { + tracing::debug!( + "Sending routine packet of {} bytes to WireGuard endpoint", + packet.len() + ); + match self.udp.send_to(packet, self.endpoint).await { + Ok(_) => {} + Err(e) => { + tracing::error!( + "Failed to send routine packet to WireGuard endpoint: {:?}", + e + ); + } + }; + } + TunnResult::Err(WireGuardError::ConnectionExpired) => { + tracing::warn!("Wireguard handshake has expired!"); + + let mut buf = vec![0u8; MAX_PACKET]; + let result = self + .tunn + .lock() + .await + .format_handshake_initiation(&mut buf[..], false); + + self.handle_routine_tun_result(result).await + } + TunnResult::Err(e) => { + tracing::error!( + "Failed to prepare routine packet for WireGuard endpoint: {:?}", + e + ); + } + TunnResult::Done => { + // Sleep for a bit + tokio::time::sleep(Duration::from_millis(250)).await; + } + other => { + tracing::warn!("Unexpected WireGuard routine task state: {:?}", other); + tokio::time::sleep(Duration::from_millis(250)).await; + } + }; + } + + /// WireGuard Routine task. Handles Handshake, keep-alive, etc. + pub async fn routine_task(self) { + loop { + let mut send_buf = vec![0u8; MAX_PACKET]; + let tun_result = { self.tunn.lock().await.update_timers(&mut send_buf) }; + self.handle_routine_tun_result(tun_result).await; + } + } + + fn fill_ip_header(zc_packet: &mut ZCPacket) { + let len = zc_packet.payload_len() + PEER_MANAGER_HEADER_SIZE; + let ip_header = &mut zc_packet.mut_wg_tunnel_header().unwrap().ipv4_header; + ip_header[0] = 0x45; + ip_header[1] = 0; + ip_header[2..4].copy_from_slice(&((len + 20) as u16).to_be_bytes()); + ip_header[4..6].copy_from_slice(&0u16.to_be_bytes()); + ip_header[6..8].copy_from_slice(&0u16.to_be_bytes()); + ip_header[8] = 64; + ip_header[9] = 0; + ip_header[10..12].copy_from_slice(&0u16.to_be_bytes()); + ip_header[12..16].copy_from_slice(&0u32.to_be_bytes()); + ip_header[16..20].copy_from_slice(&0u32.to_be_bytes()); + } + + fn remove_ip_header<'a>(&self, packet: &'a [u8], is_v4: bool) -> &'a [u8] { + if is_v4 { + return &packet[20..]; + } else { + return &packet[40..]; + } + } +} + +struct WgPeer { + udp: Arc, // only for send + config: WgConfig, + endpoint: SocketAddr, + + sink: std::sync::Mutex>>>, + + data: Option, + tasks: JoinSet<()>, + + access_time: std::time::Instant, +} + +impl WgPeer { + fn new(udp: Arc, config: WgConfig, endpoint: SocketAddr) -> Self { + WgPeer { + udp, + config, + endpoint, + + sink: std::sync::Mutex::new(None), + + data: None, + tasks: JoinSet::new(), + + access_time: std::time::Instant::now(), + } + } + + async fn handle_packet_from_me(mut stream: S, data: WgPeerData) { + while let Some(Ok(packet)) = stream.next().await { + let ret = data.handle_one_packet_from_me(packet).await; + if let Err(e) = ret { + tracing::error!("Failed to handle packet from me: {}", e); + } + } + data.stopped + .store(true, std::sync::atomic::Ordering::Relaxed); + } + + async fn handle_packet_from_peer(&mut self, packet: &[u8]) { + self.access_time = std::time::Instant::now(); + tracing::trace!("Received {} bytes from peer", packet.len()); + let data = self.data.as_ref().unwrap(); + // TODO: improve this + let mut sink = self.sink.lock().unwrap().take().unwrap(); + data.handle_one_packet_from_peer(&mut sink, packet).await; + self.sink.lock().unwrap().replace(sink); + } + + fn start_and_get_tunnel(&mut self) -> Box { + let (stunnel, ctunnel) = create_ring_tunnel_pair(); + + let (stream, sink) = stunnel.split(); + + let data = WgPeerData { + udp: self.udp.clone(), + endpoint: self.endpoint, + tunn: Arc::new(Mutex::new( + Tunn::new( + self.config.my_secret_key.clone(), + self.config.peer_public_key.clone(), + None, + None, + rand::thread_rng().next_u32(), + None, + ) + .unwrap(), + )), + wg_type: self.config.wg_type.clone(), + stopped: Arc::new(AtomicBool::new(false)), + }; + + self.data = Some(data.clone()); + self.sink.lock().unwrap().replace(sink); + + self.tasks + .spawn(Self::handle_packet_from_me(stream, data.clone())); + self.tasks.spawn(data.routine_task()); + + ctunnel + } + + fn stopped(&self) -> bool { + self.data + .as_ref() + .unwrap() + .stopped + .load(std::sync::atomic::Ordering::Relaxed) + } +} + +type ConnSender = tokio::sync::mpsc::UnboundedSender>; +type ConnReceiver = tokio::sync::mpsc::UnboundedReceiver>; + +pub struct WgTunnelListener { + addr: url::Url, + config: WgConfig, + + udp: Option>, + conn_recv: ConnReceiver, + conn_send: Option, + + wg_peer_map: Arc>, + + tasks: JoinSet<()>, +} + +impl WgTunnelListener { + pub fn new(addr: url::Url, config: WgConfig) -> Self { + let (conn_send, conn_recv) = tokio::sync::mpsc::unbounded_channel(); + WgTunnelListener { + addr, + config, + + udp: None, + conn_recv, + conn_send: Some(conn_send), + + wg_peer_map: Arc::new(DashMap::new()), + + tasks: JoinSet::new(), + } + } + + fn get_udp_socket(&self) -> Arc { + self.udp.as_ref().unwrap().clone() + } + + async fn handle_udp_incoming( + socket: Arc, + config: WgConfig, + conn_sender: ConnSender, + peer_map: Arc>, + ) { + let mut tasks = JoinSet::new(); + + let peer_map_clone = peer_map.clone(); + tasks.spawn(async move { + loop { + peer_map_clone + .retain(|_, peer| peer.access_time.elapsed().as_secs() < 61 && !peer.stopped()); + tokio::time::sleep(Duration::from_secs(1)).await; + } + }); + + let mut buf = vec![0u8; MAX_PACKET]; + loop { + let Ok((n, addr)) = socket.recv_from(&mut buf).await else { + tracing::error!("Failed to receive from UDP socket"); + break; + }; + + let data = &buf[..n]; + tracing::trace!(?n, ?addr, "Received bytes from peer"); + + if !peer_map.contains_key(&addr) { + tracing::info!("New peer: {}", addr); + let mut wg = WgPeer::new(socket.clone(), config.clone(), addr.clone()); + let (stream, sink) = wg.start_and_get_tunnel().split(); + let tunnel = Box::new(TunnelWrapper::new( + stream, + sink, + Some(TunnelInfo { + tunnel_type: "wg".to_owned(), + local_addr: build_url_from_socket_addr( + &socket.local_addr().unwrap().to_string(), + "wg", + ) + .into(), + remote_addr: build_url_from_socket_addr(&addr.to_string(), "wg").into(), + }), + )); + if let Err(e) = conn_sender.send(tunnel) { + tracing::error!("Failed to send tunnel to conn_sender: {}", e); + } + peer_map.insert(addr, wg); + } + + let mut peer = peer_map.get_mut(&addr).unwrap(); + peer.handle_packet_from_peer(data).await; + } + } +} + +#[async_trait] +impl TunnelListener for WgTunnelListener { + async fn listen(&mut self) -> Result<(), super::TunnelError> { + let addr = check_scheme_and_get_socket_addr::(&self.addr, "wg")?; + let socket2_socket = socket2::Socket::new( + socket2::Domain::for_address(addr), + socket2::Type::DGRAM, + Some(socket2::Protocol::UDP), + )?; + + let tunnel_url: TunnelUrl = self.addr.clone().into(); + if let Some(bind_dev) = tunnel_url.bind_dev() { + setup_sokcet2_ext(&socket2_socket, &addr, Some(bind_dev))?; + } else { + setup_sokcet2(&socket2_socket, &addr)?; + } + + self.udp = Some(Arc::new(UdpSocket::from_std(socket2_socket.into())?)); + self.tasks.spawn(Self::handle_udp_incoming( + self.get_udp_socket(), + self.config.clone(), + self.conn_send.take().unwrap(), + self.wg_peer_map.clone(), + )); + + Ok(()) + } + + async fn accept(&mut self) -> Result, super::TunnelError> { + while let Some(tunnel) = self.conn_recv.recv().await { + tracing::info!(?tunnel, "Accepted tunnel"); + return Ok(tunnel); + } + Err(TunnelError::Shutdown) + } + + fn local_url(&self) -> url::Url { + self.addr.clone() + } +} + +#[derive(Clone)] +pub struct WgTunnelConnector { + addr: url::Url, + config: WgConfig, + udp: Option>, + + bind_addrs: Vec, +} + +impl Debug for WgTunnelConnector { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + f.debug_struct("WgTunnelConnector") + .field("addr", &self.addr) + .field("udp", &self.udp) + .finish() + } +} + +impl WgTunnelConnector { + pub fn new(addr: url::Url, config: WgConfig) -> Self { + WgTunnelConnector { + addr, + config, + udp: None, + bind_addrs: vec![], + } + } + + fn create_handshake_init(tun: &mut Tunn) -> Vec { + let mut dst = vec![0u8; 2048]; + let handshake_init = tun.format_handshake_initiation(&mut dst, false); + assert!(matches!(handshake_init, TunnResult::WriteToNetwork(_))); + let handshake_init = if let TunnResult::WriteToNetwork(sent) = handshake_init { + sent + } else { + unreachable!(); + }; + + handshake_init.into() + } + + fn parse_handshake_resp(tun: &mut Tunn, handshake_resp: &[u8]) -> Vec { + let mut dst = vec![0u8; 2048]; + let keepalive = tun.decapsulate(None, handshake_resp, &mut dst); + assert!( + matches!(keepalive, TunnResult::WriteToNetwork(_)), + "Failed to parse handshake response, {:?}", + keepalive + ); + + let keepalive = if let TunnResult::WriteToNetwork(sent) = keepalive { + sent + } else { + unreachable!(); + }; + + keepalive.into() + } + + #[tracing::instrument(skip(config))] + async fn connect_with_socket( + addr_url: url::Url, + config: WgConfig, + udp: UdpSocket, + ) -> Result, super::TunnelError> { + let addr = super::check_scheme_and_get_socket_addr::(&addr_url, "wg")?; + tracing::warn!("wg connect: {:?}", addr); + let local_addr = udp.local_addr().unwrap().to_string(); + + let mut wg_peer = WgPeer::new(Arc::new(udp), config.clone(), addr); + let tunnel = wg_peer.start_and_get_tunnel(); + + let data = wg_peer.data.as_ref().unwrap().clone(); + let mut sink = wg_peer.sink.lock().unwrap().take().unwrap(); + wg_peer.tasks.spawn(async move { + loop { + let mut buf = vec![0u8; MAX_PACKET]; + let (n, recv_addr) = data.udp.recv_from(&mut buf).await.unwrap(); + if recv_addr != addr { + continue; + } + data.handle_one_packet_from_peer(&mut sink, &buf[..n]).await; + } + }); + + let (stream, sink) = tunnel.split(); + let ret = Box::new(TunnelWrapper::new_with_associate_data( + stream, + sink, + Some(TunnelInfo { + tunnel_type: "wg".to_owned(), + local_addr: super::build_url_from_socket_addr(&local_addr, "wg").into(), + remote_addr: addr_url.to_string(), + }), + Some(Box::new(wg_peer)), + )); + + Ok(ret) + } +} + +#[async_trait] +impl super::TunnelConnector for WgTunnelConnector { + #[tracing::instrument] + async fn connect(&mut self) -> Result, super::TunnelError> { + let bind_addrs = if self.bind_addrs.is_empty() { + vec!["0.0.0.0:0".parse().unwrap()] + } else { + self.bind_addrs.clone() + }; + let futures = FuturesUnordered::new(); + + for bind_addr in bind_addrs.into_iter() { + let socket2_socket = socket2::Socket::new( + socket2::Domain::for_address(bind_addr), + socket2::Type::DGRAM, + Some(socket2::Protocol::UDP), + )?; + setup_sokcet2(&socket2_socket, &bind_addr)?; + let socket = UdpSocket::from_std(socket2_socket.into())?; + tracing::info!(?bind_addr, ?self.addr, "prepare wg connect task"); + futures.push(Self::connect_with_socket( + self.addr.clone(), + self.config.clone(), + socket, + )); + } + + wait_for_connect_futures(futures).await + } + + fn remote_url(&self) -> url::Url { + self.addr.clone() + } + + fn set_bind_addrs(&mut self, addrs: Vec) { + self.bind_addrs = addrs; + } +} + +#[cfg(test)] +pub mod tests { + use super::*; + use crate::tunnel::{ + common::tests::{_tunnel_bench, _tunnel_pingpong}, + TunnelConnector, + }; + use boringtun::*; + + pub fn create_wg_config() -> (WgConfig, WgConfig) { + let my_secret_key = x25519::StaticSecret::random_from_rng(rand::thread_rng()); + let my_public_key = x25519::PublicKey::from(&my_secret_key); + + let their_secret_key = x25519::StaticSecret::random_from_rng(rand::thread_rng()); + let their_public_key = x25519::PublicKey::from(&their_secret_key); + + let server_cfg = WgConfig { + my_secret_key: my_secret_key.clone(), + my_public_key, + peer_secret_key: their_secret_key.clone(), + peer_public_key: their_public_key.clone(), + wg_type: WgType::InternalUse, + }; + + let client_cfg = WgConfig { + my_secret_key: their_secret_key, + my_public_key: their_public_key, + peer_secret_key: my_secret_key, + peer_public_key: my_public_key, + wg_type: WgType::InternalUse, + }; + + (server_cfg, client_cfg) + } + + #[tokio::test] + async fn wg_pingpong() { + let (server_cfg, client_cfg) = create_wg_config(); + let listener = WgTunnelListener::new("wg://0.0.0.0:5599".parse().unwrap(), server_cfg); + let connector = WgTunnelConnector::new("wg://127.0.0.1:5599".parse().unwrap(), client_cfg); + _tunnel_pingpong(listener, connector).await + } + + #[tokio::test] + async fn wg_bench() { + let (server_cfg, client_cfg) = create_wg_config(); + let listener = WgTunnelListener::new("wg://0.0.0.0:5598".parse().unwrap(), server_cfg); + let connector = WgTunnelConnector::new("wg://127.0.0.1:5598".parse().unwrap(), client_cfg); + _tunnel_bench(listener, connector).await + } + + #[tokio::test] + async fn wg_bench_with_bind() { + let (server_cfg, client_cfg) = create_wg_config(); + let listener = WgTunnelListener::new("wg://127.0.0.1:5597".parse().unwrap(), server_cfg); + let mut connector = + WgTunnelConnector::new("wg://127.0.0.1:5597".parse().unwrap(), client_cfg); + connector.set_bind_addrs(vec!["127.0.0.1:0".parse().unwrap()]); + _tunnel_pingpong(listener, connector).await + } + + #[tokio::test] + #[should_panic] + async fn wg_bench_with_bind_fail() { + let (server_cfg, client_cfg) = create_wg_config(); + let listener = WgTunnelListener::new("wg://127.0.0.1:5596".parse().unwrap(), server_cfg); + let mut connector = + WgTunnelConnector::new("wg://127.0.0.1:5596".parse().unwrap(), client_cfg); + connector.set_bind_addrs(vec!["10.0.0.1:0".parse().unwrap()]); + _tunnel_pingpong(listener, connector).await + } + + #[tokio::test] + async fn wg_server_erase_from_map_after_close() { + let (server_cfg, client_cfg) = create_wg_config(); + let mut listener = + WgTunnelListener::new("wg://127.0.0.1:5595".parse().unwrap(), server_cfg); + listener.listen().await.unwrap(); + + const CONN_COUNT: usize = 10; + + tokio::spawn(async move { + let mut tunnels = vec![]; + for _ in 0..CONN_COUNT { + let mut connector = WgTunnelConnector::new( + "wg://127.0.0.1:5595".parse().unwrap(), + client_cfg.clone(), + ); + let ret = connector.connect().await; + assert!(ret.is_ok()); + let t = ret.unwrap(); + let (_stream, mut sink) = t.split(); + sink.send(ZCPacket::new_with_payload("payload".as_bytes())) + .await + .unwrap(); + tunnels.push(t); + } + tokio::time::sleep(tokio::time::Duration::from_secs(1)).await; + }); + + for _ in 0..CONN_COUNT { + println!("accepting"); + let conn = listener.accept().await; + let (mut stream, _sink) = conn.unwrap().split(); + let packet = stream.next().await.unwrap().unwrap(); + assert_eq!("payload".as_bytes(), packet.payload()); + println!("accepting drop"); + } + + tokio::time::sleep(tokio::time::Duration::from_secs(2)).await; + + assert_eq!(0, listener.wg_peer_map.len()); + } +} diff --git a/easytier/src/tunnels/mod.rs b/easytier/src/tunnels/mod.rs index aa2860c7..ed511737 100644 --- a/easytier/src/tunnels/mod.rs +++ b/easytier/src/tunnels/mod.rs @@ -1,11 +1,11 @@ pub mod codec; pub mod common; -pub mod ring_tunnel; -pub mod stats; -pub mod tcp_tunnel; -pub mod tunnel_filter; -pub mod udp_tunnel; -pub mod wireguard; +// pub mod ring_tunnel; +// pub mod stats; +// pub mod tcp_tunnel; +// pub mod tunnel_filter; +// pub mod udp_tunnel; +// pub mod wireguard; use std::{fmt::Debug, net::SocketAddr, pin::Pin, sync::Arc}; diff --git a/easytier/src/vpn_portal/wireguard.rs b/easytier/src/vpn_portal/wireguard.rs index 0d7ce344..118fcb59 100644 --- a/easytier/src/vpn_portal/wireguard.rs +++ b/easytier/src/vpn_portal/wireguard.rs @@ -1,6 +1,5 @@ use std::{ net::{Ipv4Addr, SocketAddr}, - pin::Pin, sync::Arc, }; @@ -8,24 +7,22 @@ use anyhow::Context; use base64::{prelude::BASE64_STANDARD, Engine}; use cidr::Ipv4Inet; use dashmap::DashMap; -use futures::{SinkExt, StreamExt}; +use futures::StreamExt; use pnet::packet::ipv4::Ipv4Packet; -use tokio::{sync::Mutex, task::JoinSet}; -use tokio_util::bytes::Bytes; +use tokio::task::JoinSet; use crate::{ common::{ + config::NetworkIdentity, global_ctx::{ArcGlobalCtx, GlobalCtxEvent}, join_joinset_background, }, - peers::{ - packet::{self, ArchivedPacket}, - peer_manager::PeerManager, - PeerPacketFilter, - }, - tunnels::{ + peers::{peer_manager::PeerManager, PeerPacketFilter}, + tunnel::{ + mpsc::{MpscTunnel, MpscTunnelSender}, + packet_def::{PacketType, ZCPacket, ZCPacketType}, wireguard::{WgConfig, WgTunnelListener}, - DatagramSink, Tunnel, TunnelListener, + Tunnel, TunnelListener, }, }; @@ -33,9 +30,14 @@ use super::VpnPortal; type WgPeerIpTable = Arc>>; +pub(crate) fn get_wg_config_for_portal(nid: &NetworkIdentity) -> WgConfig { + let key_seed = format!("{}{}", nid.network_name, nid.network_secret); + WgConfig::new_for_portal(&key_seed, &key_seed) +} + struct ClientEntry { endpoint_addr: Option, - sink: Mutex>>, + sink: MpscTunnelSender, } struct WireGuardImpl { @@ -52,8 +54,7 @@ struct WireGuardImpl { impl WireGuardImpl { fn new(global_ctx: ArcGlobalCtx, peer_mgr: Arc) -> Self { let nid = global_ctx.get_network_identity(); - let key_seed = format!("{}{}", nid.network_name, nid.network_secret); - let wg_config = WgConfig::new_for_portal(&key_seed, &key_seed); + let wg_config = get_wg_config_for_portal(&nid); let vpn_cfg = global_ctx.config.get_vpn_portal_config().unwrap(); let listenr_addr = vpn_cfg.wireguard_listen; @@ -73,38 +74,41 @@ impl WireGuardImpl { peer_mgr: Arc, wg_peer_ip_table: WgPeerIpTable, ) { - let mut s = t.pin_stream(); + let info = t.info().unwrap_or_default(); + let mut mpsc_tunnel = MpscTunnel::new(t); + let mut stream = mpsc_tunnel.get_stream(); let mut ip_registered = false; - let info = t.info().unwrap_or_default(); let remote_addr = info.remote_addr.clone(); peer_mgr .get_global_ctx() .issue_event(GlobalCtxEvent::VpnPortalClientConnected( - info.local_addr, - info.remote_addr, + info.local_addr.clone(), + info.remote_addr.clone(), )); - while let Some(Ok(msg)) = s.next().await { - let Some(i) = Ipv4Packet::new(&msg) else { - tracing::error!(?msg, "Failed to parse ipv4 packet"); + while let Some(Ok(msg)) = stream.next().await { + assert_eq!(msg.packet_type(), ZCPacketType::WG); + let inner = msg.inner(); + let Some(i) = Ipv4Packet::new(&inner) else { + tracing::error!(?inner, "Failed to parse ipv4 packet"); continue; }; if !ip_registered { let client_entry = Arc::new(ClientEntry { endpoint_addr: remote_addr.parse().ok(), - sink: Mutex::new(t.pin_sink()), + sink: mpsc_tunnel.get_sink(), }); wg_peer_ip_table.insert(i.get_source(), client_entry.clone()); ip_registered = true; } tracing::trace!(?i, "Received from wg client"); + let dst = i.get_destination(); let _ = peer_mgr - .send_msg_ipv4(msg.clone(), i.get_destination()) + .send_msg_ipv4(ZCPacket::new_with_payload(inner.as_ref()), dst) .await; } - let info = t.info().unwrap_or_default(); peer_mgr .get_global_ctx() .issue_event(GlobalCtxEvent::VpnPortalClientDisconnected( @@ -120,34 +124,38 @@ impl WireGuardImpl { #[async_trait::async_trait] impl PeerPacketFilter for PeerPacketFilterForVpnPortal { - async fn try_process_packet_from_peer( - &self, - packet: &ArchivedPacket, - _: &Bytes, - ) -> Option<()> { - if packet.packet_type != packet::PacketType::Data { - return None; + async fn try_process_packet_from_peer(&self, packet: ZCPacket) -> Option { + let hdr = packet.peer_manager_header().unwrap(); + if hdr.packet_type != PacketType::Data as u8 { + return Some(packet); }; - let payload_bytes = packet.payload.as_bytes(); - + let payload_bytes = packet.payload(); let ipv4 = Ipv4Packet::new(payload_bytes)?; if ipv4.get_version() != 4 { - return None; + return Some(packet); } - let entry = self.wg_peer_ip_table.get(&ipv4.get_destination())?.clone(); + let Some(entry) = self + .wg_peer_ip_table + .get(&ipv4.get_destination()) + .map(|f| f.clone()) + else { + return Some(packet); + }; tracing::trace!(?ipv4, "Packet filter for vpn portal"); - let ret = entry - .sink - .lock() - .await - .send(Bytes::copy_from_slice(payload_bytes)) - .await; + let payload_offset = packet.packet_type().get_packet_offsets().payload_offset; + let packet = ZCPacket::new_from_buf( + packet.inner().split_off(payload_offset), + ZCPacketType::WG, + ); - ret.ok() + if let Err(ret) = entry.sink.send(packet).await { + tracing::debug!(?ret, "Failed to send packet to wg client"); + } + None } } @@ -164,9 +172,14 @@ impl WireGuardImpl { self.wg_config.clone(), ); - l.listen() - .await - .with_context(|| "Failed to start wireguard listener for vpn portal")?; + tracing::info!("Wireguard VPN Portal Starting"); + + { + let _g = self.global_ctx.net_ns.guard(); + l.listen() + .await + .with_context(|| "Failed to start wireguard listener for vpn portal")?; + } join_joinset_background(self.tasks.clone(), "wireguard".to_string()); @@ -296,62 +309,3 @@ Endpoint = {listenr_addr} # should be the public ip of the vpn server .collect() } } - -#[cfg(test)] -mod tests { - use std::sync::Arc; - - use super::*; - - use crate::{ - common::{ - config::{NetworkIdentity, VpnPortalConfig}, - global_ctx::tests::get_mock_global_ctx_with_network, - }, - connector::udp_hole_punch::tests::replace_stun_info_collector, - peers::{ - peer_manager::{PeerManager, RouteAlgoType}, - tests::wait_for_condition, - }, - rpc::NatType, - tunnels::{tcp_tunnel::TcpTunnelConnector, TunnelConnector}, - }; - - async fn portal_test() { - let (s, _r) = tokio::sync::mpsc::channel(1000); - let peer_mgr = Arc::new(PeerManager::new( - RouteAlgoType::Ospf, - get_mock_global_ctx_with_network(Some(NetworkIdentity { - network_name: "sijie".to_string(), - network_secret: "1919119".to_string(), - })), - s, - )); - replace_stun_info_collector(peer_mgr.clone(), NatType::Unknown); - peer_mgr - .get_global_ctx() - .config - .set_vpn_portal_config(VpnPortalConfig { - wireguard_listen: "0.0.0.0:11021".parse().unwrap(), - client_cidr: "10.14.14.0/24".parse().unwrap(), - }); - peer_mgr.run().await.unwrap(); - let mut pmgr_conn = TcpTunnelConnector::new("tcp://127.0.0.1:11010".parse().unwrap()); - let tunnel = pmgr_conn.connect().await; - peer_mgr.add_client_tunnel(tunnel.unwrap()).await.unwrap(); - wait_for_condition( - || async { - let routes = peer_mgr.list_routes().await; - println!("Routes: {:?}", routes); - routes.len() != 0 - }, - std::time::Duration::from_secs(10), - ) - .await; - - let mut wg = WireGuard::default(); - wg.start(peer_mgr.get_global_ctx(), peer_mgr.clone()) - .await - .unwrap(); - } -}