zero copy tunnel (#55)
make tunnel zero copy, for better performance. remove most of the locks in io path. introduce quic tunnel prepare for encryption
This commit is contained in:
Generated
+319
-1
@@ -142,6 +142,15 @@ version = "1.0.79"
|
|||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "080e9890a082662b09c1ad45f567faeeb47f22b5fb23895fbe1e651e718e25ca"
|
checksum = "080e9890a082662b09c1ad45f567faeeb47f22b5fb23895fbe1e651e718e25ca"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "async-event"
|
||||||
|
version = "0.1.0"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "4172595da7ffb68640606be5723e35a353555f2829e9209437627a003725bbdb"
|
||||||
|
dependencies = [
|
||||||
|
"loom",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "async-recursion"
|
name = "async-recursion"
|
||||||
version = "1.1.0"
|
version = "1.1.0"
|
||||||
@@ -219,6 +228,12 @@ dependencies = [
|
|||||||
"critical-section",
|
"critical-section",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "atomicbox"
|
||||||
|
version = "0.4.0"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "8a9a3820bc9e9aaf60c8389c2a4808548599f4ff254ce6bdb608ac3631d4ad76"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "auto_impl"
|
name = "auto_impl"
|
||||||
version = "1.1.0"
|
version = "1.1.0"
|
||||||
@@ -309,6 +324,12 @@ version = "0.21.7"
|
|||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "9d297deb1925b89f2ccc13d7635fa0714f12c87adce1c75356b39ca9b7178567"
|
checksum = "9d297deb1925b89f2ccc13d7635fa0714f12c87adce1c75356b39ca9b7178567"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "base64"
|
||||||
|
version = "0.22.0"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "9475866fec1451be56a3c2400fd081ff546538961565ccb5b7142cbd22bc7a51"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "base64ct"
|
name = "base64ct"
|
||||||
version = "1.6.0"
|
version = "1.6.0"
|
||||||
@@ -381,7 +402,7 @@ dependencies = [
|
|||||||
"nix 0.25.1",
|
"nix 0.25.1",
|
||||||
"parking_lot",
|
"parking_lot",
|
||||||
"rand_core 0.6.4",
|
"rand_core 0.6.4",
|
||||||
"ring",
|
"ring 0.16.20",
|
||||||
"tracing",
|
"tracing",
|
||||||
"untrusted 0.9.0",
|
"untrusted 0.9.0",
|
||||||
"x25519-dalek",
|
"x25519-dalek",
|
||||||
@@ -1078,6 +1099,26 @@ version = "2.5.0"
|
|||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "7e962a19be5cfc3f3bf6dd8f61eb50107f356ad6270fbb3ed41476571db78be5"
|
checksum = "7e962a19be5cfc3f3bf6dd8f61eb50107f356ad6270fbb3ed41476571db78be5"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "defguard_wireguard_rs"
|
||||||
|
version = "0.4.2"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "6ba16f17698d4b389907310af018b0c3a80b025bba9c38d947cbc6dd70921743"
|
||||||
|
dependencies = [
|
||||||
|
"base64 0.21.7",
|
||||||
|
"libc",
|
||||||
|
"log",
|
||||||
|
"netlink-packet-core",
|
||||||
|
"netlink-packet-generic",
|
||||||
|
"netlink-packet-route",
|
||||||
|
"netlink-packet-utils",
|
||||||
|
"netlink-packet-wireguard",
|
||||||
|
"netlink-sys",
|
||||||
|
"nix 0.27.1",
|
||||||
|
"serde",
|
||||||
|
"thiserror",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "deprecate-until"
|
name = "deprecate-until"
|
||||||
version = "0.1.1"
|
version = "0.1.1"
|
||||||
@@ -1149,6 +1190,16 @@ dependencies = [
|
|||||||
"syn 1.0.109",
|
"syn 1.0.109",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "diatomic-waker"
|
||||||
|
version = "0.1.0"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "28025fb55a9d815acf7b0877555f437254f373036eec6ed265116c7a5c0825e9"
|
||||||
|
dependencies = [
|
||||||
|
"loom",
|
||||||
|
"waker-fn",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "digest"
|
name = "digest"
|
||||||
version = "0.10.7"
|
version = "0.10.7"
|
||||||
@@ -1228,17 +1279,20 @@ dependencies = [
|
|||||||
"async-recursion",
|
"async-recursion",
|
||||||
"async-stream",
|
"async-stream",
|
||||||
"async-trait",
|
"async-trait",
|
||||||
|
"atomicbox",
|
||||||
"auto_impl",
|
"auto_impl",
|
||||||
"base64 0.21.7",
|
"base64 0.21.7",
|
||||||
"boringtun",
|
"boringtun",
|
||||||
"bytecodec",
|
"bytecodec",
|
||||||
"byteorder",
|
"byteorder",
|
||||||
|
"bytes",
|
||||||
"chrono",
|
"chrono",
|
||||||
"cidr",
|
"cidr",
|
||||||
"clap",
|
"clap",
|
||||||
"crossbeam",
|
"crossbeam",
|
||||||
"crossbeam-queue",
|
"crossbeam-queue",
|
||||||
"dashmap",
|
"dashmap",
|
||||||
|
"defguard_wireguard_rs",
|
||||||
"derivative",
|
"derivative",
|
||||||
"futures",
|
"futures",
|
||||||
"gethostname",
|
"gethostname",
|
||||||
@@ -1249,19 +1303,24 @@ dependencies = [
|
|||||||
"once_cell",
|
"once_cell",
|
||||||
"pathfinding",
|
"pathfinding",
|
||||||
"percent-encoding",
|
"percent-encoding",
|
||||||
|
"pin-project-lite",
|
||||||
"pnet",
|
"pnet",
|
||||||
"postcard",
|
"postcard",
|
||||||
"prost",
|
"prost",
|
||||||
"public-ip",
|
"public-ip",
|
||||||
|
"quinn",
|
||||||
"rand 0.8.5",
|
"rand 0.8.5",
|
||||||
|
"rcgen",
|
||||||
"reqwest",
|
"reqwest",
|
||||||
"rkyv",
|
"rkyv",
|
||||||
"rstest",
|
"rstest",
|
||||||
|
"rustls",
|
||||||
"serde",
|
"serde",
|
||||||
"serial_test",
|
"serial_test",
|
||||||
"socket2 0.5.5",
|
"socket2 0.5.5",
|
||||||
"stun_codec",
|
"stun_codec",
|
||||||
"tabled",
|
"tabled",
|
||||||
|
"tachyonix",
|
||||||
"tarpc",
|
"tarpc",
|
||||||
"thiserror",
|
"thiserror",
|
||||||
"time",
|
"time",
|
||||||
@@ -1279,6 +1338,7 @@ dependencies = [
|
|||||||
"url",
|
"url",
|
||||||
"uuid",
|
"uuid",
|
||||||
"windows-sys 0.52.0",
|
"windows-sys 0.52.0",
|
||||||
|
"zerocopy",
|
||||||
"zip",
|
"zip",
|
||||||
]
|
]
|
||||||
|
|
||||||
@@ -2691,6 +2751,80 @@ dependencies = [
|
|||||||
"jni-sys",
|
"jni-sys",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "netlink-packet-core"
|
||||||
|
version = "0.7.0"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "72724faf704479d67b388da142b186f916188505e7e0b26719019c525882eda4"
|
||||||
|
dependencies = [
|
||||||
|
"anyhow",
|
||||||
|
"byteorder",
|
||||||
|
"netlink-packet-utils",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "netlink-packet-generic"
|
||||||
|
version = "0.3.3"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "1cd7eb8ad331c84c6b8cb7f685b448133e5ad82e1ffd5acafac374af4a5a308b"
|
||||||
|
dependencies = [
|
||||||
|
"anyhow",
|
||||||
|
"byteorder",
|
||||||
|
"netlink-packet-core",
|
||||||
|
"netlink-packet-utils",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "netlink-packet-route"
|
||||||
|
version = "0.17.1"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "053998cea5a306971f88580d0829e90f270f940befd7cf928da179d4187a5a66"
|
||||||
|
dependencies = [
|
||||||
|
"anyhow",
|
||||||
|
"bitflags 1.3.2",
|
||||||
|
"byteorder",
|
||||||
|
"libc",
|
||||||
|
"netlink-packet-core",
|
||||||
|
"netlink-packet-utils",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "netlink-packet-utils"
|
||||||
|
version = "0.5.2"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "0ede8a08c71ad5a95cdd0e4e52facd37190977039a4704eb82a283f713747d34"
|
||||||
|
dependencies = [
|
||||||
|
"anyhow",
|
||||||
|
"byteorder",
|
||||||
|
"paste",
|
||||||
|
"thiserror",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "netlink-packet-wireguard"
|
||||||
|
version = "0.2.3"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "60b25b050ff1f6a1e23c6777b72db22790fe5b6b5ccfd3858672587a79876c8f"
|
||||||
|
dependencies = [
|
||||||
|
"anyhow",
|
||||||
|
"byteorder",
|
||||||
|
"libc",
|
||||||
|
"log",
|
||||||
|
"netlink-packet-generic",
|
||||||
|
"netlink-packet-utils",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "netlink-sys"
|
||||||
|
version = "0.8.6"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "416060d346fbaf1f23f9512963e3e878f1a78e707cb699ba9215761754244307"
|
||||||
|
dependencies = [
|
||||||
|
"bytes",
|
||||||
|
"libc",
|
||||||
|
"log",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "network-interface"
|
name = "network-interface"
|
||||||
version = "1.1.1"
|
version = "1.1.1"
|
||||||
@@ -3011,6 +3145,12 @@ dependencies = [
|
|||||||
"subtle",
|
"subtle",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "paste"
|
||||||
|
version = "1.0.14"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "de3145af08024dea9fa9914f381a17b8fc6034dfb00f3a84013f7ff43f29ed4c"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "pathdiff"
|
name = "pathdiff"
|
||||||
version = "0.2.1"
|
version = "0.2.1"
|
||||||
@@ -3044,6 +3184,16 @@ dependencies = [
|
|||||||
"sha2",
|
"sha2",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "pem"
|
||||||
|
version = "3.0.4"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "8e459365e590736a54c3fa561947c84837534b8e9af6fc5bf781307e82658fae"
|
||||||
|
dependencies = [
|
||||||
|
"base64 0.22.0",
|
||||||
|
"serde",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "percent-encoding"
|
name = "percent-encoding"
|
||||||
version = "2.3.1"
|
version = "2.3.1"
|
||||||
@@ -3574,6 +3724,54 @@ dependencies = [
|
|||||||
"memchr",
|
"memchr",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "quinn"
|
||||||
|
version = "0.10.2"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "8cc2c5017e4b43d5995dcea317bc46c1e09404c0a9664d2908f7f02dfe943d75"
|
||||||
|
dependencies = [
|
||||||
|
"bytes",
|
||||||
|
"pin-project-lite",
|
||||||
|
"quinn-proto",
|
||||||
|
"quinn-udp",
|
||||||
|
"rustc-hash",
|
||||||
|
"rustls",
|
||||||
|
"thiserror",
|
||||||
|
"tokio",
|
||||||
|
"tracing",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "quinn-proto"
|
||||||
|
version = "0.10.6"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "141bf7dfde2fbc246bfd3fe12f2455aa24b0fbd9af535d8c86c7bd1381ff2b1a"
|
||||||
|
dependencies = [
|
||||||
|
"bytes",
|
||||||
|
"rand 0.8.5",
|
||||||
|
"ring 0.16.20",
|
||||||
|
"rustc-hash",
|
||||||
|
"rustls",
|
||||||
|
"rustls-native-certs",
|
||||||
|
"slab",
|
||||||
|
"thiserror",
|
||||||
|
"tinyvec",
|
||||||
|
"tracing",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "quinn-udp"
|
||||||
|
version = "0.4.1"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "055b4e778e8feb9f93c4e439f71dc2156ef13360b432b799e179a8c4cdf0b1d7"
|
||||||
|
dependencies = [
|
||||||
|
"bytes",
|
||||||
|
"libc",
|
||||||
|
"socket2 0.5.5",
|
||||||
|
"tracing",
|
||||||
|
"windows-sys 0.48.0",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "quote"
|
name = "quote"
|
||||||
version = "1.0.35"
|
version = "1.0.35"
|
||||||
@@ -3686,6 +3884,18 @@ version = "0.5.2"
|
|||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "f2ff9a1f06a88b01621b7ae906ef0211290d1c8a168a15542486a8f61c0833b9"
|
checksum = "f2ff9a1f06a88b01621b7ae906ef0211290d1c8a168a15542486a8f61c0833b9"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "rcgen"
|
||||||
|
version = "0.11.3"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "52c4f3084aa3bc7dfbba4eff4fab2a54db4324965d8872ab933565e6fbd83bc6"
|
||||||
|
dependencies = [
|
||||||
|
"pem",
|
||||||
|
"ring 0.16.20",
|
||||||
|
"time",
|
||||||
|
"yasna",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "redox_syscall"
|
name = "redox_syscall"
|
||||||
version = "0.4.1"
|
version = "0.4.1"
|
||||||
@@ -3820,6 +4030,21 @@ dependencies = [
|
|||||||
"winapi",
|
"winapi",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "ring"
|
||||||
|
version = "0.17.8"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "c17fa4cb658e3583423e915b9f3acc01cceaee1860e33d59ebae66adc3a2dc0d"
|
||||||
|
dependencies = [
|
||||||
|
"cc",
|
||||||
|
"cfg-if",
|
||||||
|
"getrandom 0.2.12",
|
||||||
|
"libc",
|
||||||
|
"spin 0.9.8",
|
||||||
|
"untrusted 0.9.0",
|
||||||
|
"windows-sys 0.52.0",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "rkyv"
|
name = "rkyv"
|
||||||
version = "0.7.43"
|
version = "0.7.43"
|
||||||
@@ -3912,6 +4137,30 @@ dependencies = [
|
|||||||
"windows-sys 0.52.0",
|
"windows-sys 0.52.0",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "rustls"
|
||||||
|
version = "0.21.11"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "7fecbfb7b1444f477b345853b1fce097a2c6fb637b2bfb87e6bc5db0f043fae4"
|
||||||
|
dependencies = [
|
||||||
|
"log",
|
||||||
|
"ring 0.17.8",
|
||||||
|
"rustls-webpki",
|
||||||
|
"sct",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "rustls-native-certs"
|
||||||
|
version = "0.6.3"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "a9aace74cb666635c918e9c12bc0d348266037aa8eb599b5cba565709a8dff00"
|
||||||
|
dependencies = [
|
||||||
|
"openssl-probe",
|
||||||
|
"rustls-pemfile",
|
||||||
|
"schannel",
|
||||||
|
"security-framework",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "rustls-pemfile"
|
name = "rustls-pemfile"
|
||||||
version = "1.0.4"
|
version = "1.0.4"
|
||||||
@@ -3921,6 +4170,16 @@ dependencies = [
|
|||||||
"base64 0.21.7",
|
"base64 0.21.7",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "rustls-webpki"
|
||||||
|
version = "0.101.7"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "8b6275d1ee7a1cd780b64aca7726599a1dbc893b1e64144529e55c3c2f745765"
|
||||||
|
dependencies = [
|
||||||
|
"ring 0.17.8",
|
||||||
|
"untrusted 0.9.0",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "rustversion"
|
name = "rustversion"
|
||||||
version = "1.0.14"
|
version = "1.0.14"
|
||||||
@@ -3963,6 +4222,16 @@ version = "1.2.0"
|
|||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "94143f37725109f92c262ed2cf5e59bce7498c01bcc1502d7b9afe439a4e9f49"
|
checksum = "94143f37725109f92c262ed2cf5e59bce7498c01bcc1502d7b9afe439a4e9f49"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "sct"
|
||||||
|
version = "0.7.1"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "da046153aa2352493d6cb7da4b6e5c0c057d8a1d0a9aa8560baffdd945acd414"
|
||||||
|
dependencies = [
|
||||||
|
"ring 0.17.8",
|
||||||
|
"untrusted 0.9.0",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "seahash"
|
name = "seahash"
|
||||||
version = "4.1.0"
|
version = "4.1.0"
|
||||||
@@ -4487,6 +4756,19 @@ dependencies = [
|
|||||||
"syn 1.0.109",
|
"syn 1.0.109",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "tachyonix"
|
||||||
|
version = "0.2.1"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "64e0bf82be3359dbefbfea621d6365db00e1d7846561daad2ea74cc4cb4c9604"
|
||||||
|
dependencies = [
|
||||||
|
"async-event",
|
||||||
|
"crossbeam-utils",
|
||||||
|
"diatomic-waker",
|
||||||
|
"futures-core",
|
||||||
|
"loom",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "tao"
|
name = "tao"
|
||||||
version = "0.16.8"
|
version = "0.16.8"
|
||||||
@@ -5504,6 +5786,12 @@ dependencies = [
|
|||||||
"libc",
|
"libc",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "waker-fn"
|
||||||
|
version = "1.1.1"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "f3c4517f54858c779bbcbf228f4fca63d121bf85fbecb2dc578cdf4a39395690"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "walkdir"
|
name = "walkdir"
|
||||||
version = "2.5.0"
|
version = "2.5.0"
|
||||||
@@ -6192,6 +6480,36 @@ dependencies = [
|
|||||||
"rustix",
|
"rustix",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "yasna"
|
||||||
|
version = "0.5.2"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "e17bb3549cc1321ae1296b9cdc2698e2b6cb1992adfa19a8c72e5b7a738f44cd"
|
||||||
|
dependencies = [
|
||||||
|
"time",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "zerocopy"
|
||||||
|
version = "0.7.32"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "74d4d3961e53fa4c9a25a8637fc2bfaf2595b3d3ae34875568a5cf64787716be"
|
||||||
|
dependencies = [
|
||||||
|
"byteorder",
|
||||||
|
"zerocopy-derive",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "zerocopy-derive"
|
||||||
|
version = "0.7.32"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "9ce1b18ccd8e73a9321186f97e46f9f04b778851177567b1975109d26a08d2a6"
|
||||||
|
dependencies = [
|
||||||
|
"proc-macro2",
|
||||||
|
"quote",
|
||||||
|
"syn 2.0.48",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "zeroize"
|
name = "zeroize"
|
||||||
version = "1.7.0"
|
version = "1.7.0"
|
||||||
|
|||||||
@@ -58,6 +58,17 @@ async-trait = "0.1.74"
|
|||||||
dashmap = "5.5.3"
|
dashmap = "5.5.3"
|
||||||
timedmap = "=1.0.1"
|
timedmap = "=1.0.1"
|
||||||
|
|
||||||
|
# for full-path zero-copy
|
||||||
|
zerocopy = { version = "0.7.32", features = ["derive", "simd"] }
|
||||||
|
bytes = "1.5.0"
|
||||||
|
pin-project-lite = "0.2.13"
|
||||||
|
atomicbox = "0.4.0"
|
||||||
|
tachyonix = "0.2.1"
|
||||||
|
|
||||||
|
quinn = { version = "0.10.2" }
|
||||||
|
rustls = { version = "0.21.0", features = ["dangerous_configuration"] }
|
||||||
|
rcgen = "0.11.1"
|
||||||
|
|
||||||
# for tap device
|
# for tap device
|
||||||
tun = { version = "0.6.1", features = ["async"] }
|
tun = { version = "0.6.1", features = ["async"] }
|
||||||
# for net ns
|
# for net ns
|
||||||
@@ -148,3 +159,4 @@ zip = "0.6.6"
|
|||||||
[dev-dependencies]
|
[dev-dependencies]
|
||||||
serial_test = "3.0.0"
|
serial_test = "3.0.0"
|
||||||
rstest = "0.18.2"
|
rstest = "0.18.2"
|
||||||
|
defguard_wireguard_rs = "0.4.2"
|
||||||
|
|||||||
@@ -157,3 +157,21 @@ message GetVpnPortalInfoResponse {
|
|||||||
service VpnPortalRpc {
|
service VpnPortalRpc {
|
||||||
rpc GetVpnPortalInfo (GetVpnPortalInfoRequest) returns (GetVpnPortalInfoResponse);
|
rpc GetVpnPortalInfo (GetVpnPortalInfoRequest) returns (GetVpnPortalInfoResponse);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
message HandshakeRequest {
|
||||||
|
uint32 magic = 1;
|
||||||
|
uint32 my_peer_id = 2;
|
||||||
|
uint32 version = 3;
|
||||||
|
repeated string features = 4;
|
||||||
|
string network_name = 5;
|
||||||
|
string network_secret = 6;
|
||||||
|
}
|
||||||
|
|
||||||
|
message TaRpcPacket {
|
||||||
|
uint32 from_peer = 1;
|
||||||
|
uint32 to_peer = 2;
|
||||||
|
uint32 service_id = 3;
|
||||||
|
uint32 transact_id = 4;
|
||||||
|
bool is_req = 5;
|
||||||
|
bytes content = 6;
|
||||||
|
}
|
||||||
|
|||||||
@@ -2,7 +2,7 @@ use std::{io, result};
|
|||||||
|
|
||||||
use thiserror::Error;
|
use thiserror::Error;
|
||||||
|
|
||||||
use crate::tunnels;
|
use crate::{tunnel, tunnels};
|
||||||
|
|
||||||
use super::PeerId;
|
use super::PeerId;
|
||||||
|
|
||||||
@@ -38,6 +38,15 @@ pub enum Error {
|
|||||||
Unknown,
|
Unknown,
|
||||||
#[error("anyhow error: {0}")]
|
#[error("anyhow error: {0}")]
|
||||||
AnyhowError(#[from] anyhow::Error),
|
AnyhowError(#[from] anyhow::Error),
|
||||||
|
|
||||||
|
#[error("wait resp error: {0}")]
|
||||||
|
WaitRespError(String),
|
||||||
|
|
||||||
|
#[error("tunnel error")]
|
||||||
|
TunnelErr(#[from] tunnel::TunnelError),
|
||||||
|
|
||||||
|
#[error("message decode error: {0}")]
|
||||||
|
MessageDecodeError(String),
|
||||||
}
|
}
|
||||||
|
|
||||||
pub type Result<T> = result::Result<T, Error>;
|
pub type Result<T> = result::Result<T, Error>;
|
||||||
|
|||||||
+13
-25
@@ -87,7 +87,7 @@ impl Stun {
|
|||||||
pub fn new(stun_server: SocketAddr) -> Self {
|
pub fn new(stun_server: SocketAddr) -> Self {
|
||||||
Self {
|
Self {
|
||||||
stun_server,
|
stun_server,
|
||||||
req_repeat: 5,
|
req_repeat: 1,
|
||||||
resp_timeout: Duration::from_millis(3000),
|
resp_timeout: Duration::from_millis(3000),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -208,6 +208,7 @@ impl Stun {
|
|||||||
let mut tids = vec![];
|
let mut tids = vec![];
|
||||||
for _ in 0..self.req_repeat {
|
for _ in 0..self.req_repeat {
|
||||||
let tid = rand::random::<u32>();
|
let tid = rand::random::<u32>();
|
||||||
|
// let tid = 1;
|
||||||
let mut buf = [0u8; 28];
|
let mut buf = [0u8; 28];
|
||||||
// memset buf
|
// memset buf
|
||||||
unsafe { std::ptr::write_bytes(buf.as_mut_ptr(), 0, buf.len()) };
|
unsafe { std::ptr::write_bytes(buf.as_mut_ptr(), 0, buf.len()) };
|
||||||
@@ -511,30 +512,17 @@ mod tests {
|
|||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_stun_bind_request() {
|
async fn test_stun_bind_request() {
|
||||||
// miwifi / qq seems not correctly responde to change_ip and change_port, they always try to change the src ip and port.
|
// miwifi / qq seems not correctly responde to change_ip and change_port, they always try to change the src ip and port.
|
||||||
let mut ips = HostResolverIter::new(vec!["stun1.l.google.com:19302".to_string()]);
|
// let mut ips = HostResolverIter::new(vec!["stun1.l.google.com:19302".to_string()]);
|
||||||
let stun = Stun::new(ips.next().await.unwrap());
|
let mut ips_ = HostResolverIter::new(vec!["stun.canets.org:3478".to_string()]);
|
||||||
// let stun = Stun::new("180.235.108.91:3478".to_string());
|
let mut ips = vec![];
|
||||||
// let stun = Stun::new("193.22.2.248:3478".to_string());
|
while let Some(ip) = ips_.next().await {
|
||||||
// let stun = Stun::new("stun.chat.bilibili.com:3478".to_string());
|
ips.push(ip);
|
||||||
// let stun = Stun::new("stun.miwifi.com:3478".to_string());
|
}
|
||||||
|
println!("ip: {:?}", ips);
|
||||||
// github actions are port restricted nat, so we only test last one.
|
for ip in ips.iter() {
|
||||||
|
let stun = Stun::new(ip.clone());
|
||||||
// let rs = stun.bind_request(12345, true, true).await.unwrap();
|
let _rs = stun.bind_request(12345, true, true).await;
|
||||||
// assert!(rs.ip_changed);
|
}
|
||||||
// assert!(rs.port_changed);
|
|
||||||
|
|
||||||
// let rs = stun.bind_request(12345, true, false).await.unwrap();
|
|
||||||
// assert!(rs.ip_changed);
|
|
||||||
// assert!(!rs.port_changed);
|
|
||||||
|
|
||||||
// let rs = stun.bind_request(12345, false, true).await.unwrap();
|
|
||||||
// assert!(!rs.ip_changed);
|
|
||||||
// assert!(rs.port_changed);
|
|
||||||
|
|
||||||
let rs = stun.bind_request(12345, false, false).await.unwrap();
|
|
||||||
assert!(!rs.ip_changed);
|
|
||||||
assert!(!rs.port_changed);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
|
|||||||
@@ -7,7 +7,9 @@ use tokio::{
|
|||||||
time::timeout,
|
time::timeout,
|
||||||
};
|
};
|
||||||
|
|
||||||
use crate::{common::PeerId, peers::peer_conn::PeerConnId, rpc as easytier_rpc};
|
use crate::{
|
||||||
|
common::PeerId, peers::zc_peer_conn::PeerConnId, rpc as easytier_rpc, tunnel::TunnelConnector,
|
||||||
|
};
|
||||||
|
|
||||||
use crate::{
|
use crate::{
|
||||||
common::{
|
common::{
|
||||||
@@ -21,13 +23,13 @@ use crate::{
|
|||||||
connector_manage_rpc_server::ConnectorManageRpc, Connector, ConnectorStatus,
|
connector_manage_rpc_server::ConnectorManageRpc, Connector, ConnectorStatus,
|
||||||
ListConnectorRequest, ManageConnectorRequest,
|
ListConnectorRequest, ManageConnectorRequest,
|
||||||
},
|
},
|
||||||
tunnels::{Tunnel, TunnelConnector},
|
|
||||||
use_global_var,
|
use_global_var,
|
||||||
};
|
};
|
||||||
|
|
||||||
use super::create_connector_by_url;
|
use super::create_connector_by_url;
|
||||||
|
|
||||||
type ConnectorMap = Arc<DashMap<String, Box<dyn TunnelConnector + Send + Sync>>>;
|
type MutexConnector = Arc<Mutex<Box<dyn TunnelConnector>>>;
|
||||||
|
type ConnectorMap = Arc<DashMap<String, MutexConnector>>;
|
||||||
|
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
struct ReconnResult {
|
struct ReconnResult {
|
||||||
@@ -81,12 +83,13 @@ impl ManualConnectorManager {
|
|||||||
|
|
||||||
pub fn add_connector<T>(&self, connector: T)
|
pub fn add_connector<T>(&self, connector: T)
|
||||||
where
|
where
|
||||||
T: TunnelConnector + Send + Sync + 'static,
|
T: TunnelConnector + 'static,
|
||||||
{
|
{
|
||||||
log::info!("add_connector: {}", connector.remote_url());
|
log::info!("add_connector: {}", connector.remote_url());
|
||||||
self.data
|
self.data.connectors.insert(
|
||||||
.connectors
|
connector.remote_url().into(),
|
||||||
.insert(connector.remote_url().into(), Box::new(connector));
|
Arc::new(Mutex::new(Box::new(connector))),
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn add_connector_by_url(&self, url: &str) -> Result<(), Error> {
|
pub async fn add_connector_by_url(&self, url: &str) -> Result<(), Error> {
|
||||||
@@ -254,7 +257,7 @@ impl ManualConnectorManager {
|
|||||||
async fn conn_reconnect(
|
async fn conn_reconnect(
|
||||||
data: Arc<ConnectorManagerData>,
|
data: Arc<ConnectorManagerData>,
|
||||||
dead_url: String,
|
dead_url: String,
|
||||||
connector: Box<dyn TunnelConnector + Send + Sync>,
|
connector: MutexConnector,
|
||||||
) -> Result<ReconnResult, Error> {
|
) -> Result<ReconnResult, Error> {
|
||||||
let connector = Arc::new(Mutex::new(Some(connector)));
|
let connector = Arc::new(Mutex::new(Some(connector)));
|
||||||
let net_ns = data.net_ns.clone();
|
let net_ns = data.net_ns.clone();
|
||||||
@@ -269,15 +272,17 @@ impl ManualConnectorManager {
|
|||||||
let mut locked = connector_clone.lock().await;
|
let mut locked = connector_clone.lock().await;
|
||||||
let conn = locked.as_mut().unwrap();
|
let conn = locked.as_mut().unwrap();
|
||||||
// TODO: should support set v6 here, use url in connector array
|
// TODO: should support set v6 here, use url in connector array
|
||||||
set_bind_addr_for_peer_connector(conn, true, &ip_collector).await;
|
set_bind_addr_for_peer_connector(conn.lock().await.as_mut(), true, &ip_collector).await;
|
||||||
|
|
||||||
data_clone
|
data_clone
|
||||||
.global_ctx
|
.global_ctx
|
||||||
.issue_event(GlobalCtxEvent::Connecting(conn.remote_url().clone()));
|
.issue_event(GlobalCtxEvent::Connecting(
|
||||||
|
conn.lock().await.remote_url().clone(),
|
||||||
|
));
|
||||||
|
|
||||||
let _g = net_ns.guard();
|
let _g = net_ns.guard();
|
||||||
log::info!("reconnect try connect... conn: {:?}", conn);
|
log::info!("reconnect try connect... conn: {:?}", conn);
|
||||||
let tunnel = conn.connect().await?;
|
let tunnel = conn.lock().await.connect().await?;
|
||||||
log::info!("reconnect get tunnel succ: {:?}", tunnel);
|
log::info!("reconnect get tunnel succ: {:?}", tunnel);
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
url_clone,
|
url_clone,
|
||||||
@@ -359,7 +364,7 @@ mod tests {
|
|||||||
use crate::{
|
use crate::{
|
||||||
peers::tests::create_mock_peer_manager,
|
peers::tests::create_mock_peer_manager,
|
||||||
set_global_var,
|
set_global_var,
|
||||||
tunnels::{Tunnel, TunnelError},
|
tunnel::{Tunnel, TunnelError},
|
||||||
};
|
};
|
||||||
|
|
||||||
use super::*;
|
use super::*;
|
||||||
@@ -379,7 +384,7 @@ mod tests {
|
|||||||
}
|
}
|
||||||
async fn connect(&mut self) -> Result<Box<dyn Tunnel>, TunnelError> {
|
async fn connect(&mut self) -> Result<Box<dyn Tunnel>, TunnelError> {
|
||||||
tokio::time::sleep(std::time::Duration::from_millis(10)).await;
|
tokio::time::sleep(std::time::Duration::from_millis(10)).await;
|
||||||
Err(TunnelError::CommonError("fake error".into()))
|
Err(TunnelError::InvalidPacket("fake error".into()))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -5,10 +5,10 @@ use std::{
|
|||||||
|
|
||||||
use crate::{
|
use crate::{
|
||||||
common::{error::Error, global_ctx::ArcGlobalCtx, network::IPCollector},
|
common::{error::Error, global_ctx::ArcGlobalCtx, network::IPCollector},
|
||||||
tunnels::{
|
tunnel::{
|
||||||
ring_tunnel::RingTunnelConnector,
|
ring::RingTunnelConnector,
|
||||||
tcp_tunnel::TcpTunnelConnector,
|
tcp::TcpTunnelConnector,
|
||||||
udp_tunnel::UdpTunnelConnector,
|
udp::UdpTunnelConnector,
|
||||||
wireguard::{WgConfig, WgTunnelConnector},
|
wireguard::{WgConfig, WgTunnelConnector},
|
||||||
TunnelConnector,
|
TunnelConnector,
|
||||||
},
|
},
|
||||||
@@ -19,7 +19,7 @@ pub mod manual;
|
|||||||
pub mod udp_hole_punch;
|
pub mod udp_hole_punch;
|
||||||
|
|
||||||
async fn set_bind_addr_for_peer_connector(
|
async fn set_bind_addr_for_peer_connector(
|
||||||
connector: &mut impl TunnelConnector,
|
connector: &mut (impl TunnelConnector + ?Sized),
|
||||||
is_ipv4: bool,
|
is_ipv4: bool,
|
||||||
ip_collector: &Arc<IPCollector>,
|
ip_collector: &Arc<IPCollector>,
|
||||||
) {
|
) {
|
||||||
@@ -45,7 +45,7 @@ async fn set_bind_addr_for_peer_connector(
|
|||||||
pub async fn create_connector_by_url(
|
pub async fn create_connector_by_url(
|
||||||
url: &str,
|
url: &str,
|
||||||
global_ctx: &ArcGlobalCtx,
|
global_ctx: &ArcGlobalCtx,
|
||||||
) -> Result<Box<dyn TunnelConnector + Send + Sync + 'static>, Error> {
|
) -> Result<Box<dyn TunnelConnector + 'static>, Error> {
|
||||||
let url = url::Url::parse(url).map_err(|_| Error::InvalidUrl(url.to_owned()))?;
|
let url = url::Url::parse(url).map_err(|_| Error::InvalidUrl(url.to_owned()))?;
|
||||||
match url.scheme() {
|
match url.scheme() {
|
||||||
"tcp" => {
|
"tcp" => {
|
||||||
|
|||||||
@@ -2,20 +2,21 @@ use std::{net::SocketAddr, sync::Arc};
|
|||||||
|
|
||||||
use anyhow::Context;
|
use anyhow::Context;
|
||||||
use crossbeam::atomic::AtomicCell;
|
use crossbeam::atomic::AtomicCell;
|
||||||
use rand::{seq::SliceRandom, Rng, SeedableRng};
|
use rand::{seq::SliceRandom, SeedableRng};
|
||||||
use tokio::{net::UdpSocket, sync::Mutex, task::JoinSet};
|
use tokio::{net::UdpSocket, sync::Mutex, task::JoinSet};
|
||||||
use tracing::Instrument;
|
use tracing::Instrument;
|
||||||
|
|
||||||
use crate::{
|
use crate::{
|
||||||
common::{
|
common::{
|
||||||
constants, error::Error, global_ctx::ArcGlobalCtx, join_joinset_background,
|
constants, error::Error, global_ctx::ArcGlobalCtx, join_joinset_background,
|
||||||
rkyv_util::encode_to_bytes, stun::StunInfoCollectorTrait, PeerId,
|
stun::StunInfoCollectorTrait, PeerId,
|
||||||
},
|
},
|
||||||
peers::peer_manager::PeerManager,
|
peers::peer_manager::PeerManager,
|
||||||
rpc::NatType,
|
rpc::NatType,
|
||||||
tunnels::{
|
tunnel::{
|
||||||
common::setup_sokcet2,
|
common::setup_sokcet2,
|
||||||
udp_tunnel::{UdpPacket, UdpTunnelConnector, UdpTunnelListener},
|
packet_def::ZCPacketType,
|
||||||
|
udp::{new_hole_punch_packet, UdpTunnelConnector, UdpTunnelListener},
|
||||||
Tunnel, TunnelConnCounter, TunnelListener,
|
Tunnel, TunnelConnCounter, TunnelListener,
|
||||||
},
|
},
|
||||||
};
|
};
|
||||||
@@ -149,15 +150,10 @@ impl UdpHolePunchService for UdpHolePunchRpcServer {
|
|||||||
self.tasks.lock().unwrap().spawn(async move {
|
self.tasks.lock().unwrap().spawn(async move {
|
||||||
for _ in 0..10 {
|
for _ in 0..10 {
|
||||||
tracing::info!(?local_mapped_addr, "sending hole punching packet");
|
tracing::info!(?local_mapped_addr, "sending hole punching packet");
|
||||||
// generate a 128 bytes vec with random data
|
|
||||||
let mut rng = rand::rngs::StdRng::from_entropy();
|
|
||||||
let mut buf = vec![0u8; 128];
|
|
||||||
rng.fill(&mut buf[..]);
|
|
||||||
|
|
||||||
let udp_packet = UdpPacket::new_hole_punch_packet(buf);
|
let udp_packet = new_hole_punch_packet();
|
||||||
let udp_packet_bytes = encode_to_bytes::<_, 256>(&udp_packet);
|
|
||||||
let _ = socket
|
let _ = socket
|
||||||
.send_to(udp_packet_bytes.as_ref(), local_mapped_addr)
|
.send_to(&udp_packet.into_bytes(ZCPacketType::UDP), local_mapped_addr)
|
||||||
.await;
|
.await;
|
||||||
tokio::time::sleep(std::time::Duration::from_millis(300)).await;
|
tokio::time::sleep(std::time::Duration::from_millis(300)).await;
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -9,6 +9,7 @@ use utils::{list_peer_route_pair, PeerRoutePair};
|
|||||||
mod arch;
|
mod arch;
|
||||||
mod common;
|
mod common;
|
||||||
mod rpc;
|
mod rpc;
|
||||||
|
mod tunnel;
|
||||||
mod tunnels;
|
mod tunnels;
|
||||||
mod utils;
|
mod utils;
|
||||||
|
|
||||||
|
|||||||
@@ -16,6 +16,7 @@ mod instance;
|
|||||||
mod peer_center;
|
mod peer_center;
|
||||||
mod peers;
|
mod peers;
|
||||||
mod rpc;
|
mod rpc;
|
||||||
|
mod tunnel;
|
||||||
mod tunnels;
|
mod tunnels;
|
||||||
mod vpn_portal;
|
mod vpn_portal;
|
||||||
|
|
||||||
|
|||||||
@@ -16,12 +16,13 @@ use tokio::{
|
|||||||
sync::{mpsc::UnboundedSender, Mutex},
|
sync::{mpsc::UnboundedSender, Mutex},
|
||||||
task::JoinSet,
|
task::JoinSet,
|
||||||
};
|
};
|
||||||
use tokio_util::bytes::Bytes;
|
|
||||||
use tracing::Instrument;
|
use tracing::Instrument;
|
||||||
|
|
||||||
use crate::{
|
use crate::{
|
||||||
common::{error::Error, global_ctx::ArcGlobalCtx, PeerId},
|
common::{error::Error, global_ctx::ArcGlobalCtx, PeerId},
|
||||||
peers::{packet, peer_manager::PeerManager, PeerPacketFilter},
|
peers::{peer_manager::PeerManager, PeerPacketFilter},
|
||||||
|
tunnel::packet_def::{PacketType, ZCPacket},
|
||||||
};
|
};
|
||||||
|
|
||||||
use super::CidrSet;
|
use super::CidrSet;
|
||||||
@@ -78,11 +79,7 @@ fn socket_recv(socket: &Socket, buf: &mut [MaybeUninit<u8>]) -> Result<(usize, I
|
|||||||
Ok((size, addr))
|
Ok((size, addr))
|
||||||
}
|
}
|
||||||
|
|
||||||
fn socket_recv_loop(
|
fn socket_recv_loop(socket: Socket, nat_table: IcmpNatTable, sender: UnboundedSender<ZCPacket>) {
|
||||||
socket: Socket,
|
|
||||||
nat_table: IcmpNatTable,
|
|
||||||
sender: UnboundedSender<packet::Packet>,
|
|
||||||
) {
|
|
||||||
let mut buf = [0u8; 4096];
|
let mut buf = [0u8; 4096];
|
||||||
let data: &mut [MaybeUninit<u8>] = unsafe { std::mem::transmute(&mut buf[12..]) };
|
let data: &mut [MaybeUninit<u8>] = unsafe { std::mem::transmute(&mut buf[12..]) };
|
||||||
|
|
||||||
@@ -126,13 +123,14 @@ fn socket_recv_loop(
|
|||||||
ipv4_packet.set_destination(dest_ip);
|
ipv4_packet.set_destination(dest_ip);
|
||||||
ipv4_packet.set_checksum(ipv4::checksum(&ipv4_packet.to_immutable()));
|
ipv4_packet.set_checksum(ipv4::checksum(&ipv4_packet.to_immutable()));
|
||||||
|
|
||||||
let peer_packet = packet::Packet::new_data_packet(
|
let mut p = ZCPacket::new_with_payload(ipv4_packet.packet());
|
||||||
v.my_peer_id,
|
p.fill_peer_manager_hdr(
|
||||||
v.src_peer_id,
|
v.my_peer_id.into(),
|
||||||
&ipv4_packet.to_immutable().packet(),
|
v.src_peer_id.into(),
|
||||||
|
PacketType::Data as u8,
|
||||||
);
|
);
|
||||||
|
|
||||||
if let Err(e) = sender.send(peer_packet) {
|
if let Err(e) = sender.send(p) {
|
||||||
tracing::error!("send icmp packet to peer failed: {:?}, may exiting..", e);
|
tracing::error!("send icmp packet to peer failed: {:?}, may exiting..", e);
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
@@ -141,61 +139,12 @@ fn socket_recv_loop(
|
|||||||
|
|
||||||
#[async_trait::async_trait]
|
#[async_trait::async_trait]
|
||||||
impl PeerPacketFilter for IcmpProxy {
|
impl PeerPacketFilter for IcmpProxy {
|
||||||
async fn try_process_packet_from_peer(
|
async fn try_process_packet_from_peer(&self, packet: ZCPacket) -> Option<ZCPacket> {
|
||||||
&self,
|
if let Some(_) = self.try_handle_peer_packet(&packet).await {
|
||||||
packet: &packet::ArchivedPacket,
|
|
||||||
_: &Bytes,
|
|
||||||
) -> Option<()> {
|
|
||||||
let _ = self.global_ctx.get_ipv4()?;
|
|
||||||
|
|
||||||
if packet.packet_type != packet::PacketType::Data {
|
|
||||||
return None;
|
|
||||||
};
|
|
||||||
|
|
||||||
let ipv4 = Ipv4Packet::new(&packet.payload.as_bytes())?;
|
|
||||||
|
|
||||||
if ipv4.get_version() != 4 || ipv4.get_next_level_protocol() != IpNextHeaderProtocols::Icmp
|
|
||||||
{
|
|
||||||
return None;
|
return None;
|
||||||
|
} else {
|
||||||
|
return Some(packet);
|
||||||
}
|
}
|
||||||
|
|
||||||
if !self.cidr_set.contains_v4(ipv4.get_destination()) {
|
|
||||||
return None;
|
|
||||||
}
|
|
||||||
|
|
||||||
let icmp_packet = icmp::echo_request::EchoRequestPacket::new(&ipv4.payload())?;
|
|
||||||
|
|
||||||
if icmp_packet.get_icmp_type() != IcmpTypes::EchoRequest {
|
|
||||||
// drop it because we do not support other icmp types
|
|
||||||
tracing::trace!("unsupported icmp type: {:?}", icmp_packet.get_icmp_type());
|
|
||||||
return Some(());
|
|
||||||
}
|
|
||||||
|
|
||||||
let icmp_id = icmp_packet.get_identifier();
|
|
||||||
let icmp_seq = icmp_packet.get_sequence_number();
|
|
||||||
|
|
||||||
let key = IcmpNatKey {
|
|
||||||
dst_ip: ipv4.get_destination().into(),
|
|
||||||
icmp_id,
|
|
||||||
icmp_seq,
|
|
||||||
};
|
|
||||||
|
|
||||||
let value = IcmpNatEntry::new(
|
|
||||||
packet.from_peer.into(),
|
|
||||||
packet.to_peer.into(),
|
|
||||||
ipv4.get_source().into(),
|
|
||||||
)
|
|
||||||
.ok()?;
|
|
||||||
|
|
||||||
if let Some(old) = self.nat_table.insert(key, value) {
|
|
||||||
tracing::info!("icmp nat table entry replaced: {:?}", old);
|
|
||||||
}
|
|
||||||
|
|
||||||
if let Err(e) = self.send_icmp_packet(ipv4.get_destination(), &icmp_packet) {
|
|
||||||
tracing::error!("send icmp packet failed: {:?}", e);
|
|
||||||
}
|
|
||||||
|
|
||||||
Some(())
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -262,8 +211,9 @@ impl IcmpProxy {
|
|||||||
self.tasks.lock().await.spawn(
|
self.tasks.lock().await.spawn(
|
||||||
async move {
|
async move {
|
||||||
while let Some(msg) = receiver.recv().await {
|
while let Some(msg) = receiver.recv().await {
|
||||||
let to_peer_id = msg.to_peer.into();
|
let hdr = msg.peer_manager_header().unwrap();
|
||||||
let ret = peer_manager.send_msg(msg.into(), to_peer_id).await;
|
let to_peer_id = hdr.to_peer_id.into();
|
||||||
|
let ret = peer_manager.send_msg(msg, to_peer_id).await;
|
||||||
if ret.is_err() {
|
if ret.is_err() {
|
||||||
tracing::error!("send icmp packet to peer failed: {:?}", ret);
|
tracing::error!("send icmp packet to peer failed: {:?}", ret);
|
||||||
}
|
}
|
||||||
@@ -290,4 +240,58 @@ impl IcmpProxy {
|
|||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
async fn try_handle_peer_packet(&self, packet: &ZCPacket) -> Option<()> {
|
||||||
|
let _ = self.global_ctx.get_ipv4()?;
|
||||||
|
let hdr = packet.peer_manager_header().unwrap();
|
||||||
|
|
||||||
|
if hdr.packet_type != PacketType::Data as u8 {
|
||||||
|
return None;
|
||||||
|
};
|
||||||
|
|
||||||
|
let ipv4 = Ipv4Packet::new(&packet.payload())?;
|
||||||
|
|
||||||
|
if ipv4.get_version() != 4 || ipv4.get_next_level_protocol() != IpNextHeaderProtocols::Icmp
|
||||||
|
{
|
||||||
|
return None;
|
||||||
|
}
|
||||||
|
|
||||||
|
if !self.cidr_set.contains_v4(ipv4.get_destination()) {
|
||||||
|
return None;
|
||||||
|
}
|
||||||
|
|
||||||
|
let icmp_packet = icmp::echo_request::EchoRequestPacket::new(&ipv4.payload())?;
|
||||||
|
|
||||||
|
if icmp_packet.get_icmp_type() != IcmpTypes::EchoRequest {
|
||||||
|
// drop it because we do not support other icmp types
|
||||||
|
tracing::trace!("unsupported icmp type: {:?}", icmp_packet.get_icmp_type());
|
||||||
|
return Some(());
|
||||||
|
}
|
||||||
|
|
||||||
|
let icmp_id = icmp_packet.get_identifier();
|
||||||
|
let icmp_seq = icmp_packet.get_sequence_number();
|
||||||
|
|
||||||
|
let key = IcmpNatKey {
|
||||||
|
dst_ip: ipv4.get_destination().into(),
|
||||||
|
icmp_id,
|
||||||
|
icmp_seq,
|
||||||
|
};
|
||||||
|
|
||||||
|
let value = IcmpNatEntry::new(
|
||||||
|
hdr.from_peer_id.into(),
|
||||||
|
hdr.to_peer_id.into(),
|
||||||
|
ipv4.get_source().into(),
|
||||||
|
)
|
||||||
|
.ok()?;
|
||||||
|
|
||||||
|
if let Some(old) = self.nat_table.insert(key, value) {
|
||||||
|
tracing::info!("icmp nat table entry replaced: {:?}", old);
|
||||||
|
}
|
||||||
|
|
||||||
|
if let Err(e) = self.send_icmp_packet(ipv4.get_destination(), &icmp_packet) {
|
||||||
|
tracing::error!("send icmp packet failed: {:?}", e);
|
||||||
|
}
|
||||||
|
|
||||||
|
Some(())
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -2,7 +2,9 @@ use crossbeam::atomic::AtomicCell;
|
|||||||
use dashmap::DashMap;
|
use dashmap::DashMap;
|
||||||
use pnet::packet::ip::IpNextHeaderProtocols;
|
use pnet::packet::ip::IpNextHeaderProtocols;
|
||||||
use pnet::packet::ipv4::{Ipv4Packet, MutableIpv4Packet};
|
use pnet::packet::ipv4::{Ipv4Packet, MutableIpv4Packet};
|
||||||
use pnet::packet::tcp::{ipv4_checksum, MutableTcpPacket};
|
use pnet::packet::tcp::{ipv4_checksum, MutableTcpPacket, TcpPacket};
|
||||||
|
use pnet::packet::MutablePacket;
|
||||||
|
use pnet::packet::Packet;
|
||||||
use std::net::{IpAddr, Ipv4Addr, SocketAddr, SocketAddrV4};
|
use std::net::{IpAddr, Ipv4Addr, SocketAddr, SocketAddrV4};
|
||||||
use std::sync::atomic::AtomicU16;
|
use std::sync::atomic::AtomicU16;
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
@@ -11,16 +13,16 @@ use tokio::io::copy_bidirectional;
|
|||||||
use tokio::net::{TcpListener, TcpSocket, TcpStream};
|
use tokio::net::{TcpListener, TcpSocket, TcpStream};
|
||||||
use tokio::sync::Mutex;
|
use tokio::sync::Mutex;
|
||||||
use tokio::task::JoinSet;
|
use tokio::task::JoinSet;
|
||||||
use tokio_util::bytes::{Bytes, BytesMut};
|
|
||||||
use tracing::Instrument;
|
use tracing::Instrument;
|
||||||
|
|
||||||
use crate::common::error::Result;
|
use crate::common::error::Result;
|
||||||
use crate::common::global_ctx::GlobalCtx;
|
use crate::common::global_ctx::GlobalCtx;
|
||||||
use crate::common::join_joinset_background;
|
use crate::common::join_joinset_background;
|
||||||
use crate::common::netns::NetNS;
|
use crate::common::netns::NetNS;
|
||||||
use crate::peers::packet::{self, ArchivedPacket};
|
|
||||||
use crate::peers::peer_manager::PeerManager;
|
use crate::peers::peer_manager::PeerManager;
|
||||||
use crate::peers::{NicPacketFilter, PeerPacketFilter};
|
use crate::peers::{NicPacketFilter, PeerPacketFilter};
|
||||||
|
use crate::tunnel::packet_def::{PacketType, ZCPacket};
|
||||||
|
|
||||||
use super::CidrSet;
|
use super::CidrSet;
|
||||||
|
|
||||||
@@ -83,98 +85,37 @@ pub struct TcpProxy {
|
|||||||
|
|
||||||
#[async_trait::async_trait]
|
#[async_trait::async_trait]
|
||||||
impl PeerPacketFilter for TcpProxy {
|
impl PeerPacketFilter for TcpProxy {
|
||||||
async fn try_process_packet_from_peer(&self, packet: &ArchivedPacket, _: &Bytes) -> Option<()> {
|
async fn try_process_packet_from_peer(&self, mut packet: ZCPacket) -> Option<ZCPacket> {
|
||||||
let ipv4_addr = self.global_ctx.get_ipv4()?;
|
if let Some(_) = self.try_handle_peer_packet(&mut packet).await {
|
||||||
|
if let Err(e) = self.peer_manager.get_nic_channel().send(packet).await {
|
||||||
if packet.packet_type != packet::PacketType::Data {
|
tracing::error!("send to nic failed: {:?}", e);
|
||||||
return None;
|
}
|
||||||
};
|
|
||||||
|
|
||||||
let payload_bytes = packet.payload.as_bytes();
|
|
||||||
|
|
||||||
let ipv4 = Ipv4Packet::new(payload_bytes)?;
|
|
||||||
if ipv4.get_version() != 4 || ipv4.get_next_level_protocol() != IpNextHeaderProtocols::Tcp {
|
|
||||||
return None;
|
return None;
|
||||||
|
} else {
|
||||||
|
Some(packet)
|
||||||
}
|
}
|
||||||
|
|
||||||
if !self.cidr_set.contains_v4(ipv4.get_destination()) {
|
|
||||||
return None;
|
|
||||||
}
|
|
||||||
|
|
||||||
tracing::trace!(ipv4 = ?ipv4, cidr_set = ?self.cidr_set, "proxy tcp packet received");
|
|
||||||
|
|
||||||
let mut packet_buffer = BytesMut::with_capacity(payload_bytes.len());
|
|
||||||
packet_buffer.extend_from_slice(&payload_bytes.to_vec());
|
|
||||||
|
|
||||||
let (ip_buffer, tcp_buffer) =
|
|
||||||
packet_buffer.split_at_mut(ipv4.get_header_length() as usize * 4);
|
|
||||||
|
|
||||||
let mut ip_packet = MutableIpv4Packet::new(ip_buffer).unwrap();
|
|
||||||
let mut tcp_packet = MutableTcpPacket::new(tcp_buffer).unwrap();
|
|
||||||
|
|
||||||
let is_tcp_syn = tcp_packet.get_flags() & pnet::packet::tcp::TcpFlags::SYN != 0;
|
|
||||||
if is_tcp_syn {
|
|
||||||
let source_ip = ip_packet.get_source();
|
|
||||||
let source_port = tcp_packet.get_source();
|
|
||||||
let src = SocketAddr::V4(SocketAddrV4::new(source_ip, source_port));
|
|
||||||
|
|
||||||
let dest_ip = ip_packet.get_destination();
|
|
||||||
let dest_port = tcp_packet.get_destination();
|
|
||||||
let dst = SocketAddr::V4(SocketAddrV4::new(dest_ip, dest_port));
|
|
||||||
|
|
||||||
let old_val = self
|
|
||||||
.syn_map
|
|
||||||
.insert(src, Arc::new(NatDstEntry::new(src, dst)));
|
|
||||||
tracing::trace!(src = ?src, dst = ?dst, old_entry = ?old_val, "tcp syn received");
|
|
||||||
}
|
|
||||||
|
|
||||||
ip_packet.set_destination(ipv4_addr);
|
|
||||||
tcp_packet.set_destination(self.get_local_port());
|
|
||||||
Self::update_ipv4_packet_checksum(&mut ip_packet, &mut tcp_packet);
|
|
||||||
|
|
||||||
tracing::trace!(ip_packet = ?ip_packet, tcp_packet = ?tcp_packet, "tcp packet forwarded");
|
|
||||||
|
|
||||||
if let Err(e) = self
|
|
||||||
.peer_manager
|
|
||||||
.get_nic_channel()
|
|
||||||
.send(packet_buffer.freeze())
|
|
||||||
.await
|
|
||||||
{
|
|
||||||
tracing::error!("send to nic failed: {:?}", e);
|
|
||||||
}
|
|
||||||
|
|
||||||
Some(())
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[async_trait::async_trait]
|
#[async_trait::async_trait]
|
||||||
impl NicPacketFilter for TcpProxy {
|
impl NicPacketFilter for TcpProxy {
|
||||||
async fn try_process_packet_from_nic(&self, mut data: BytesMut) -> BytesMut {
|
async fn try_process_packet_from_nic(&self, zc_packet: &mut ZCPacket) {
|
||||||
let Some(my_ipv4) = self.global_ctx.get_ipv4() else {
|
let Some(my_ipv4) = self.global_ctx.get_ipv4() else {
|
||||||
return data;
|
return;
|
||||||
};
|
};
|
||||||
|
|
||||||
let header_len = {
|
let data = zc_packet.payload();
|
||||||
let Some(ipv4) = &Ipv4Packet::new(&data[..]) else {
|
let ip_packet = Ipv4Packet::new(data).unwrap();
|
||||||
return data;
|
if ip_packet.get_version() != 4
|
||||||
};
|
|| ip_packet.get_source() != my_ipv4
|
||||||
|
|| ip_packet.get_next_level_protocol() != IpNextHeaderProtocols::Tcp
|
||||||
if ipv4.get_version() != 4
|
{
|
||||||
|| ipv4.get_source() != my_ipv4
|
return;
|
||||||
|| ipv4.get_next_level_protocol() != IpNextHeaderProtocols::Tcp
|
}
|
||||||
{
|
|
||||||
return data;
|
|
||||||
}
|
|
||||||
|
|
||||||
ipv4.get_header_length() as usize * 4
|
|
||||||
};
|
|
||||||
|
|
||||||
let (ip_buffer, tcp_buffer) = data.split_at_mut(header_len);
|
|
||||||
let mut ip_packet = MutableIpv4Packet::new(ip_buffer).unwrap();
|
|
||||||
let mut tcp_packet = MutableTcpPacket::new(tcp_buffer).unwrap();
|
|
||||||
|
|
||||||
|
let tcp_packet = TcpPacket::new(ip_packet.payload()).unwrap();
|
||||||
if tcp_packet.get_source() != self.get_local_port() {
|
if tcp_packet.get_source() != self.get_local_port() {
|
||||||
return data;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
let dst_addr = SocketAddr::V4(SocketAddrV4::new(
|
let dst_addr = SocketAddr::V4(SocketAddrV4::new(
|
||||||
@@ -187,7 +128,7 @@ impl NicPacketFilter for TcpProxy {
|
|||||||
entry
|
entry
|
||||||
} else {
|
} else {
|
||||||
let Some(syn_entry) = self.syn_map.get(&dst_addr) else {
|
let Some(syn_entry) = self.syn_map.get(&dst_addr) else {
|
||||||
return data;
|
return;
|
||||||
};
|
};
|
||||||
syn_entry
|
syn_entry
|
||||||
};
|
};
|
||||||
@@ -199,13 +140,18 @@ impl NicPacketFilter for TcpProxy {
|
|||||||
panic!("v4 nat entry src ip is not v4");
|
panic!("v4 nat entry src ip is not v4");
|
||||||
};
|
};
|
||||||
|
|
||||||
|
let mut ip_packet = MutableIpv4Packet::new(zc_packet.mut_payload()).unwrap();
|
||||||
ip_packet.set_source(ip);
|
ip_packet.set_source(ip);
|
||||||
|
let dst = ip_packet.get_destination();
|
||||||
|
|
||||||
|
let mut tcp_packet = MutableTcpPacket::new(ip_packet.payload_mut()).unwrap();
|
||||||
tcp_packet.set_source(nat_entry.dst.port());
|
tcp_packet.set_source(nat_entry.dst.port());
|
||||||
Self::update_ipv4_packet_checksum(&mut ip_packet, &mut tcp_packet);
|
|
||||||
|
Self::update_tcp_packet_checksum(&mut tcp_packet, &ip, &dst);
|
||||||
|
drop(tcp_packet);
|
||||||
|
Self::update_ip_packet_checksum(&mut ip_packet);
|
||||||
|
|
||||||
tracing::trace!(dst_addr = ?dst_addr, nat_entry = ?nat_entry, packet = ?ip_packet, "tcp packet after modified");
|
tracing::trace!(dst_addr = ?dst_addr, nat_entry = ?nat_entry, packet = ?ip_packet, "tcp packet after modified");
|
||||||
|
|
||||||
data
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -226,17 +172,20 @@ impl TcpProxy {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
fn update_ipv4_packet_checksum(
|
fn update_tcp_packet_checksum(
|
||||||
ipv4_packet: &mut MutableIpv4Packet,
|
|
||||||
tcp_packet: &mut MutableTcpPacket,
|
tcp_packet: &mut MutableTcpPacket,
|
||||||
|
ipv4_src: &Ipv4Addr,
|
||||||
|
ipv4_dst: &Ipv4Addr,
|
||||||
) {
|
) {
|
||||||
tcp_packet.set_checksum(ipv4_checksum(
|
tcp_packet.set_checksum(ipv4_checksum(
|
||||||
&tcp_packet.to_immutable(),
|
&tcp_packet.to_immutable(),
|
||||||
&ipv4_packet.get_source(),
|
ipv4_src,
|
||||||
&ipv4_packet.get_destination(),
|
ipv4_dst,
|
||||||
));
|
));
|
||||||
|
}
|
||||||
|
|
||||||
ipv4_packet.set_checksum(pnet::packet::ipv4::checksum(&ipv4_packet.to_immutable()));
|
fn update_ip_packet_checksum(ip_packet: &mut MutableIpv4Packet) {
|
||||||
|
ip_packet.set_checksum(pnet::packet::ipv4::checksum(&ip_packet.to_immutable()));
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn start(self: &Arc<Self>) -> Result<()> {
|
pub async fn start(self: &Arc<Self>) -> Result<()> {
|
||||||
@@ -302,6 +251,7 @@ impl TcpProxy {
|
|||||||
tracing::error!("tcp connection from unknown source: {:?}", socket_addr);
|
tracing::error!("tcp connection from unknown source: {:?}", socket_addr);
|
||||||
continue;
|
continue;
|
||||||
};
|
};
|
||||||
|
tracing::info!(?socket_addr, "tcp connection accepted for proxy");
|
||||||
assert_eq!(entry.state.load(), NatDstEntryState::SynReceived);
|
assert_eq!(entry.state.load(), NatDstEntryState::SynReceived);
|
||||||
|
|
||||||
let entry_clone = entry.clone();
|
let entry_clone = entry.clone();
|
||||||
@@ -404,4 +354,60 @@ impl TcpProxy {
|
|||||||
pub fn get_local_port(&self) -> u16 {
|
pub fn get_local_port(&self) -> u16 {
|
||||||
self.local_port.load(std::sync::atomic::Ordering::Relaxed)
|
self.local_port.load(std::sync::atomic::Ordering::Relaxed)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
async fn try_handle_peer_packet(&self, packet: &mut ZCPacket) -> Option<()> {
|
||||||
|
let ipv4_addr = self.global_ctx.get_ipv4()?;
|
||||||
|
let hdr = packet.peer_manager_header().unwrap();
|
||||||
|
|
||||||
|
if hdr.packet_type != PacketType::Data as u8 {
|
||||||
|
return None;
|
||||||
|
};
|
||||||
|
|
||||||
|
let payload_bytes = packet.mut_payload();
|
||||||
|
|
||||||
|
let ipv4 = Ipv4Packet::new(payload_bytes)?;
|
||||||
|
if ipv4.get_version() != 4 || ipv4.get_next_level_protocol() != IpNextHeaderProtocols::Tcp {
|
||||||
|
return None;
|
||||||
|
}
|
||||||
|
|
||||||
|
if !self.cidr_set.contains_v4(ipv4.get_destination()) {
|
||||||
|
return None;
|
||||||
|
}
|
||||||
|
|
||||||
|
tracing::info!(ipv4 = ?ipv4, cidr_set = ?self.cidr_set, "proxy tcp packet received");
|
||||||
|
|
||||||
|
let ip_packet = Ipv4Packet::new(payload_bytes).unwrap();
|
||||||
|
let tcp_packet = TcpPacket::new(ip_packet.payload()).unwrap();
|
||||||
|
|
||||||
|
let is_tcp_syn = tcp_packet.get_flags() & pnet::packet::tcp::TcpFlags::SYN != 0;
|
||||||
|
if is_tcp_syn {
|
||||||
|
let source_ip = ip_packet.get_source();
|
||||||
|
let source_port = tcp_packet.get_source();
|
||||||
|
let src = SocketAddr::V4(SocketAddrV4::new(source_ip, source_port));
|
||||||
|
|
||||||
|
let dest_ip = ip_packet.get_destination();
|
||||||
|
let dest_port = tcp_packet.get_destination();
|
||||||
|
let dst = SocketAddr::V4(SocketAddrV4::new(dest_ip, dest_port));
|
||||||
|
|
||||||
|
let old_val = self
|
||||||
|
.syn_map
|
||||||
|
.insert(src, Arc::new(NatDstEntry::new(src, dst)));
|
||||||
|
tracing::trace!(src = ?src, dst = ?dst, old_entry = ?old_val, "tcp syn received");
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut ip_packet = MutableIpv4Packet::new(payload_bytes).unwrap();
|
||||||
|
ip_packet.set_destination(ipv4_addr);
|
||||||
|
let source = ip_packet.get_source();
|
||||||
|
|
||||||
|
let mut tcp_packet = MutableTcpPacket::new(ip_packet.payload_mut()).unwrap();
|
||||||
|
tcp_packet.set_destination(self.get_local_port());
|
||||||
|
|
||||||
|
Self::update_tcp_packet_checksum(&mut tcp_packet, &source, &ipv4_addr);
|
||||||
|
drop(tcp_packet);
|
||||||
|
Self::update_ip_packet_checksum(&mut ip_packet);
|
||||||
|
|
||||||
|
tracing::info!(?source, ?ipv4_addr, ?packet, "tcp packet after modified");
|
||||||
|
|
||||||
|
Some(())
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -21,12 +21,12 @@ use tokio::{
|
|||||||
time::timeout,
|
time::timeout,
|
||||||
};
|
};
|
||||||
|
|
||||||
use tokio_util::bytes::Bytes;
|
|
||||||
use tracing::Level;
|
use tracing::Level;
|
||||||
|
|
||||||
use crate::{
|
use crate::{
|
||||||
common::{error::Error, global_ctx::ArcGlobalCtx, PeerId},
|
common::{error::Error, global_ctx::ArcGlobalCtx, PeerId},
|
||||||
peers::{packet, peer_manager::PeerManager, PeerPacketFilter},
|
peers::{peer_manager::PeerManager, PeerPacketFilter},
|
||||||
|
tunnel::packet_def::{PacketType, ZCPacket},
|
||||||
tunnels::common::setup_sokcet2,
|
tunnels::common::setup_sokcet2,
|
||||||
};
|
};
|
||||||
|
|
||||||
@@ -79,7 +79,7 @@ impl UdpNatEntry {
|
|||||||
|
|
||||||
async fn compose_ipv4_packet(
|
async fn compose_ipv4_packet(
|
||||||
self: &Arc<Self>,
|
self: &Arc<Self>,
|
||||||
packet_sender: &mut UnboundedSender<packet::Packet>,
|
packet_sender: &mut UnboundedSender<ZCPacket>,
|
||||||
buf: &mut [u8],
|
buf: &mut [u8],
|
||||||
src_v4: &SocketAddrV4,
|
src_v4: &SocketAddrV4,
|
||||||
payload_len: usize,
|
payload_len: usize,
|
||||||
@@ -140,13 +140,10 @@ impl UdpNatEntry {
|
|||||||
|
|
||||||
tracing::trace!(?ipv4_packet, "udp nat packet response send");
|
tracing::trace!(?ipv4_packet, "udp nat packet response send");
|
||||||
|
|
||||||
let peer_packet = packet::Packet::new_data_packet(
|
let mut p = ZCPacket::new_with_payload(ipv4_packet.packet());
|
||||||
self.my_peer_id,
|
p.fill_peer_manager_hdr(self.my_peer_id, self.src_peer_id, PacketType::Data as u8);
|
||||||
self.src_peer_id,
|
|
||||||
&ipv4_packet.to_immutable().packet(),
|
|
||||||
);
|
|
||||||
|
|
||||||
if let Err(e) = packet_sender.send(peer_packet) {
|
if let Err(e) = packet_sender.send(p) {
|
||||||
tracing::error!("send icmp packet to peer failed: {:?}, may exiting..", e);
|
tracing::error!("send icmp packet to peer failed: {:?}, may exiting..", e);
|
||||||
return Err(Error::AnyhowError(e.into()));
|
return Err(Error::AnyhowError(e.into()));
|
||||||
}
|
}
|
||||||
@@ -158,7 +155,7 @@ impl UdpNatEntry {
|
|||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn forward_task(self: Arc<Self>, mut packet_sender: UnboundedSender<packet::Packet>) {
|
async fn forward_task(self: Arc<Self>, mut packet_sender: UnboundedSender<ZCPacket>) {
|
||||||
let mut buf = [0u8; 8192];
|
let mut buf = [0u8; 8192];
|
||||||
let mut udp_body: &mut [u8] = unsafe { std::mem::transmute(&mut buf[20 + 8..]) };
|
let mut udp_body: &mut [u8] = unsafe { std::mem::transmute(&mut buf[20 + 8..]) };
|
||||||
let mut ip_id = 1;
|
let mut ip_id = 1;
|
||||||
@@ -220,31 +217,25 @@ pub struct UdpProxy {
|
|||||||
|
|
||||||
nat_table: Arc<DashMap<UdpNatKey, Arc<UdpNatEntry>>>,
|
nat_table: Arc<DashMap<UdpNatKey, Arc<UdpNatEntry>>>,
|
||||||
|
|
||||||
sender: UnboundedSender<packet::Packet>,
|
sender: UnboundedSender<ZCPacket>,
|
||||||
receiver: Mutex<Option<UnboundedReceiver<packet::Packet>>>,
|
receiver: Mutex<Option<UnboundedReceiver<ZCPacket>>>,
|
||||||
|
|
||||||
tasks: Mutex<JoinSet<()>>,
|
tasks: Mutex<JoinSet<()>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[async_trait::async_trait]
|
impl UdpProxy {
|
||||||
impl PeerPacketFilter for UdpProxy {
|
async fn try_handle_packet(&self, packet: &ZCPacket) -> Option<()> {
|
||||||
async fn try_process_packet_from_peer(
|
|
||||||
&self,
|
|
||||||
packet: &packet::ArchivedPacket,
|
|
||||||
_: &Bytes,
|
|
||||||
) -> Option<()> {
|
|
||||||
if self.cidr_set.is_empty() {
|
if self.cidr_set.is_empty() {
|
||||||
return None;
|
return None;
|
||||||
}
|
}
|
||||||
|
|
||||||
let _ = self.global_ctx.get_ipv4()?;
|
let _ = self.global_ctx.get_ipv4()?;
|
||||||
|
let hdr = packet.peer_manager_header().unwrap();
|
||||||
if packet.packet_type != packet::PacketType::Data {
|
if hdr.packet_type != PacketType::Data as u8 {
|
||||||
return None;
|
return None;
|
||||||
};
|
};
|
||||||
|
|
||||||
let ipv4 = Ipv4Packet::new(packet.payload.as_bytes())?;
|
let ipv4 = Ipv4Packet::new(packet.payload())?;
|
||||||
|
|
||||||
if ipv4.get_version() != 4 || ipv4.get_next_level_protocol() != IpNextHeaderProtocols::Udp {
|
if ipv4.get_version() != 4 || ipv4.get_next_level_protocol() != IpNextHeaderProtocols::Udp {
|
||||||
return None;
|
return None;
|
||||||
}
|
}
|
||||||
@@ -272,8 +263,8 @@ impl PeerPacketFilter for UdpProxy {
|
|||||||
tracing::info!(?packet, ?ipv4, ?udp_packet, "udp nat table entry created");
|
tracing::info!(?packet, ?ipv4, ?udp_packet, "udp nat table entry created");
|
||||||
let _g = self.global_ctx.net_ns.guard();
|
let _g = self.global_ctx.net_ns.guard();
|
||||||
Ok(Arc::new(UdpNatEntry::new(
|
Ok(Arc::new(UdpNatEntry::new(
|
||||||
packet.from_peer.into(),
|
hdr.from_peer_id.get(),
|
||||||
packet.to_peer.into(),
|
hdr.to_peer_id.get(),
|
||||||
nat_key.src_socket,
|
nat_key.src_socket,
|
||||||
)?))
|
)?))
|
||||||
})
|
})
|
||||||
@@ -316,6 +307,17 @@ impl PeerPacketFilter for UdpProxy {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[async_trait::async_trait]
|
||||||
|
impl PeerPacketFilter for UdpProxy {
|
||||||
|
async fn try_process_packet_from_peer(&self, packet: ZCPacket) -> Option<ZCPacket> {
|
||||||
|
if let Some(_) = self.try_handle_packet(&packet).await {
|
||||||
|
return None;
|
||||||
|
} else {
|
||||||
|
return Some(packet);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
impl UdpProxy {
|
impl UdpProxy {
|
||||||
pub fn new(
|
pub fn new(
|
||||||
global_ctx: ArcGlobalCtx,
|
global_ctx: ArcGlobalCtx,
|
||||||
@@ -362,9 +364,9 @@ impl UdpProxy {
|
|||||||
let peer_manager = self.peer_manager.clone();
|
let peer_manager = self.peer_manager.clone();
|
||||||
self.tasks.lock().await.spawn(async move {
|
self.tasks.lock().await.spawn(async move {
|
||||||
while let Some(msg) = receiver.recv().await {
|
while let Some(msg) = receiver.recv().await {
|
||||||
let to_peer_id: PeerId = msg.to_peer.into();
|
let to_peer_id: PeerId = msg.peer_manager_header().unwrap().to_peer_id.get();
|
||||||
tracing::trace!(?msg, ?to_peer_id, "udp nat packet response send");
|
tracing::trace!(?msg, ?to_peer_id, "udp nat packet response send");
|
||||||
let ret = peer_manager.send_msg(msg.into(), to_peer_id).await;
|
let ret = peer_manager.send_msg(msg, to_peer_id).await;
|
||||||
if ret.is_err() {
|
if ret.is_err() {
|
||||||
tracing::error!("send icmp packet to peer failed: {:?}", ret);
|
tracing::error!("send icmp packet to peer failed: {:?}", ret);
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -3,12 +3,12 @@ use std::net::Ipv4Addr;
|
|||||||
use std::sync::{Arc, Weak};
|
use std::sync::{Arc, Weak};
|
||||||
|
|
||||||
use anyhow::Context;
|
use anyhow::Context;
|
||||||
use futures::StreamExt;
|
use futures::{SinkExt, StreamExt};
|
||||||
use pnet::packet::ethernet::EthernetPacket;
|
use pnet::packet::ethernet::EthernetPacket;
|
||||||
use pnet::packet::ipv4::Ipv4Packet;
|
use pnet::packet::ipv4::Ipv4Packet;
|
||||||
|
|
||||||
|
use bytes::BytesMut;
|
||||||
use tokio::{sync::Mutex, task::JoinSet};
|
use tokio::{sync::Mutex, task::JoinSet};
|
||||||
use tokio_util::bytes::{Bytes, BytesMut};
|
|
||||||
use tonic::transport::Server;
|
use tonic::transport::Server;
|
||||||
|
|
||||||
use crate::common::config::ConfigLoader;
|
use crate::common::config::ConfigLoader;
|
||||||
@@ -22,15 +22,15 @@ use crate::gateway::icmp_proxy::IcmpProxy;
|
|||||||
use crate::gateway::tcp_proxy::TcpProxy;
|
use crate::gateway::tcp_proxy::TcpProxy;
|
||||||
use crate::gateway::udp_proxy::UdpProxy;
|
use crate::gateway::udp_proxy::UdpProxy;
|
||||||
use crate::peer_center::instance::PeerCenterInstance;
|
use crate::peer_center::instance::PeerCenterInstance;
|
||||||
use crate::peers::peer_conn::PeerConnId;
|
|
||||||
use crate::peers::peer_manager::{PeerManager, RouteAlgoType};
|
use crate::peers::peer_manager::{PeerManager, RouteAlgoType};
|
||||||
use crate::peers::rpc_service::PeerManagerRpcService;
|
use crate::peers::rpc_service::PeerManagerRpcService;
|
||||||
|
use crate::peers::zc_peer_conn::PeerConnId;
|
||||||
|
use crate::peers::PacketRecvChanReceiver;
|
||||||
use crate::rpc::vpn_portal_rpc_server::VpnPortalRpc;
|
use crate::rpc::vpn_portal_rpc_server::VpnPortalRpc;
|
||||||
use crate::rpc::{GetVpnPortalInfoRequest, GetVpnPortalInfoResponse, VpnPortalInfo};
|
use crate::rpc::{GetVpnPortalInfoRequest, GetVpnPortalInfoResponse, VpnPortalInfo};
|
||||||
use crate::tunnels::SinkItem;
|
use crate::tunnel::packet_def::ZCPacket;
|
||||||
use crate::vpn_portal::{self, VpnPortal};
|
|
||||||
|
|
||||||
use tokio_stream::wrappers::ReceiverStream;
|
use crate::vpn_portal::{self, VpnPortal};
|
||||||
|
|
||||||
use super::listeners::ListenerManager;
|
use super::listeners::ListenerManager;
|
||||||
use super::virtual_nic;
|
use super::virtual_nic;
|
||||||
@@ -70,7 +70,7 @@ pub struct Instance {
|
|||||||
id: uuid::Uuid,
|
id: uuid::Uuid,
|
||||||
|
|
||||||
virtual_nic: Option<Arc<virtual_nic::VirtualNic>>,
|
virtual_nic: Option<Arc<virtual_nic::VirtualNic>>,
|
||||||
peer_packet_receiver: Option<ReceiverStream<SinkItem>>,
|
peer_packet_receiver: Option<PacketRecvChanReceiver>,
|
||||||
|
|
||||||
tasks: JoinSet<()>,
|
tasks: JoinSet<()>,
|
||||||
|
|
||||||
@@ -133,7 +133,7 @@ impl Instance {
|
|||||||
id,
|
id,
|
||||||
|
|
||||||
virtual_nic: None,
|
virtual_nic: None,
|
||||||
peer_packet_receiver: Some(ReceiverStream::new(peer_packet_receiver)),
|
peer_packet_receiver: Some(peer_packet_receiver),
|
||||||
|
|
||||||
tasks: JoinSet::new(),
|
tasks: JoinSet::new(),
|
||||||
peer_manager,
|
peer_manager,
|
||||||
@@ -167,7 +167,11 @@ impl Instance {
|
|||||||
?ret,
|
?ret,
|
||||||
"[USER_PACKET] recv new packet from tun device and forward to peers."
|
"[USER_PACKET] recv new packet from tun device and forward to peers."
|
||||||
);
|
);
|
||||||
let send_ret = mgr.send_msg_ipv4(ret, dst_ipv4).await;
|
|
||||||
|
// TODO: use zero-copy
|
||||||
|
let send_ret = mgr
|
||||||
|
.send_msg_ipv4(ZCPacket::new_with_payload(ret.as_ref()), dst_ipv4)
|
||||||
|
.await;
|
||||||
if send_ret.is_err() {
|
if send_ret.is_err() {
|
||||||
tracing::trace!(?send_ret, "[USER_PACKET] send_msg_ipv4 failed")
|
tracing::trace!(?send_ret, "[USER_PACKET] send_msg_ipv4 failed")
|
||||||
}
|
}
|
||||||
@@ -209,23 +213,23 @@ impl Instance {
|
|||||||
fn do_forward_peers_to_nic(
|
fn do_forward_peers_to_nic(
|
||||||
tasks: &mut JoinSet<()>,
|
tasks: &mut JoinSet<()>,
|
||||||
nic: Arc<virtual_nic::VirtualNic>,
|
nic: Arc<virtual_nic::VirtualNic>,
|
||||||
channel: Option<ReceiverStream<Bytes>>,
|
channel: Option<PacketRecvChanReceiver>,
|
||||||
) {
|
) {
|
||||||
tasks.spawn(async move {
|
tasks.spawn(async move {
|
||||||
let send = nic.pin_send_stream();
|
let mut send = nic.pin_send_stream();
|
||||||
let channel = channel.unwrap();
|
let mut channel = channel.unwrap();
|
||||||
let ret = channel
|
while let Some(packet) = channel.recv().await {
|
||||||
.map(|packet| {
|
tracing::trace!(
|
||||||
log::trace!(
|
"[USER_PACKET] forward packet from peers to nic. packet: {:?}",
|
||||||
"[USER_PACKET] forward packet from peers to nic. packet: {:?}",
|
packet
|
||||||
packet
|
);
|
||||||
);
|
let mut b = BytesMut::new();
|
||||||
Ok(packet)
|
b.extend_from_slice(packet.payload());
|
||||||
})
|
|
||||||
.forward(send)
|
let ret = send.send(b.freeze()).await;
|
||||||
.await;
|
if ret.is_err() {
|
||||||
if ret.is_err() {
|
panic!("do_forward_tunnel_to_nic");
|
||||||
panic!("do_forward_tunnel_to_nic");
|
}
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
@@ -300,17 +304,25 @@ impl Instance {
|
|||||||
|
|
||||||
self.add_initial_peers().await?;
|
self.add_initial_peers().await?;
|
||||||
|
|
||||||
if let Some(_) = self.global_ctx.get_vpn_portal_cidr() {
|
if self.global_ctx.get_vpn_portal_cidr().is_some() {
|
||||||
self.vpn_portal
|
self.run_vpn_portal().await?;
|
||||||
.lock()
|
|
||||||
.await
|
|
||||||
.start(self.get_global_ctx(), self.get_peer_manager())
|
|
||||||
.await?;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub async fn run_vpn_portal(&mut self) -> Result<(), Error> {
|
||||||
|
if self.global_ctx.get_vpn_portal_cidr().is_none() {
|
||||||
|
return Err(anyhow::anyhow!("vpn portal cidr not set.").into());
|
||||||
|
}
|
||||||
|
self.vpn_portal
|
||||||
|
.lock()
|
||||||
|
.await
|
||||||
|
.start(self.get_global_ctx(), self.get_peer_manager())
|
||||||
|
.await?;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
pub fn get_peer_manager(&self) -> Arc<PeerManager> {
|
pub fn get_peer_manager(&self) -> Arc<PeerManager> {
|
||||||
self.peer_manager.clone()
|
self.peer_manager.clone()
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -11,10 +11,10 @@ use crate::{
|
|||||||
netns::NetNS,
|
netns::NetNS,
|
||||||
},
|
},
|
||||||
peers::peer_manager::PeerManager,
|
peers::peer_manager::PeerManager,
|
||||||
tunnels::{
|
tunnel::{
|
||||||
ring_tunnel::RingTunnelListener,
|
ring::RingTunnelListener,
|
||||||
tcp_tunnel::TcpTunnelListener,
|
tcp::TcpTunnelListener,
|
||||||
udp_tunnel::UdpTunnelListener,
|
udp::UdpTunnelListener,
|
||||||
wireguard::{WgConfig, WgTunnelListener},
|
wireguard::{WgConfig, WgTunnelListener},
|
||||||
Tunnel, TunnelListener,
|
Tunnel, TunnelListener,
|
||||||
},
|
},
|
||||||
@@ -155,7 +155,7 @@ mod tests {
|
|||||||
|
|
||||||
use crate::{
|
use crate::{
|
||||||
common::global_ctx::tests::get_mock_global_ctx,
|
common::global_ctx::tests::get_mock_global_ctx,
|
||||||
tunnels::{ring_tunnel::RingTunnelConnector, TunnelConnector},
|
tunnel::{packet_def::ZCPacket, ring::RingTunnelConnector, TunnelConnector},
|
||||||
};
|
};
|
||||||
|
|
||||||
use super::*;
|
use super::*;
|
||||||
@@ -165,9 +165,12 @@ mod tests {
|
|||||||
|
|
||||||
#[async_trait]
|
#[async_trait]
|
||||||
impl TunnelHandlerForListener for MockListenerHandler {
|
impl TunnelHandlerForListener for MockListenerHandler {
|
||||||
async fn handle_tunnel(&self, _tunnel: Box<dyn Tunnel>) -> Result<(), Error> {
|
async fn handle_tunnel(&self, tunnel: Box<dyn Tunnel>) -> Result<(), Error> {
|
||||||
let data = "abc";
|
let data = "abc";
|
||||||
_tunnel.pin_sink().send(data.into()).await.unwrap();
|
let (_recv, mut send) = tunnel.split();
|
||||||
|
|
||||||
|
let zc_packet = ZCPacket::new_with_payload(data.as_bytes());
|
||||||
|
send.send(zc_packet).await.unwrap();
|
||||||
Err(Error::Unknown)
|
Err(Error::Unknown)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -187,7 +190,11 @@ mod tests {
|
|||||||
|
|
||||||
let connect_once = |ring_id| async move {
|
let connect_once = |ring_id| async move {
|
||||||
let tunnel = RingTunnelConnector::new(ring_id).connect().await.unwrap();
|
let tunnel = RingTunnelConnector::new(ring_id).connect().await.unwrap();
|
||||||
assert_eq!(tunnel.pin_stream().next().await.unwrap().unwrap(), "abc");
|
let (mut recv, _send) = tunnel.split();
|
||||||
|
assert_eq!(
|
||||||
|
recv.next().await.unwrap().unwrap().payload(),
|
||||||
|
"abc".as_bytes()
|
||||||
|
);
|
||||||
tunnel
|
tunnel
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|||||||
@@ -8,6 +8,7 @@ pub mod instance;
|
|||||||
pub mod peer_center;
|
pub mod peer_center;
|
||||||
pub mod peers;
|
pub mod peers;
|
||||||
pub mod rpc;
|
pub mod rpc;
|
||||||
|
pub mod tunnel;
|
||||||
pub mod tunnels;
|
pub mod tunnels;
|
||||||
pub mod utils;
|
pub mod utils;
|
||||||
pub mod vpn_portal;
|
pub mod vpn_portal;
|
||||||
|
|||||||
@@ -4,23 +4,23 @@ use std::{
|
|||||||
};
|
};
|
||||||
|
|
||||||
use dashmap::DashMap;
|
use dashmap::DashMap;
|
||||||
use tokio::{
|
use tokio::{sync::Mutex, task::JoinSet};
|
||||||
sync::{mpsc, Mutex},
|
|
||||||
task::JoinSet,
|
|
||||||
};
|
|
||||||
use tokio_util::bytes::Bytes;
|
|
||||||
|
|
||||||
use crate::common::{
|
use crate::{
|
||||||
error::Error,
|
common::{
|
||||||
global_ctx::{ArcGlobalCtx, NetworkIdentity},
|
error::Error,
|
||||||
PeerId,
|
global_ctx::{ArcGlobalCtx, NetworkIdentity},
|
||||||
|
PeerId,
|
||||||
|
},
|
||||||
|
tunnel::packet_def::ZCPacket,
|
||||||
};
|
};
|
||||||
|
|
||||||
use super::{
|
use super::{
|
||||||
foreign_network_manager::{ForeignNetworkServiceClient, FOREIGN_NETWORK_SERVICE_ID},
|
foreign_network_manager::{ForeignNetworkServiceClient, FOREIGN_NETWORK_SERVICE_ID},
|
||||||
peer_conn::PeerConn,
|
|
||||||
peer_map::PeerMap,
|
peer_map::PeerMap,
|
||||||
peer_rpc::PeerRpcManager,
|
peer_rpc::PeerRpcManager,
|
||||||
|
zc_peer_conn::PeerConn,
|
||||||
|
PacketRecvChan,
|
||||||
};
|
};
|
||||||
|
|
||||||
pub struct ForeignNetworkClient {
|
pub struct ForeignNetworkClient {
|
||||||
@@ -37,7 +37,7 @@ pub struct ForeignNetworkClient {
|
|||||||
impl ForeignNetworkClient {
|
impl ForeignNetworkClient {
|
||||||
pub fn new(
|
pub fn new(
|
||||||
global_ctx: ArcGlobalCtx,
|
global_ctx: ArcGlobalCtx,
|
||||||
packet_sender_to_mgr: mpsc::Sender<Bytes>,
|
packet_sender_to_mgr: PacketRecvChan,
|
||||||
peer_rpc: Arc<PeerRpcManager>,
|
peer_rpc: Arc<PeerRpcManager>,
|
||||||
my_peer_id: PeerId,
|
my_peer_id: PeerId,
|
||||||
) -> Self {
|
) -> Self {
|
||||||
@@ -148,7 +148,7 @@ impl ForeignNetworkClient {
|
|||||||
self.next_hop.get(&peer_id).map(|v| v.clone())
|
self.next_hop.get(&peer_id).map(|v| v.clone())
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn send_msg(&self, msg: Bytes, peer_id: PeerId) -> Result<(), Error> {
|
pub async fn send_msg(&self, msg: ZCPacket, peer_id: PeerId) -> Result<(), Error> {
|
||||||
if let Some(next_hop) = self.get_next_hop(peer_id) {
|
if let Some(next_hop) = self.get_next_hop(peer_id) {
|
||||||
let ret = self.peer_map.send_msg_directly(msg, next_hop).await;
|
let ret = self.peer_map.send_msg_directly(msg, next_hop).await;
|
||||||
if ret.is_err() {
|
if ret.is_err() {
|
||||||
|
|||||||
@@ -15,19 +15,21 @@ use tokio::{
|
|||||||
},
|
},
|
||||||
task::JoinSet,
|
task::JoinSet,
|
||||||
};
|
};
|
||||||
use tokio_util::bytes::Bytes;
|
|
||||||
|
|
||||||
use crate::common::{
|
use crate::{
|
||||||
error::Error,
|
common::{
|
||||||
global_ctx::{ArcGlobalCtx, GlobalCtxEvent, NetworkIdentity},
|
error::Error,
|
||||||
PeerId,
|
global_ctx::{ArcGlobalCtx, GlobalCtxEvent, NetworkIdentity},
|
||||||
|
PeerId,
|
||||||
|
},
|
||||||
|
tunnel::packet_def::{PacketType, ZCPacket},
|
||||||
};
|
};
|
||||||
|
|
||||||
use super::{
|
use super::{
|
||||||
packet::{self},
|
|
||||||
peer_conn::PeerConn,
|
|
||||||
peer_map::PeerMap,
|
peer_map::PeerMap,
|
||||||
peer_rpc::{PeerRpcManager, PeerRpcManagerTransport},
|
peer_rpc::{PeerRpcManager, PeerRpcManagerTransport},
|
||||||
|
zc_peer_conn::PeerConn,
|
||||||
|
PacketRecvChan, PacketRecvChanReceiver,
|
||||||
};
|
};
|
||||||
|
|
||||||
struct ForeignNetworkEntry {
|
struct ForeignNetworkEntry {
|
||||||
@@ -38,7 +40,7 @@ struct ForeignNetworkEntry {
|
|||||||
impl ForeignNetworkEntry {
|
impl ForeignNetworkEntry {
|
||||||
fn new(
|
fn new(
|
||||||
network: NetworkIdentity,
|
network: NetworkIdentity,
|
||||||
packet_sender: mpsc::Sender<Bytes>,
|
packet_sender: PacketRecvChan,
|
||||||
global_ctx: ArcGlobalCtx,
|
global_ctx: ArcGlobalCtx,
|
||||||
my_peer_id: PeerId,
|
my_peer_id: PeerId,
|
||||||
) -> Self {
|
) -> Self {
|
||||||
@@ -53,7 +55,7 @@ struct ForeignNetworkManagerData {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl ForeignNetworkManagerData {
|
impl ForeignNetworkManagerData {
|
||||||
async fn send_msg(&self, msg: Bytes, dst_peer_id: PeerId) -> Result<(), Error> {
|
async fn send_msg(&self, msg: ZCPacket, dst_peer_id: PeerId) -> Result<(), Error> {
|
||||||
let network_name = self
|
let network_name = self
|
||||||
.peer_network_map
|
.peer_network_map
|
||||||
.get(&dst_peer_id)
|
.get(&dst_peer_id)
|
||||||
@@ -94,7 +96,7 @@ struct RpcTransport {
|
|||||||
my_peer_id: PeerId,
|
my_peer_id: PeerId,
|
||||||
data: Arc<ForeignNetworkManagerData>,
|
data: Arc<ForeignNetworkManagerData>,
|
||||||
|
|
||||||
packet_recv: Mutex<UnboundedReceiver<Bytes>>,
|
packet_recv: Mutex<UnboundedReceiver<ZCPacket>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[async_trait::async_trait]
|
#[async_trait::async_trait]
|
||||||
@@ -103,11 +105,11 @@ impl PeerRpcManagerTransport for RpcTransport {
|
|||||||
self.my_peer_id
|
self.my_peer_id
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn send(&self, msg: Bytes, dst_peer_id: PeerId) -> Result<(), Error> {
|
async fn send(&self, msg: ZCPacket, dst_peer_id: PeerId) -> Result<(), Error> {
|
||||||
self.data.send_msg(msg, dst_peer_id).await
|
self.data.send_msg(msg, dst_peer_id).await
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn recv(&self) -> Result<Bytes, Error> {
|
async fn recv(&self) -> Result<ZCPacket, Error> {
|
||||||
if let Some(o) = self.packet_recv.lock().await.recv().await {
|
if let Some(o) = self.packet_recv.lock().await.recv().await {
|
||||||
Ok(o)
|
Ok(o)
|
||||||
} else {
|
} else {
|
||||||
@@ -138,14 +140,14 @@ impl ForeignNetworkService for Arc<ForeignNetworkManagerData> {
|
|||||||
pub struct ForeignNetworkManager {
|
pub struct ForeignNetworkManager {
|
||||||
my_peer_id: PeerId,
|
my_peer_id: PeerId,
|
||||||
global_ctx: ArcGlobalCtx,
|
global_ctx: ArcGlobalCtx,
|
||||||
packet_sender_to_mgr: mpsc::Sender<Bytes>,
|
packet_sender_to_mgr: PacketRecvChan,
|
||||||
|
|
||||||
packet_sender: mpsc::Sender<Bytes>,
|
packet_sender: PacketRecvChan,
|
||||||
packet_recv: Mutex<Option<mpsc::Receiver<Bytes>>>,
|
packet_recv: Mutex<Option<PacketRecvChanReceiver>>,
|
||||||
|
|
||||||
data: Arc<ForeignNetworkManagerData>,
|
data: Arc<ForeignNetworkManagerData>,
|
||||||
rpc_mgr: Arc<PeerRpcManager>,
|
rpc_mgr: Arc<PeerRpcManager>,
|
||||||
rpc_transport_sender: UnboundedSender<Bytes>,
|
rpc_transport_sender: UnboundedSender<ZCPacket>,
|
||||||
|
|
||||||
tasks: Mutex<JoinSet<()>>,
|
tasks: Mutex<JoinSet<()>>,
|
||||||
}
|
}
|
||||||
@@ -154,7 +156,7 @@ impl ForeignNetworkManager {
|
|||||||
pub fn new(
|
pub fn new(
|
||||||
my_peer_id: PeerId,
|
my_peer_id: PeerId,
|
||||||
global_ctx: ArcGlobalCtx,
|
global_ctx: ArcGlobalCtx,
|
||||||
packet_sender_to_mgr: mpsc::Sender<Bytes>,
|
packet_sender_to_mgr: PacketRecvChan,
|
||||||
) -> Self {
|
) -> Self {
|
||||||
// recv packet from all foreign networks
|
// recv packet from all foreign networks
|
||||||
let (packet_sender, packet_recv) = mpsc::channel(1000);
|
let (packet_sender, packet_recv) = mpsc::channel(1000);
|
||||||
@@ -242,12 +244,15 @@ impl ForeignNetworkManager {
|
|||||||
|
|
||||||
self.tasks.lock().await.spawn(async move {
|
self.tasks.lock().await.spawn(async move {
|
||||||
while let Some(packet_bytes) = recv.recv().await {
|
while let Some(packet_bytes) = recv.recv().await {
|
||||||
let packet = packet::Packet::decode(&packet_bytes);
|
let Some(hdr) = packet_bytes.peer_manager_header() else {
|
||||||
let from_peer_id = packet.from_peer.into();
|
tracing::warn!("invalid packet, skip");
|
||||||
let to_peer_id = packet.to_peer.into();
|
continue;
|
||||||
|
};
|
||||||
|
let from_peer_id = hdr.from_peer_id.get();
|
||||||
|
let to_peer_id = hdr.to_peer_id.get();
|
||||||
if to_peer_id == my_node_id {
|
if to_peer_id == my_node_id {
|
||||||
if packet.packet_type == packet::PacketType::TaRpc {
|
if hdr.packet_type == PacketType::TaRpc as u8 {
|
||||||
rpc_sender.send(packet_bytes.clone()).unwrap();
|
rpc_sender.send(packet_bytes).unwrap();
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
if let Err(e) = sender_to_mgr.send(packet_bytes).await {
|
if let Err(e) = sender_to_mgr.send(packet_bytes).await {
|
||||||
@@ -343,6 +348,27 @@ mod tests {
|
|||||||
peer_mgr
|
peer_mgr
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn foreign_network_basic() {
|
||||||
|
let pm_center = create_mock_peer_manager_with_mock_stun(crate::rpc::NatType::Unknown).await;
|
||||||
|
tracing::debug!("pm_center: {:?}", pm_center.my_peer_id());
|
||||||
|
|
||||||
|
let pma_net1 = create_mock_peer_manager_for_foreign_network("net1").await;
|
||||||
|
let pmb_net1 = create_mock_peer_manager_for_foreign_network("net1").await;
|
||||||
|
tracing::debug!(
|
||||||
|
"pma_net1: {:?}, pmb_net1: {:?}",
|
||||||
|
pma_net1.my_peer_id(),
|
||||||
|
pmb_net1.my_peer_id()
|
||||||
|
);
|
||||||
|
connect_peer_manager(pma_net1.clone(), pm_center.clone()).await;
|
||||||
|
connect_peer_manager(pmb_net1.clone(), pm_center.clone()).await;
|
||||||
|
wait_route_appear(pma_net1.clone(), pmb_net1.clone())
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
assert_eq!(1, pma_net1.list_routes().await.len());
|
||||||
|
assert_eq!(1, pmb_net1.list_routes().await.len());
|
||||||
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_foreign_network_manager() {
|
async fn test_foreign_network_manager() {
|
||||||
let pm_center = create_mock_peer_manager_with_mock_stun(crate::rpc::NatType::Unknown).await;
|
let pm_center = create_mock_peer_manager_with_mock_stun(crate::rpc::NatType::Unknown).await;
|
||||||
@@ -350,11 +376,23 @@ mod tests {
|
|||||||
create_mock_peer_manager_with_mock_stun(crate::rpc::NatType::Unknown).await;
|
create_mock_peer_manager_with_mock_stun(crate::rpc::NatType::Unknown).await;
|
||||||
connect_peer_manager(pm_center.clone(), pm_center2.clone()).await;
|
connect_peer_manager(pm_center.clone(), pm_center2.clone()).await;
|
||||||
|
|
||||||
|
tracing::debug!(
|
||||||
|
"pm_center: {:?}, pm_center2: {:?}",
|
||||||
|
pm_center.my_peer_id(),
|
||||||
|
pm_center2.my_peer_id()
|
||||||
|
);
|
||||||
|
|
||||||
let pma_net1 = create_mock_peer_manager_for_foreign_network("net1").await;
|
let pma_net1 = create_mock_peer_manager_for_foreign_network("net1").await;
|
||||||
let pmb_net1 = create_mock_peer_manager_for_foreign_network("net1").await;
|
let pmb_net1 = create_mock_peer_manager_for_foreign_network("net1").await;
|
||||||
connect_peer_manager(pma_net1.clone(), pm_center.clone()).await;
|
connect_peer_manager(pma_net1.clone(), pm_center.clone()).await;
|
||||||
connect_peer_manager(pmb_net1.clone(), pm_center.clone()).await;
|
connect_peer_manager(pmb_net1.clone(), pm_center.clone()).await;
|
||||||
|
|
||||||
|
tracing::debug!(
|
||||||
|
"pma_net1: {:?}, pmb_net1: {:?}",
|
||||||
|
pma_net1.my_peer_id(),
|
||||||
|
pmb_net1.my_peer_id()
|
||||||
|
);
|
||||||
|
|
||||||
let now = std::time::Instant::now();
|
let now = std::time::Instant::now();
|
||||||
let mut succ = false;
|
let mut succ = false;
|
||||||
while now.elapsed().as_secs() < 10 {
|
while now.elapsed().as_secs() < 10 {
|
||||||
@@ -399,8 +437,15 @@ mod tests {
|
|||||||
.unwrap();
|
.unwrap();
|
||||||
assert_eq!(2, pmc_net1.list_routes().await.len());
|
assert_eq!(2, pmc_net1.list_routes().await.len());
|
||||||
|
|
||||||
|
tracing::debug!("pmc_net1: {:?}", pmc_net1.my_peer_id());
|
||||||
|
|
||||||
let pma_net2 = create_mock_peer_manager_for_foreign_network("net2").await;
|
let pma_net2 = create_mock_peer_manager_for_foreign_network("net2").await;
|
||||||
let pmb_net2 = create_mock_peer_manager_for_foreign_network("net2").await;
|
let pmb_net2 = create_mock_peer_manager_for_foreign_network("net2").await;
|
||||||
|
tracing::debug!(
|
||||||
|
"pma_net2: {:?}, pmb_net2: {:?}",
|
||||||
|
pma_net2.my_peer_id(),
|
||||||
|
pmb_net2.my_peer_id()
|
||||||
|
);
|
||||||
connect_peer_manager(pma_net2.clone(), pm_center.clone()).await;
|
connect_peer_manager(pma_net2.clone(), pm_center.clone()).await;
|
||||||
connect_peer_manager(pmb_net2.clone(), pm_center.clone()).await;
|
connect_peer_manager(pmb_net2.clone(), pm_center.clone()).await;
|
||||||
wait_route_appear(pma_net2.clone(), pmb_net2.clone())
|
wait_route_appear(pma_net2.clone(), pmb_net2.clone())
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
pub mod packet;
|
pub mod packet;
|
||||||
pub mod peer;
|
pub mod peer;
|
||||||
pub mod peer_conn;
|
// pub mod peer_conn;
|
||||||
|
pub mod peer_conn_ping;
|
||||||
pub mod peer_manager;
|
pub mod peer_manager;
|
||||||
pub mod peer_map;
|
pub mod peer_map;
|
||||||
pub mod peer_ospf_route;
|
pub mod peer_ospf_route;
|
||||||
@@ -8,6 +9,7 @@ pub mod peer_rip_route;
|
|||||||
pub mod peer_rpc;
|
pub mod peer_rpc;
|
||||||
pub mod route_trait;
|
pub mod route_trait;
|
||||||
pub mod rpc_service;
|
pub mod rpc_service;
|
||||||
|
pub mod zc_peer_conn;
|
||||||
|
|
||||||
pub mod foreign_network_client;
|
pub mod foreign_network_client;
|
||||||
pub mod foreign_network_manager;
|
pub mod foreign_network_manager;
|
||||||
@@ -15,25 +17,24 @@ pub mod foreign_network_manager;
|
|||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
pub mod tests;
|
pub mod tests;
|
||||||
|
|
||||||
use tokio_util::bytes::{Bytes, BytesMut};
|
use crate::tunnel::packet_def::ZCPacket;
|
||||||
|
|
||||||
#[async_trait::async_trait]
|
#[async_trait::async_trait]
|
||||||
#[auto_impl::auto_impl(Arc)]
|
#[auto_impl::auto_impl(Arc)]
|
||||||
pub trait PeerPacketFilter {
|
pub trait PeerPacketFilter {
|
||||||
async fn try_process_packet_from_peer(
|
async fn try_process_packet_from_peer(&self, _zc_packet: ZCPacket) -> Option<ZCPacket> {
|
||||||
&self,
|
Some(_zc_packet)
|
||||||
_packet: &packet::ArchivedPacket,
|
|
||||||
_data: &Bytes,
|
|
||||||
) -> Option<()> {
|
|
||||||
None
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[async_trait::async_trait]
|
#[async_trait::async_trait]
|
||||||
#[auto_impl::auto_impl(Arc)]
|
#[auto_impl::auto_impl(Arc)]
|
||||||
pub trait NicPacketFilter {
|
pub trait NicPacketFilter {
|
||||||
async fn try_process_packet_from_nic(&self, data: BytesMut) -> BytesMut;
|
async fn try_process_packet_from_nic(&self, data: &mut ZCPacket);
|
||||||
}
|
}
|
||||||
|
|
||||||
type BoxPeerPacketFilter = Box<dyn PeerPacketFilter + Send + Sync>;
|
type BoxPeerPacketFilter = Box<dyn PeerPacketFilter + Send + Sync>;
|
||||||
type BoxNicPacketFilter = Box<dyn NicPacketFilter + Send + Sync>;
|
type BoxNicPacketFilter = Box<dyn NicPacketFilter + Send + Sync>;
|
||||||
|
|
||||||
|
pub type PacketRecvChan = tokio::sync::mpsc::Sender<ZCPacket>;
|
||||||
|
pub type PacketRecvChanReceiver = tokio::sync::mpsc::Receiver<ZCPacket>;
|
||||||
|
|||||||
+17
-11
@@ -7,16 +7,22 @@ use tokio::{
|
|||||||
sync::{mpsc, Mutex},
|
sync::{mpsc, Mutex},
|
||||||
task::JoinHandle,
|
task::JoinHandle,
|
||||||
};
|
};
|
||||||
use tokio_util::bytes::Bytes;
|
|
||||||
use tracing::Instrument;
|
use tracing::Instrument;
|
||||||
|
|
||||||
use super::peer_conn::{PeerConn, PeerConnId};
|
use super::{
|
||||||
use crate::common::{
|
zc_peer_conn::{PeerConn, PeerConnId},
|
||||||
error::Error,
|
PacketRecvChan,
|
||||||
global_ctx::{ArcGlobalCtx, GlobalCtxEvent},
|
|
||||||
PeerId,
|
|
||||||
};
|
};
|
||||||
use crate::rpc::PeerConnInfo;
|
use crate::rpc::PeerConnInfo;
|
||||||
|
use crate::{
|
||||||
|
common::{
|
||||||
|
error::Error,
|
||||||
|
global_ctx::{ArcGlobalCtx, GlobalCtxEvent},
|
||||||
|
PeerId,
|
||||||
|
},
|
||||||
|
tunnel::packet_def::ZCPacket,
|
||||||
|
};
|
||||||
|
|
||||||
type ArcPeerConn = Arc<Mutex<PeerConn>>;
|
type ArcPeerConn = Arc<Mutex<PeerConn>>;
|
||||||
type ConnMap = Arc<DashMap<PeerConnId, ArcPeerConn>>;
|
type ConnMap = Arc<DashMap<PeerConnId, ArcPeerConn>>;
|
||||||
@@ -26,7 +32,7 @@ pub struct Peer {
|
|||||||
conns: ConnMap,
|
conns: ConnMap,
|
||||||
global_ctx: ArcGlobalCtx,
|
global_ctx: ArcGlobalCtx,
|
||||||
|
|
||||||
packet_recv_chan: mpsc::Sender<Bytes>,
|
packet_recv_chan: PacketRecvChan,
|
||||||
|
|
||||||
close_event_sender: mpsc::Sender<PeerConnId>,
|
close_event_sender: mpsc::Sender<PeerConnId>,
|
||||||
close_event_listener: JoinHandle<()>,
|
close_event_listener: JoinHandle<()>,
|
||||||
@@ -37,7 +43,7 @@ pub struct Peer {
|
|||||||
impl Peer {
|
impl Peer {
|
||||||
pub fn new(
|
pub fn new(
|
||||||
peer_node_id: PeerId,
|
peer_node_id: PeerId,
|
||||||
packet_recv_chan: mpsc::Sender<Bytes>,
|
packet_recv_chan: PacketRecvChan,
|
||||||
global_ctx: ArcGlobalCtx,
|
global_ctx: ArcGlobalCtx,
|
||||||
) -> Self {
|
) -> Self {
|
||||||
let conns: ConnMap = Arc::new(DashMap::new());
|
let conns: ConnMap = Arc::new(DashMap::new());
|
||||||
@@ -106,7 +112,7 @@ impl Peer {
|
|||||||
.insert(conn.get_conn_id(), Arc::new(Mutex::new(conn)));
|
.insert(conn.get_conn_id(), Arc::new(Mutex::new(conn)));
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn send_msg(&self, msg: Bytes) -> Result<(), Error> {
|
pub async fn send_msg(&self, msg: ZCPacket) -> Result<(), Error> {
|
||||||
let Some(conn) = self.conns.iter().next() else {
|
let Some(conn) = self.conns.iter().next() else {
|
||||||
return Err(Error::PeerNoConnectionError(self.peer_node_id));
|
return Err(Error::PeerNoConnectionError(self.peer_node_id));
|
||||||
};
|
};
|
||||||
@@ -157,8 +163,8 @@ mod tests {
|
|||||||
|
|
||||||
use crate::{
|
use crate::{
|
||||||
common::{global_ctx::tests::get_mock_global_ctx, new_peer_id},
|
common::{global_ctx::tests::get_mock_global_ctx, new_peer_id},
|
||||||
peers::peer_conn::PeerConn,
|
peers::zc_peer_conn::PeerConn,
|
||||||
tunnels::ring_tunnel::create_ring_tunnel_pair,
|
tunnel::ring::create_ring_tunnel_pair,
|
||||||
};
|
};
|
||||||
|
|
||||||
use super::Peer;
|
use super::Peer;
|
||||||
|
|||||||
@@ -0,0 +1,219 @@
|
|||||||
|
use std::{
|
||||||
|
sync::{
|
||||||
|
atomic::{AtomicU32, Ordering},
|
||||||
|
Arc,
|
||||||
|
},
|
||||||
|
time::Duration,
|
||||||
|
};
|
||||||
|
|
||||||
|
use tokio::{sync::broadcast, task::JoinSet, time::timeout};
|
||||||
|
|
||||||
|
use crate::{
|
||||||
|
common::{error::Error, PeerId},
|
||||||
|
tunnel::{
|
||||||
|
mpsc::MpscTunnelSender,
|
||||||
|
packet_def::{PacketType, ZCPacket},
|
||||||
|
stats::WindowLatency,
|
||||||
|
TunnelError,
|
||||||
|
},
|
||||||
|
};
|
||||||
|
|
||||||
|
pub struct PeerConnPinger {
|
||||||
|
my_peer_id: PeerId,
|
||||||
|
peer_id: PeerId,
|
||||||
|
sink: MpscTunnelSender,
|
||||||
|
ctrl_sender: broadcast::Sender<ZCPacket>,
|
||||||
|
latency_stats: Arc<WindowLatency>,
|
||||||
|
loss_rate_stats: Arc<AtomicU32>,
|
||||||
|
tasks: JoinSet<Result<(), TunnelError>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl std::fmt::Debug for PeerConnPinger {
|
||||||
|
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||||
|
f.debug_struct("PeerConnPinger")
|
||||||
|
.field("my_peer_id", &self.my_peer_id)
|
||||||
|
.field("peer_id", &self.peer_id)
|
||||||
|
.finish()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl PeerConnPinger {
|
||||||
|
pub fn new(
|
||||||
|
my_peer_id: PeerId,
|
||||||
|
peer_id: PeerId,
|
||||||
|
sink: MpscTunnelSender,
|
||||||
|
ctrl_sender: broadcast::Sender<ZCPacket>,
|
||||||
|
latency_stats: Arc<WindowLatency>,
|
||||||
|
loss_rate_stats: Arc<AtomicU32>,
|
||||||
|
) -> Self {
|
||||||
|
Self {
|
||||||
|
my_peer_id,
|
||||||
|
peer_id,
|
||||||
|
sink,
|
||||||
|
tasks: JoinSet::new(),
|
||||||
|
latency_stats,
|
||||||
|
ctrl_sender,
|
||||||
|
loss_rate_stats,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn new_ping_packet(my_node_id: PeerId, peer_id: PeerId, seq: u32) -> ZCPacket {
|
||||||
|
let mut packet = ZCPacket::new_with_payload(&seq.to_le_bytes());
|
||||||
|
packet.fill_peer_manager_hdr(my_node_id, peer_id, PacketType::Ping as u8);
|
||||||
|
packet
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn do_pingpong_once(
|
||||||
|
my_node_id: PeerId,
|
||||||
|
peer_id: PeerId,
|
||||||
|
sink: &mut MpscTunnelSender,
|
||||||
|
receiver: &mut broadcast::Receiver<ZCPacket>,
|
||||||
|
seq: u32,
|
||||||
|
) -> Result<u128, Error> {
|
||||||
|
// should add seq here. so latency can be calculated more accurately
|
||||||
|
let req = Self::new_ping_packet(my_node_id, peer_id, seq);
|
||||||
|
sink.send(req).await?;
|
||||||
|
|
||||||
|
let now = std::time::Instant::now();
|
||||||
|
// wait until we get a pong packet in ctrl_resp_receiver
|
||||||
|
let resp = timeout(Duration::from_secs(1), async {
|
||||||
|
loop {
|
||||||
|
match receiver.recv().await {
|
||||||
|
Ok(p) => {
|
||||||
|
let payload = p.payload();
|
||||||
|
let Ok(seq_buf) = payload[0..4].try_into() else {
|
||||||
|
tracing::debug!("pingpong recv invalid packet, continue");
|
||||||
|
continue;
|
||||||
|
};
|
||||||
|
let resp_seq = u32::from_le_bytes(seq_buf);
|
||||||
|
if resp_seq == seq {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Err(e) => {
|
||||||
|
return Err(Error::WaitRespError(format!(
|
||||||
|
"wait ping response error: {:?}",
|
||||||
|
e
|
||||||
|
)));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
})
|
||||||
|
.await;
|
||||||
|
|
||||||
|
tracing::trace!(?resp, "wait ping response done");
|
||||||
|
|
||||||
|
if resp.is_err() {
|
||||||
|
return Err(Error::WaitRespError(
|
||||||
|
"wait ping response timeout".to_owned(),
|
||||||
|
));
|
||||||
|
}
|
||||||
|
|
||||||
|
if resp.as_ref().unwrap().is_err() {
|
||||||
|
return Err(resp.unwrap().err().unwrap());
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(now.elapsed().as_micros())
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn pingpong(&mut self) {
|
||||||
|
let sink = self.sink.clone();
|
||||||
|
let my_node_id = self.my_peer_id;
|
||||||
|
let peer_id = self.peer_id;
|
||||||
|
let latency_stats = self.latency_stats.clone();
|
||||||
|
|
||||||
|
let (ping_res_sender, mut ping_res_receiver) = tokio::sync::mpsc::channel(100);
|
||||||
|
|
||||||
|
let stopped = Arc::new(AtomicU32::new(0));
|
||||||
|
|
||||||
|
// generate a pingpong task every 200ms
|
||||||
|
let mut pingpong_tasks = JoinSet::new();
|
||||||
|
let ctrl_resp_sender = self.ctrl_sender.clone();
|
||||||
|
let stopped_clone = stopped.clone();
|
||||||
|
self.tasks.spawn(async move {
|
||||||
|
let mut req_seq = 0;
|
||||||
|
loop {
|
||||||
|
let receiver = ctrl_resp_sender.subscribe();
|
||||||
|
let ping_res_sender = ping_res_sender.clone();
|
||||||
|
|
||||||
|
if stopped_clone.load(Ordering::Relaxed) != 0 {
|
||||||
|
return Ok(());
|
||||||
|
}
|
||||||
|
|
||||||
|
while pingpong_tasks.len() > 5 {
|
||||||
|
pingpong_tasks.join_next().await;
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut sink = sink.clone();
|
||||||
|
pingpong_tasks.spawn(async move {
|
||||||
|
let mut receiver = receiver.resubscribe();
|
||||||
|
let pingpong_once_ret = Self::do_pingpong_once(
|
||||||
|
my_node_id,
|
||||||
|
peer_id,
|
||||||
|
&mut sink,
|
||||||
|
&mut receiver,
|
||||||
|
req_seq,
|
||||||
|
)
|
||||||
|
.await;
|
||||||
|
|
||||||
|
if let Err(e) = ping_res_sender.send(pingpong_once_ret).await {
|
||||||
|
tracing::info!(?e, "pingpong task send result error, exit..");
|
||||||
|
};
|
||||||
|
});
|
||||||
|
|
||||||
|
req_seq = req_seq.wrapping_add(1);
|
||||||
|
tokio::time::sleep(Duration::from_millis(1000)).await;
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
// one with 1% precision
|
||||||
|
let loss_rate_stats_1 = WindowLatency::new(100);
|
||||||
|
// one with 20% precision, so we can fast fail this conn.
|
||||||
|
let loss_rate_stats_20 = WindowLatency::new(5);
|
||||||
|
|
||||||
|
let mut counter: u64 = 0;
|
||||||
|
|
||||||
|
while let Some(ret) = ping_res_receiver.recv().await {
|
||||||
|
counter += 1;
|
||||||
|
|
||||||
|
if let Ok(lat) = ret {
|
||||||
|
latency_stats.record_latency(lat as u32);
|
||||||
|
|
||||||
|
loss_rate_stats_1.record_latency(0);
|
||||||
|
loss_rate_stats_20.record_latency(0);
|
||||||
|
} else {
|
||||||
|
loss_rate_stats_1.record_latency(1);
|
||||||
|
loss_rate_stats_20.record_latency(1);
|
||||||
|
}
|
||||||
|
|
||||||
|
let loss_rate_20: f64 = loss_rate_stats_20.get_latency_us();
|
||||||
|
let loss_rate_1: f64 = loss_rate_stats_1.get_latency_us();
|
||||||
|
|
||||||
|
tracing::trace!(
|
||||||
|
?ret,
|
||||||
|
?self,
|
||||||
|
?loss_rate_1,
|
||||||
|
?loss_rate_20,
|
||||||
|
"pingpong task recv pingpong_once result"
|
||||||
|
);
|
||||||
|
|
||||||
|
if (counter > 5 && loss_rate_20 > 0.74) || (counter > 150 && loss_rate_1 > 0.20) {
|
||||||
|
tracing::warn!(
|
||||||
|
?ret,
|
||||||
|
?self,
|
||||||
|
?loss_rate_1,
|
||||||
|
?loss_rate_20,
|
||||||
|
"pingpong loss rate too high, closing"
|
||||||
|
);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
self.loss_rate_stats
|
||||||
|
.store((loss_rate_1 * 100.0) as u32, Ordering::Relaxed);
|
||||||
|
}
|
||||||
|
|
||||||
|
stopped.store(1, Ordering::Relaxed);
|
||||||
|
ping_res_receiver.close();
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -5,6 +5,7 @@ use std::{
|
|||||||
};
|
};
|
||||||
|
|
||||||
use async_trait::async_trait;
|
use async_trait::async_trait;
|
||||||
|
|
||||||
use futures::StreamExt;
|
use futures::StreamExt;
|
||||||
|
|
||||||
use tokio::{
|
use tokio::{
|
||||||
@@ -15,30 +16,30 @@ use tokio::{
|
|||||||
task::JoinSet,
|
task::JoinSet,
|
||||||
};
|
};
|
||||||
use tokio_stream::wrappers::ReceiverStream;
|
use tokio_stream::wrappers::ReceiverStream;
|
||||||
use tokio_util::bytes::{Bytes, BytesMut};
|
use tokio_util::bytes::Bytes;
|
||||||
|
|
||||||
use crate::{
|
use crate::{
|
||||||
common::{
|
common::{error::Error, global_ctx::ArcGlobalCtx, PeerId},
|
||||||
error::Error, global_ctx::ArcGlobalCtx, rkyv_util::extract_bytes_from_archived_string,
|
|
||||||
PeerId,
|
|
||||||
},
|
|
||||||
peers::{
|
peers::{
|
||||||
packet, peer_conn::PeerConn, peer_rpc::PeerRpcManagerTransport,
|
packet, peer_rpc::PeerRpcManagerTransport, route_trait::RouteInterface,
|
||||||
route_trait::RouteInterface, PeerPacketFilter,
|
zc_peer_conn::PeerConn, PeerPacketFilter,
|
||||||
|
},
|
||||||
|
tunnel::{
|
||||||
|
packet_def::{PacketType, ZCPacket},
|
||||||
|
SinkItem, Tunnel, TunnelConnector,
|
||||||
},
|
},
|
||||||
tunnels::{SinkItem, Tunnel, TunnelConnector},
|
|
||||||
};
|
};
|
||||||
|
|
||||||
use super::{
|
use super::{
|
||||||
foreign_network_client::ForeignNetworkClient,
|
foreign_network_client::ForeignNetworkClient,
|
||||||
foreign_network_manager::ForeignNetworkManager,
|
foreign_network_manager::ForeignNetworkManager,
|
||||||
peer_conn::PeerConnId,
|
|
||||||
peer_map::PeerMap,
|
peer_map::PeerMap,
|
||||||
peer_ospf_route::PeerRoute,
|
peer_ospf_route::PeerRoute,
|
||||||
peer_rip_route::BasicRoute,
|
peer_rip_route::BasicRoute,
|
||||||
peer_rpc::PeerRpcManager,
|
peer_rpc::PeerRpcManager,
|
||||||
route_trait::{ArcRoute, Route},
|
route_trait::{ArcRoute, Route},
|
||||||
BoxNicPacketFilter, BoxPeerPacketFilter,
|
zc_peer_conn::PeerConnId,
|
||||||
|
BoxNicPacketFilter, BoxPeerPacketFilter, PacketRecvChanReceiver,
|
||||||
};
|
};
|
||||||
|
|
||||||
struct RpcTransport {
|
struct RpcTransport {
|
||||||
@@ -46,8 +47,8 @@ struct RpcTransport {
|
|||||||
peers: Weak<PeerMap>,
|
peers: Weak<PeerMap>,
|
||||||
foreign_peers: Mutex<Option<Weak<ForeignNetworkClient>>>,
|
foreign_peers: Mutex<Option<Weak<ForeignNetworkClient>>>,
|
||||||
|
|
||||||
packet_recv: Mutex<UnboundedReceiver<Bytes>>,
|
packet_recv: Mutex<UnboundedReceiver<ZCPacket>>,
|
||||||
peer_rpc_tspt_sender: UnboundedSender<Bytes>,
|
peer_rpc_tspt_sender: UnboundedSender<ZCPacket>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[async_trait::async_trait]
|
#[async_trait::async_trait]
|
||||||
@@ -56,7 +57,7 @@ impl PeerRpcManagerTransport for RpcTransport {
|
|||||||
self.my_peer_id
|
self.my_peer_id
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn send(&self, msg: Bytes, dst_peer_id: PeerId) -> Result<(), Error> {
|
async fn send(&self, msg: ZCPacket, dst_peer_id: PeerId) -> Result<(), Error> {
|
||||||
let foreign_peers = self
|
let foreign_peers = self
|
||||||
.foreign_peers
|
.foreign_peers
|
||||||
.lock()
|
.lock()
|
||||||
@@ -67,21 +68,30 @@ impl PeerRpcManagerTransport for RpcTransport {
|
|||||||
.ok_or(Error::Unknown)?;
|
.ok_or(Error::Unknown)?;
|
||||||
let peers = self.peers.upgrade().ok_or(Error::Unknown)?;
|
let peers = self.peers.upgrade().ok_or(Error::Unknown)?;
|
||||||
|
|
||||||
let ret = peers.send_msg(msg.clone(), dst_peer_id).await;
|
if let Some(gateway_id) = peers.get_gateway_peer_id(dst_peer_id).await {
|
||||||
|
tracing::trace!(
|
||||||
if matches!(ret, Err(Error::RouteError(..))) && foreign_peers.has_next_hop(dst_peer_id) {
|
?dst_peer_id,
|
||||||
tracing::info!(
|
?gateway_id,
|
||||||
|
?self.my_peer_id,
|
||||||
|
"send msg to peer via gateway",
|
||||||
|
);
|
||||||
|
peers.send_msg_directly(msg, gateway_id).await
|
||||||
|
} else if foreign_peers.has_next_hop(dst_peer_id) {
|
||||||
|
tracing::debug!(
|
||||||
?dst_peer_id,
|
?dst_peer_id,
|
||||||
?self.my_peer_id,
|
?self.my_peer_id,
|
||||||
"failed to send msg to peer, try foreign network",
|
"failed to send msg to peer, try foreign network",
|
||||||
);
|
);
|
||||||
return foreign_peers.send_msg(msg, dst_peer_id).await;
|
foreign_peers.send_msg(msg, dst_peer_id).await
|
||||||
|
} else {
|
||||||
|
Err(Error::RouteError(Some(format!(
|
||||||
|
"peermgr RpcTransport no route for dst_peer_id: {}",
|
||||||
|
dst_peer_id
|
||||||
|
))))
|
||||||
}
|
}
|
||||||
|
|
||||||
ret
|
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn recv(&self) -> Result<Bytes, Error> {
|
async fn recv(&self) -> Result<ZCPacket, Error> {
|
||||||
if let Some(o) = self.packet_recv.lock().await.recv().await {
|
if let Some(o) = self.packet_recv.lock().await.recv().await {
|
||||||
Ok(o)
|
Ok(o)
|
||||||
} else {
|
} else {
|
||||||
@@ -110,7 +120,7 @@ pub struct PeerManager {
|
|||||||
|
|
||||||
tasks: Arc<Mutex<JoinSet<()>>>,
|
tasks: Arc<Mutex<JoinSet<()>>>,
|
||||||
|
|
||||||
packet_recv: Arc<Mutex<Option<mpsc::Receiver<Bytes>>>>,
|
packet_recv: Arc<Mutex<Option<PacketRecvChanReceiver>>>,
|
||||||
|
|
||||||
peers: Arc<PeerMap>,
|
peers: Arc<PeerMap>,
|
||||||
|
|
||||||
@@ -261,17 +271,20 @@ impl PeerManager {
|
|||||||
self.tasks.lock().await.spawn(async move {
|
self.tasks.lock().await.spawn(async move {
|
||||||
log::trace!("start_peer_recv");
|
log::trace!("start_peer_recv");
|
||||||
while let Some(ret) = recv.next().await {
|
while let Some(ret) = recv.next().await {
|
||||||
log::trace!("peer recv a packet...: {:?}", ret);
|
let Some(hdr) = ret.peer_manager_header() else {
|
||||||
let packet = packet::Packet::decode(&ret);
|
tracing::warn!(?ret, "invalid packet, skip");
|
||||||
let from_peer_id: PeerId = packet.from_peer.into();
|
continue;
|
||||||
let to_peer_id: PeerId = packet.to_peer.into();
|
};
|
||||||
|
tracing::trace!(?hdr, ?ret, "peer recv a packet...");
|
||||||
|
let from_peer_id = hdr.from_peer_id.get();
|
||||||
|
let to_peer_id = hdr.to_peer_id.get();
|
||||||
if to_peer_id != my_peer_id {
|
if to_peer_id != my_peer_id {
|
||||||
log::trace!(
|
log::trace!(
|
||||||
"need forward: to_peer_id: {:?}, my_peer_id: {:?}",
|
"need forward: to_peer_id: {:?}, my_peer_id: {:?}",
|
||||||
to_peer_id,
|
to_peer_id,
|
||||||
my_peer_id
|
my_peer_id
|
||||||
);
|
);
|
||||||
let ret = peers.send_msg(ret.clone(), to_peer_id).await;
|
let ret = peers.send_msg(ret, to_peer_id).await;
|
||||||
if ret.is_err() {
|
if ret.is_err() {
|
||||||
log::error!(
|
log::error!(
|
||||||
"forward packet error: {:?}, dst: {:?}, from: {:?}",
|
"forward packet error: {:?}, dst: {:?}, from: {:?}",
|
||||||
@@ -282,15 +295,21 @@ impl PeerManager {
|
|||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
let mut processed = false;
|
let mut processed = false;
|
||||||
|
let mut zc_packet = Some(ret);
|
||||||
|
let mut idx = 0;
|
||||||
for pipeline in pipe_line.read().await.iter().rev() {
|
for pipeline in pipe_line.read().await.iter().rev() {
|
||||||
if let Some(_) = pipeline.try_process_packet_from_peer(&packet, &ret).await
|
tracing::debug!(?zc_packet, ?idx, "try_process_packet_from_peer");
|
||||||
{
|
idx += 1;
|
||||||
|
zc_packet = pipeline
|
||||||
|
.try_process_packet_from_peer(zc_packet.unwrap())
|
||||||
|
.await;
|
||||||
|
if zc_packet.is_none() {
|
||||||
processed = true;
|
processed = true;
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if !processed {
|
if !processed {
|
||||||
tracing::error!("unexpected packet: {:?}", ret);
|
tracing::error!(?zc_packet, "unhandled packet");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -321,20 +340,15 @@ impl PeerManager {
|
|||||||
}
|
}
|
||||||
#[async_trait::async_trait]
|
#[async_trait::async_trait]
|
||||||
impl PeerPacketFilter for NicPacketProcessor {
|
impl PeerPacketFilter for NicPacketProcessor {
|
||||||
async fn try_process_packet_from_peer(
|
async fn try_process_packet_from_peer(&self, packet: ZCPacket) -> Option<ZCPacket> {
|
||||||
&self,
|
let hdr = packet.peer_manager_header().unwrap();
|
||||||
packet: &packet::ArchivedPacket,
|
if hdr.packet_type == PacketType::Data as u8 {
|
||||||
data: &Bytes,
|
tracing::trace!(?packet, "send packet to nic channel");
|
||||||
) -> Option<()> {
|
|
||||||
if packet.packet_type == packet::PacketType::Data {
|
|
||||||
// TODO: use a function to get the body ref directly for zero copy
|
// TODO: use a function to get the body ref directly for zero copy
|
||||||
self.nic_channel
|
self.nic_channel.send(packet).await.unwrap();
|
||||||
.send(extract_bytes_from_archived_string(data, &packet.payload))
|
|
||||||
.await
|
|
||||||
.unwrap();
|
|
||||||
Some(())
|
|
||||||
} else {
|
|
||||||
None
|
None
|
||||||
|
} else {
|
||||||
|
Some(packet)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -345,21 +359,18 @@ impl PeerManager {
|
|||||||
|
|
||||||
// for peer rpc packet
|
// for peer rpc packet
|
||||||
struct PeerRpcPacketProcessor {
|
struct PeerRpcPacketProcessor {
|
||||||
peer_rpc_tspt_sender: UnboundedSender<Bytes>,
|
peer_rpc_tspt_sender: UnboundedSender<ZCPacket>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[async_trait::async_trait]
|
#[async_trait::async_trait]
|
||||||
impl PeerPacketFilter for PeerRpcPacketProcessor {
|
impl PeerPacketFilter for PeerRpcPacketProcessor {
|
||||||
async fn try_process_packet_from_peer(
|
async fn try_process_packet_from_peer(&self, packet: ZCPacket) -> Option<ZCPacket> {
|
||||||
&self,
|
let hdr = packet.peer_manager_header().unwrap();
|
||||||
packet: &packet::ArchivedPacket,
|
if hdr.packet_type == PacketType::TaRpc as u8 {
|
||||||
data: &Bytes,
|
self.peer_rpc_tspt_sender.send(packet).unwrap();
|
||||||
) -> Option<()> {
|
|
||||||
if packet.packet_type == packet::PacketType::TaRpc {
|
|
||||||
self.peer_rpc_tspt_sender.send(data.clone()).unwrap();
|
|
||||||
Some(())
|
|
||||||
} else {
|
|
||||||
None
|
None
|
||||||
|
} else {
|
||||||
|
Some(packet)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -401,7 +412,7 @@ impl PeerManager {
|
|||||||
async fn send_route_packet(
|
async fn send_route_packet(
|
||||||
&self,
|
&self,
|
||||||
msg: Bytes,
|
msg: Bytes,
|
||||||
route_id: u8,
|
_route_id: u8,
|
||||||
dst_peer_id: PeerId,
|
dst_peer_id: PeerId,
|
||||||
) -> Result<(), Error> {
|
) -> Result<(), Error> {
|
||||||
let foreign_client = self
|
let foreign_client = self
|
||||||
@@ -409,15 +420,17 @@ impl PeerManager {
|
|||||||
.upgrade()
|
.upgrade()
|
||||||
.ok_or(Error::Unknown)?;
|
.ok_or(Error::Unknown)?;
|
||||||
let peer_map = self.peers.upgrade().ok_or(Error::Unknown)?;
|
let peer_map = self.peers.upgrade().ok_or(Error::Unknown)?;
|
||||||
|
let mut zc_packet = ZCPacket::new_with_payload(&msg);
|
||||||
let packet_bytes: Bytes =
|
zc_packet.fill_peer_manager_hdr(
|
||||||
packet::Packet::new_route_packet(self.my_peer_id, dst_peer_id, route_id, &msg)
|
self.my_peer_id,
|
||||||
.into();
|
dst_peer_id,
|
||||||
|
PacketType::Route as u8,
|
||||||
|
);
|
||||||
if foreign_client.has_next_hop(dst_peer_id) {
|
if foreign_client.has_next_hop(dst_peer_id) {
|
||||||
return foreign_client.send_msg(packet_bytes, dst_peer_id).await;
|
foreign_client.send_msg(zc_packet, dst_peer_id).await
|
||||||
|
} else {
|
||||||
|
peer_map.send_msg_directly(zc_packet, dst_peer_id).await
|
||||||
}
|
}
|
||||||
|
|
||||||
peer_map.send_msg_directly(packet_bytes, dst_peer_id).await
|
|
||||||
}
|
}
|
||||||
fn my_peer_id(&self) -> PeerId {
|
fn my_peer_id(&self) -> PeerId {
|
||||||
self.my_peer_id
|
self.my_peer_id
|
||||||
@@ -450,18 +463,17 @@ impl PeerManager {
|
|||||||
self.get_route().list_routes().await
|
self.get_route().list_routes().await
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn run_nic_packet_process_pipeline(&self, mut data: BytesMut) -> BytesMut {
|
async fn run_nic_packet_process_pipeline(&self, data: &mut ZCPacket) {
|
||||||
for pipeline in self.nic_packet_process_pipeline.read().await.iter().rev() {
|
for pipeline in self.nic_packet_process_pipeline.read().await.iter().rev() {
|
||||||
data = pipeline.try_process_packet_from_nic(data).await;
|
pipeline.try_process_packet_from_nic(data).await;
|
||||||
}
|
}
|
||||||
data
|
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn send_msg(&self, msg: Bytes, dst_peer_id: PeerId) -> Result<(), Error> {
|
pub async fn send_msg(&self, msg: ZCPacket, dst_peer_id: PeerId) -> Result<(), Error> {
|
||||||
self.peers.send_msg(msg, dst_peer_id).await
|
self.peers.send_msg(msg, dst_peer_id).await
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn send_msg_ipv4(&self, msg: BytesMut, ipv4_addr: Ipv4Addr) -> Result<(), Error> {
|
pub async fn send_msg_ipv4(&self, mut msg: ZCPacket, ipv4_addr: Ipv4Addr) -> Result<(), Error> {
|
||||||
log::trace!(
|
log::trace!(
|
||||||
"do send_msg in peer manager, msg: {:?}, ipv4_addr: {}",
|
"do send_msg in peer manager, msg: {:?}, ipv4_addr: {}",
|
||||||
msg,
|
msg,
|
||||||
@@ -487,25 +499,34 @@ impl PeerManager {
|
|||||||
return Ok(());
|
return Ok(());
|
||||||
}
|
}
|
||||||
|
|
||||||
let msg = self.run_nic_packet_process_pipeline(msg).await;
|
self.run_nic_packet_process_pipeline(&mut msg).await;
|
||||||
let mut errs: Vec<Error> = vec![];
|
let mut errs: Vec<Error> = vec![];
|
||||||
|
|
||||||
for peer_id in dst_peers.iter() {
|
let mut msg = Some(msg);
|
||||||
let msg: Bytes =
|
let total_dst_peers = dst_peers.len();
|
||||||
packet::Packet::new_data_packet(self.my_peer_id, peer_id.clone(), &msg).into();
|
for i in 0..total_dst_peers {
|
||||||
let send_ret = self.peers.send_msg(msg.clone(), *peer_id).await;
|
let mut msg = if i == total_dst_peers - 1 {
|
||||||
|
msg.take().unwrap()
|
||||||
|
} else {
|
||||||
|
msg.clone().unwrap()
|
||||||
|
};
|
||||||
|
|
||||||
if matches!(send_ret, Err(Error::RouteError(..)))
|
let peer_id = &dst_peers[i];
|
||||||
&& self.foreign_network_client.has_next_hop(*peer_id)
|
|
||||||
{
|
msg.fill_peer_manager_hdr(self.my_peer_id, *peer_id, packet::PacketType::Data as u8);
|
||||||
let foreign_send_ret = self.foreign_network_client.send_msg(msg, *peer_id).await;
|
|
||||||
if foreign_send_ret.is_ok() {
|
if let Some(gateway) = self.peers.get_gateway_peer_id(*peer_id).await {
|
||||||
continue;
|
if let Err(e) = self.peers.send_msg_directly(msg.clone(), gateway).await {
|
||||||
|
errs.push(e);
|
||||||
|
}
|
||||||
|
} else if self.foreign_network_client.has_next_hop(*peer_id) {
|
||||||
|
if let Err(e) = self
|
||||||
|
.foreign_network_client
|
||||||
|
.send_msg(msg.clone(), *peer_id)
|
||||||
|
.await
|
||||||
|
{
|
||||||
|
errs.push(e);
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
if let Err(send_ret) = send_ret {
|
|
||||||
errs.push(send_ret);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -2,8 +2,7 @@ use std::{net::Ipv4Addr, sync::Arc};
|
|||||||
|
|
||||||
use anyhow::Context;
|
use anyhow::Context;
|
||||||
use dashmap::DashMap;
|
use dashmap::DashMap;
|
||||||
use tokio::sync::{mpsc, RwLock};
|
use tokio::sync::RwLock;
|
||||||
use tokio_util::bytes::Bytes;
|
|
||||||
|
|
||||||
use crate::{
|
use crate::{
|
||||||
common::{
|
common::{
|
||||||
@@ -12,29 +11,27 @@ use crate::{
|
|||||||
PeerId,
|
PeerId,
|
||||||
},
|
},
|
||||||
rpc::PeerConnInfo,
|
rpc::PeerConnInfo,
|
||||||
|
tunnel::packet_def::ZCPacket,
|
||||||
tunnels::TunnelError,
|
tunnels::TunnelError,
|
||||||
};
|
};
|
||||||
|
|
||||||
use super::{
|
use super::{
|
||||||
peer::Peer,
|
peer::Peer,
|
||||||
peer_conn::{PeerConn, PeerConnId},
|
|
||||||
route_trait::ArcRoute,
|
route_trait::ArcRoute,
|
||||||
|
zc_peer_conn::{PeerConn, PeerConnId},
|
||||||
|
PacketRecvChan,
|
||||||
};
|
};
|
||||||
|
|
||||||
pub struct PeerMap {
|
pub struct PeerMap {
|
||||||
global_ctx: ArcGlobalCtx,
|
global_ctx: ArcGlobalCtx,
|
||||||
my_peer_id: PeerId,
|
my_peer_id: PeerId,
|
||||||
peer_map: DashMap<PeerId, Arc<Peer>>,
|
peer_map: DashMap<PeerId, Arc<Peer>>,
|
||||||
packet_send: mpsc::Sender<Bytes>,
|
packet_send: PacketRecvChan,
|
||||||
routes: RwLock<Vec<ArcRoute>>,
|
routes: RwLock<Vec<ArcRoute>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl PeerMap {
|
impl PeerMap {
|
||||||
pub fn new(
|
pub fn new(packet_send: PacketRecvChan, global_ctx: ArcGlobalCtx, my_peer_id: PeerId) -> Self {
|
||||||
packet_send: mpsc::Sender<Bytes>,
|
|
||||||
global_ctx: ArcGlobalCtx,
|
|
||||||
my_peer_id: PeerId,
|
|
||||||
) -> Self {
|
|
||||||
PeerMap {
|
PeerMap {
|
||||||
global_ctx,
|
global_ctx,
|
||||||
my_peer_id,
|
my_peer_id,
|
||||||
@@ -72,7 +69,7 @@ impl PeerMap {
|
|||||||
self.peer_map.contains_key(&peer_id)
|
self.peer_map.contains_key(&peer_id)
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn send_msg_directly(&self, msg: Bytes, dst_peer_id: PeerId) -> Result<(), Error> {
|
pub async fn send_msg_directly(&self, msg: ZCPacket, dst_peer_id: PeerId) -> Result<(), Error> {
|
||||||
if dst_peer_id == self.my_peer_id {
|
if dst_peer_id == self.my_peer_id {
|
||||||
return Ok(self
|
return Ok(self
|
||||||
.packet_send
|
.packet_send
|
||||||
@@ -87,48 +84,53 @@ impl PeerMap {
|
|||||||
}
|
}
|
||||||
None => {
|
None => {
|
||||||
log::error!("no peer for dst_peer_id: {}", dst_peer_id);
|
log::error!("no peer for dst_peer_id: {}", dst_peer_id);
|
||||||
return Err(Error::RouteError(None));
|
return Err(Error::RouteError(Some(format!(
|
||||||
|
"peer map sengmsg directly no connected dst_peer_id: {}",
|
||||||
|
dst_peer_id
|
||||||
|
))));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn send_msg(&self, msg: Bytes, dst_peer_id: PeerId) -> Result<(), Error> {
|
pub async fn get_gateway_peer_id(&self, dst_peer_id: PeerId) -> Option<PeerId> {
|
||||||
if dst_peer_id == self.my_peer_id {
|
if dst_peer_id == self.my_peer_id {
|
||||||
return Ok(self
|
return Some(dst_peer_id);
|
||||||
.packet_send
|
}
|
||||||
.send(msg)
|
|
||||||
.await
|
if self.has_peer(dst_peer_id) {
|
||||||
.with_context(|| "send msg to self failed")?);
|
return Some(dst_peer_id);
|
||||||
}
|
}
|
||||||
|
|
||||||
// get route info
|
// get route info
|
||||||
let mut gateway_peer_id = None;
|
|
||||||
for route in self.routes.read().await.iter() {
|
for route in self.routes.read().await.iter() {
|
||||||
gateway_peer_id = route.get_next_hop(dst_peer_id).await;
|
if let Some(gateway_peer_id) = route.get_next_hop(dst_peer_id).await {
|
||||||
if gateway_peer_id.is_none() {
|
// for foreign network, gateway_peer_id may not connect to me
|
||||||
continue;
|
if self.has_peer(gateway_peer_id) {
|
||||||
} else {
|
return Some(gateway_peer_id);
|
||||||
break;
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if gateway_peer_id.is_none() && self.has_peer(dst_peer_id) {
|
None
|
||||||
gateway_peer_id = Some(dst_peer_id);
|
}
|
||||||
}
|
|
||||||
|
|
||||||
let Some(gateway_peer_id) = gateway_peer_id else {
|
pub async fn send_msg(&self, msg: ZCPacket, dst_peer_id: PeerId) -> Result<(), Error> {
|
||||||
|
let Some(gateway_peer_id) = self.get_gateway_peer_id(dst_peer_id).await else {
|
||||||
tracing::trace!(
|
tracing::trace!(
|
||||||
"no gateway for dst_peer_id: {}, peers: {:?}, my_peer_id: {}",
|
"no gateway for dst_peer_id: {}, peers: {:?}, my_peer_id: {}",
|
||||||
dst_peer_id,
|
dst_peer_id,
|
||||||
self.peer_map.iter().map(|v| *v.key()).collect::<Vec<_>>(),
|
self.peer_map.iter().map(|v| *v.key()).collect::<Vec<_>>(),
|
||||||
self.my_peer_id
|
self.my_peer_id
|
||||||
);
|
);
|
||||||
return Err(Error::RouteError(None));
|
return Err(Error::RouteError(Some(format!(
|
||||||
|
"peer map sengmsg no gateway for dst_peer_id: {}",
|
||||||
|
dst_peer_id
|
||||||
|
))));
|
||||||
};
|
};
|
||||||
|
|
||||||
self.send_msg_directly(msg.clone(), gateway_peer_id).await?;
|
self.send_msg_directly(msg, gateway_peer_id).await?;
|
||||||
return Ok(());
|
return Ok(());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -203,6 +203,7 @@ struct SyncRouteInfoResponse {
|
|||||||
trait RouteService {
|
trait RouteService {
|
||||||
async fn sync_route_info(
|
async fn sync_route_info(
|
||||||
my_peer_id: PeerId,
|
my_peer_id: PeerId,
|
||||||
|
my_session_id: SessionId,
|
||||||
is_initiator: bool,
|
is_initiator: bool,
|
||||||
peer_infos: Option<Vec<RoutePeerInfo>>,
|
peer_infos: Option<Vec<RoutePeerInfo>>,
|
||||||
conn_bitmap: Option<RouteConnBitmap>,
|
conn_bitmap: Option<RouteConnBitmap>,
|
||||||
@@ -547,6 +548,15 @@ impl SyncRouteSession {
|
|||||||
self.we_are_initiator.store(is_initiator, Ordering::Relaxed);
|
self.we_are_initiator.store(is_initiator, Ordering::Relaxed);
|
||||||
self.need_sync_initiator_info.store(true, Ordering::Relaxed);
|
self.need_sync_initiator_info.store(true, Ordering::Relaxed);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn update_dst_session_id(&self, session_id: SessionId) {
|
||||||
|
if session_id != self.dst_session_id.load(Ordering::Relaxed) {
|
||||||
|
tracing::warn!(?self, ?session_id, "session id mismatch, clear saved info.");
|
||||||
|
self.dst_session_id.store(session_id, Ordering::Relaxed);
|
||||||
|
self.dst_saved_conn_bitmap_version.clear();
|
||||||
|
self.dst_saved_peer_info_versions.clear();
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
struct PeerRouteServiceImpl {
|
struct PeerRouteServiceImpl {
|
||||||
@@ -794,6 +804,7 @@ impl PeerRouteServiceImpl {
|
|||||||
.sync_route_info(
|
.sync_route_info(
|
||||||
rpc_ctx,
|
rpc_ctx,
|
||||||
my_peer_id,
|
my_peer_id,
|
||||||
|
session.my_session_id.load(Ordering::Relaxed),
|
||||||
session.we_are_initiator.load(Ordering::Relaxed),
|
session.we_are_initiator.load(Ordering::Relaxed),
|
||||||
peer_infos.clone(),
|
peer_infos.clone(),
|
||||||
conn_bitmap.clone(),
|
conn_bitmap.clone(),
|
||||||
@@ -814,19 +825,7 @@ impl PeerRouteServiceImpl {
|
|||||||
.need_sync_initiator_info
|
.need_sync_initiator_info
|
||||||
.store(false, Ordering::Relaxed);
|
.store(false, Ordering::Relaxed);
|
||||||
|
|
||||||
if ret.session_id != session.dst_session_id.load(Ordering::Relaxed) {
|
session.update_dst_session_id(ret.session_id);
|
||||||
tracing::warn!(
|
|
||||||
?ret,
|
|
||||||
?my_peer_id,
|
|
||||||
?dst_peer_id,
|
|
||||||
"session id mismatch, clear saved info."
|
|
||||||
);
|
|
||||||
session
|
|
||||||
.dst_session_id
|
|
||||||
.store(ret.session_id, Ordering::Relaxed);
|
|
||||||
session.dst_saved_conn_bitmap_version.clear();
|
|
||||||
session.dst_saved_peer_info_versions.clear();
|
|
||||||
}
|
|
||||||
|
|
||||||
if let Some(peer_infos) = &peer_infos {
|
if let Some(peer_infos) = &peer_infos {
|
||||||
session.update_dst_saved_peer_info_version(&peer_infos);
|
session.update_dst_saved_peer_info_version(&peer_infos);
|
||||||
@@ -864,6 +863,7 @@ impl RouteService for RouteSessionManager {
|
|||||||
self,
|
self,
|
||||||
_: tarpc::context::Context,
|
_: tarpc::context::Context,
|
||||||
from_peer_id: PeerId,
|
from_peer_id: PeerId,
|
||||||
|
from_session_id: SessionId,
|
||||||
is_initiator: bool,
|
is_initiator: bool,
|
||||||
peer_infos: Option<Vec<RoutePeerInfo>>,
|
peer_infos: Option<Vec<RoutePeerInfo>>,
|
||||||
conn_bitmap: Option<RouteConnBitmap>,
|
conn_bitmap: Option<RouteConnBitmap>,
|
||||||
@@ -877,6 +877,8 @@ impl RouteService for RouteSessionManager {
|
|||||||
|
|
||||||
session.rpc_rx_count.fetch_add(1, Ordering::Relaxed);
|
session.rpc_rx_count.fetch_add(1, Ordering::Relaxed);
|
||||||
|
|
||||||
|
session.update_dst_session_id(from_session_id);
|
||||||
|
|
||||||
if let Some(peer_infos) = &peer_infos {
|
if let Some(peer_infos) = &peer_infos {
|
||||||
service_impl.synced_route_info.update_peer_infos(
|
service_impl.synced_route_info.update_peer_infos(
|
||||||
my_peer_id,
|
my_peer_id,
|
||||||
@@ -1383,9 +1385,8 @@ mod tests {
|
|||||||
|
|
||||||
let i_a = get_is_initiator(&r_a, p_b.my_peer_id());
|
let i_a = get_is_initiator(&r_a, p_b.my_peer_id());
|
||||||
let i_b = get_is_initiator(&r_b, p_a.my_peer_id());
|
let i_b = get_is_initiator(&r_b, p_a.my_peer_id());
|
||||||
assert_ne!(i_a.0, i_a.1);
|
assert_eq!(i_a.0, i_b.1);
|
||||||
assert_ne!(i_b.0, i_b.1);
|
assert_eq!(i_b.0, i_a.1);
|
||||||
assert_ne!(i_a.0, i_b.0);
|
|
||||||
|
|
||||||
drop(r_b);
|
drop(r_b);
|
||||||
drop(p_b);
|
drop(p_b);
|
||||||
|
|||||||
@@ -15,14 +15,12 @@ use tracing::Instrument;
|
|||||||
|
|
||||||
use crate::{
|
use crate::{
|
||||||
common::{error::Error, global_ctx::ArcGlobalCtx, stun::StunInfoCollectorTrait, PeerId},
|
common::{error::Error, global_ctx::ArcGlobalCtx, stun::StunInfoCollectorTrait, PeerId},
|
||||||
peers::{
|
peers::route_trait::{Route, RouteInterfaceBox},
|
||||||
packet,
|
|
||||||
route_trait::{Route, RouteInterfaceBox},
|
|
||||||
},
|
|
||||||
rpc::{NatType, StunInfo},
|
rpc::{NatType, StunInfo},
|
||||||
|
tunnel::packet_def::{PacketType, ZCPacket},
|
||||||
};
|
};
|
||||||
|
|
||||||
use super::{packet::CtrlPacketPayload, PeerPacketFilter};
|
use super::PeerPacketFilter;
|
||||||
|
|
||||||
const SEND_ROUTE_PERIOD_SEC: u64 = 60;
|
const SEND_ROUTE_PERIOD_SEC: u64 = 60;
|
||||||
const SEND_ROUTE_FAST_REPLY_SEC: u64 = 5;
|
const SEND_ROUTE_FAST_REPLY_SEC: u64 = 5;
|
||||||
@@ -625,26 +623,15 @@ impl Route for BasicRoute {
|
|||||||
|
|
||||||
#[async_trait::async_trait]
|
#[async_trait::async_trait]
|
||||||
impl PeerPacketFilter for BasicRoute {
|
impl PeerPacketFilter for BasicRoute {
|
||||||
async fn try_process_packet_from_peer(
|
async fn try_process_packet_from_peer(&self, packet: ZCPacket) -> Option<ZCPacket> {
|
||||||
&self,
|
let hdr = packet.peer_manager_header().unwrap();
|
||||||
packet: &packet::ArchivedPacket,
|
if hdr.packet_type == PacketType::Route as u8 {
|
||||||
_data: &Bytes,
|
let b = packet.payload().to_vec();
|
||||||
) -> Option<()> {
|
self.handle_route_packet(hdr.from_peer_id.get(), b.into())
|
||||||
if packet.packet_type == packet::PacketType::RoutePacket {
|
.await;
|
||||||
let CtrlPacketPayload::RoutePacket(route_packet) =
|
|
||||||
CtrlPacketPayload::from_packet(packet)
|
|
||||||
else {
|
|
||||||
return None;
|
|
||||||
};
|
|
||||||
|
|
||||||
self.handle_route_packet(
|
|
||||||
packet.from_peer.into(),
|
|
||||||
route_packet.body.into_boxed_slice().into(),
|
|
||||||
)
|
|
||||||
.await;
|
|
||||||
Some(())
|
|
||||||
} else {
|
|
||||||
None
|
None
|
||||||
|
} else {
|
||||||
|
Some(packet)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -2,22 +2,22 @@ use std::sync::{atomic::AtomicU32, Arc};
|
|||||||
|
|
||||||
use dashmap::DashMap;
|
use dashmap::DashMap;
|
||||||
use futures::{SinkExt, StreamExt};
|
use futures::{SinkExt, StreamExt};
|
||||||
use rkyv::Deserialize;
|
use prost::Message;
|
||||||
|
|
||||||
use tarpc::{server::Channel, transport::channel::UnboundedChannel};
|
use tarpc::{server::Channel, transport::channel::UnboundedChannel};
|
||||||
use tokio::{
|
use tokio::{
|
||||||
sync::mpsc::{self, UnboundedSender},
|
sync::mpsc::{self, UnboundedSender},
|
||||||
task::JoinSet,
|
task::JoinSet,
|
||||||
};
|
};
|
||||||
use tokio_util::bytes::Bytes;
|
|
||||||
use tracing::Instrument;
|
use tracing::Instrument;
|
||||||
|
|
||||||
use crate::{
|
use crate::{
|
||||||
common::{error::Error, PeerId},
|
common::{error::Error, PeerId},
|
||||||
peers::packet::Packet,
|
rpc::TaRpcPacket,
|
||||||
|
tunnel::packet_def::{PacketType, ZCPacket},
|
||||||
};
|
};
|
||||||
|
|
||||||
use super::packet::CtrlPacketPayload;
|
|
||||||
|
|
||||||
type PeerRpcServiceId = u32;
|
type PeerRpcServiceId = u32;
|
||||||
type PeerRpcTransactId = u32;
|
type PeerRpcTransactId = u32;
|
||||||
|
|
||||||
@@ -25,11 +25,11 @@ type PeerRpcTransactId = u32;
|
|||||||
#[auto_impl::auto_impl(Arc)]
|
#[auto_impl::auto_impl(Arc)]
|
||||||
pub trait PeerRpcManagerTransport: Send + Sync + 'static {
|
pub trait PeerRpcManagerTransport: Send + Sync + 'static {
|
||||||
fn my_peer_id(&self) -> PeerId;
|
fn my_peer_id(&self) -> PeerId;
|
||||||
async fn send(&self, msg: Bytes, dst_peer_id: PeerId) -> Result<(), Error>;
|
async fn send(&self, msg: ZCPacket, dst_peer_id: PeerId) -> Result<(), Error>;
|
||||||
async fn recv(&self) -> Result<Bytes, Error>;
|
async fn recv(&self) -> Result<ZCPacket, Error>;
|
||||||
}
|
}
|
||||||
|
|
||||||
type PacketSender = UnboundedSender<Packet>;
|
type PacketSender = UnboundedSender<ZCPacket>;
|
||||||
|
|
||||||
struct PeerRpcEndPoint {
|
struct PeerRpcEndPoint {
|
||||||
peer_id: PeerId,
|
peer_id: PeerId,
|
||||||
@@ -63,16 +63,6 @@ impl std::fmt::Debug for PeerRpcManager {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug)]
|
|
||||||
struct TaRpcPacketInfo {
|
|
||||||
from_peer: PeerId,
|
|
||||||
to_peer: PeerId,
|
|
||||||
service_id: PeerRpcServiceId,
|
|
||||||
transact_id: PeerRpcTransactId,
|
|
||||||
is_req: bool,
|
|
||||||
content: Vec<u8>,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl PeerRpcManager {
|
impl PeerRpcManager {
|
||||||
pub fn new(tspt: impl PeerRpcManagerTransport) -> Self {
|
pub fn new(tspt: impl PeerRpcManagerTransport) -> Self {
|
||||||
Self {
|
Self {
|
||||||
@@ -100,7 +90,7 @@ impl PeerRpcManager {
|
|||||||
let tspt = self.tspt.clone();
|
let tspt = self.tspt.clone();
|
||||||
let creator = Box::new(move |peer_id: PeerId| {
|
let creator = Box::new(move |peer_id: PeerId| {
|
||||||
let mut tasks = JoinSet::new();
|
let mut tasks = JoinSet::new();
|
||||||
let (packet_sender, mut packet_receiver) = mpsc::unbounded_channel::<Packet>();
|
let (packet_sender, mut packet_receiver) = mpsc::unbounded_channel();
|
||||||
let (mut client_transport, server_transport) = tarpc::transport::channel::unbounded();
|
let (mut client_transport, server_transport) = tarpc::transport::channel::unbounded();
|
||||||
let server = tarpc::server::BaseChannel::with_defaults(server_transport);
|
let server = tarpc::server::BaseChannel::with_defaults(server_transport);
|
||||||
|
|
||||||
@@ -122,7 +112,7 @@ impl PeerRpcManager {
|
|||||||
continue;
|
continue;
|
||||||
};
|
};
|
||||||
|
|
||||||
tracing::trace!(resp = ?resp, "recv packet from client");
|
tracing::debug!(resp = ?resp, "server recv packet from service provider");
|
||||||
if resp.is_err() {
|
if resp.is_err() {
|
||||||
tracing::warn!(err = ?resp.err(),
|
tracing::warn!(err = ?resp.err(),
|
||||||
"[PEER RPC MGR] client_transport in server side got channel error, ignore it.");
|
"[PEER RPC MGR] client_transport in server side got channel error, ignore it.");
|
||||||
@@ -136,7 +126,7 @@ impl PeerRpcManager {
|
|||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
let msg = Packet::new_tarpc_packet(
|
let msg = Self::build_rpc_packet(
|
||||||
tspt.my_peer_id(),
|
tspt.my_peer_id(),
|
||||||
cur_req_peer_id,
|
cur_req_peer_id,
|
||||||
service_id,
|
service_id,
|
||||||
@@ -145,12 +135,13 @@ impl PeerRpcManager {
|
|||||||
serialized_resp.unwrap(),
|
serialized_resp.unwrap(),
|
||||||
);
|
);
|
||||||
|
|
||||||
if let Err(e) = tspt.send(msg.into(), peer_id).await {
|
if let Err(e) = tspt.send(msg, peer_id).await {
|
||||||
tracing::error!(error = ?e, peer_id = ?peer_id, service_id = ?service_id, "send resp to peer failed");
|
tracing::error!(error = ?e, peer_id = ?peer_id, service_id = ?service_id, "send resp to peer failed");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
Some(packet) = packet_receiver.recv() => {
|
Some(packet) = packet_receiver.recv() => {
|
||||||
let info = Self::parse_rpc_packet(&packet);
|
let info = Self::parse_rpc_packet(&packet);
|
||||||
|
tracing::debug!(?info, "server recv packet from peer");
|
||||||
if let Err(e) = info {
|
if let Err(e) = info {
|
||||||
tracing::error!(error = ?e, packet = ?packet, "parse rpc packet failed");
|
tracing::error!(error = ?e, packet = ?packet, "parse rpc packet failed");
|
||||||
continue;
|
continue;
|
||||||
@@ -168,7 +159,7 @@ impl PeerRpcManager {
|
|||||||
}
|
}
|
||||||
|
|
||||||
assert_eq!(info.service_id, service_id);
|
assert_eq!(info.service_id, service_id);
|
||||||
cur_req_peer_id = Some(packet.from_peer.clone().into());
|
cur_req_peer_id = Some(info.from_peer);
|
||||||
cur_transact_id = info.transact_id;
|
cur_transact_id = info.transact_id;
|
||||||
|
|
||||||
tracing::trace!("recv packet from peer, packet: {:?}", packet);
|
tracing::trace!("recv packet from peer, packet: {:?}", packet);
|
||||||
@@ -219,19 +210,33 @@ impl PeerRpcManager {
|
|||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn parse_rpc_packet(packet: &Packet) -> Result<TaRpcPacketInfo, Error> {
|
fn parse_rpc_packet(packet: &ZCPacket) -> Result<TaRpcPacket, Error> {
|
||||||
let ctrl_packet_payload = CtrlPacketPayload::from_packet2(&packet);
|
let payload = packet.payload();
|
||||||
match &ctrl_packet_payload {
|
TaRpcPacket::decode(payload).map_err(|e| Error::MessageDecodeError(e.to_string()))
|
||||||
CtrlPacketPayload::TaRpc(id, tid, is_req, body) => Ok(TaRpcPacketInfo {
|
}
|
||||||
from_peer: packet.from_peer.into(),
|
|
||||||
to_peer: packet.to_peer.into(),
|
fn build_rpc_packet(
|
||||||
service_id: *id,
|
from_peer: PeerId,
|
||||||
transact_id: *tid,
|
to_peer: PeerId,
|
||||||
is_req: *is_req,
|
service_id: PeerRpcServiceId,
|
||||||
content: body.clone(),
|
transact_id: PeerRpcTransactId,
|
||||||
}),
|
is_req: bool,
|
||||||
_ => Err(Error::ShellCommandError("invalid packet".to_owned())),
|
content: Vec<u8>,
|
||||||
}
|
) -> ZCPacket {
|
||||||
|
let packet = TaRpcPacket {
|
||||||
|
from_peer,
|
||||||
|
to_peer,
|
||||||
|
service_id,
|
||||||
|
transact_id,
|
||||||
|
is_req,
|
||||||
|
content,
|
||||||
|
};
|
||||||
|
let mut buf = Vec::new();
|
||||||
|
packet.encode(&mut buf).unwrap();
|
||||||
|
|
||||||
|
let mut zc_packet = ZCPacket::new_with_payload(&buf);
|
||||||
|
zc_packet.fill_peer_manager_hdr(from_peer, to_peer, PacketType::TaRpc as u8);
|
||||||
|
zc_packet
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn run(&self) {
|
pub fn run(&self) {
|
||||||
@@ -245,9 +250,9 @@ impl PeerRpcManager {
|
|||||||
tracing::warn!("peer rpc transport read aborted, exiting");
|
tracing::warn!("peer rpc transport read aborted, exiting");
|
||||||
break;
|
break;
|
||||||
};
|
};
|
||||||
let packet = Packet::decode(&o);
|
|
||||||
let packet: Packet = packet.deserialize(&mut rkyv::Infallible).unwrap();
|
let info = Self::parse_rpc_packet(&o).unwrap();
|
||||||
let info = Self::parse_rpc_packet(&packet).unwrap();
|
tracing::debug!(?info, "recv rpc packet from peer");
|
||||||
|
|
||||||
if info.is_req {
|
if info.is_req {
|
||||||
if !service_registry.contains_key(&info.service_id) {
|
if !service_registry.contains_key(&info.service_id) {
|
||||||
@@ -265,15 +270,15 @@ impl PeerRpcManager {
|
|||||||
service_registry.get(&info.service_id).unwrap()(info.from_peer)
|
service_registry.get(&info.service_id).unwrap()(info.from_peer)
|
||||||
});
|
});
|
||||||
|
|
||||||
endpoint.packet_sender.send(packet).unwrap();
|
endpoint.packet_sender.send(o).unwrap();
|
||||||
} else {
|
} else {
|
||||||
if let Some(a) = client_resp_receivers.get(&PeerRpcClientCtxKey(
|
if let Some(a) = client_resp_receivers.get(&PeerRpcClientCtxKey(
|
||||||
info.from_peer,
|
info.from_peer,
|
||||||
info.service_id,
|
info.service_id,
|
||||||
info.transact_id,
|
info.transact_id,
|
||||||
)) {
|
)) {
|
||||||
log::trace!("recv resp: {:?}", packet);
|
tracing::trace!("recv resp: {:?}", info);
|
||||||
if let Err(e) = a.send(packet) {
|
if let Err(e) = a.send(o) {
|
||||||
tracing::error!(error = ?e, "send resp to client failed");
|
tracing::error!(error = ?e, "send resp to client failed");
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
@@ -297,7 +302,8 @@ impl PeerRpcManager {
|
|||||||
Fut: std::future::Future<Output = RpcRet>,
|
Fut: std::future::Future<Output = RpcRet>,
|
||||||
{
|
{
|
||||||
let mut tasks = JoinSet::new();
|
let mut tasks = JoinSet::new();
|
||||||
let (packet_sender, mut packet_receiver) = mpsc::unbounded_channel::<Packet>();
|
let (packet_sender, mut packet_receiver) = mpsc::unbounded_channel();
|
||||||
|
|
||||||
let (client_transport, server_transport) =
|
let (client_transport, server_transport) =
|
||||||
tarpc::transport::channel::unbounded::<CM, Req>();
|
tarpc::transport::channel::unbounded::<CM, Req>();
|
||||||
|
|
||||||
@@ -321,7 +327,7 @@ impl PeerRpcManager {
|
|||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
let a = Packet::new_tarpc_packet(
|
let packet = Self::build_rpc_packet(
|
||||||
tspt.my_peer_id(),
|
tspt.my_peer_id(),
|
||||||
dst_peer_id,
|
dst_peer_id,
|
||||||
service_id,
|
service_id,
|
||||||
@@ -330,7 +336,9 @@ impl PeerRpcManager {
|
|||||||
a.unwrap(),
|
a.unwrap(),
|
||||||
);
|
);
|
||||||
|
|
||||||
if let Err(e) = tspt.send(a.into(), dst_peer_id).await {
|
tracing::debug!(?packet, "client send rpc packet to peer");
|
||||||
|
|
||||||
|
if let Err(e) = tspt.send(packet, dst_peer_id).await {
|
||||||
tracing::error!(error = ?e, dst_peer_id = ?dst_peer_id, "send to peer failed");
|
tracing::error!(error = ?e, dst_peer_id = ?dst_peer_id, "send to peer failed");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -342,11 +350,12 @@ impl PeerRpcManager {
|
|||||||
while let Some(packet) = packet_receiver.recv().await {
|
while let Some(packet) = packet_receiver.recv().await {
|
||||||
tracing::trace!("tunnel recv: {:?}", packet);
|
tracing::trace!("tunnel recv: {:?}", packet);
|
||||||
|
|
||||||
let info = PeerRpcManager::parse_rpc_packet(&packet);
|
let info = Self::parse_rpc_packet(&packet);
|
||||||
if let Err(e) = info {
|
if let Err(e) = info {
|
||||||
tracing::error!(error = ?e, "parse rpc packet failed");
|
tracing::error!(error = ?e, "parse rpc packet failed");
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
tracing::debug!(?info, "client recv rpc packet from peer");
|
||||||
|
|
||||||
let decoded = postcard::from_bytes(&info.unwrap().content.as_slice());
|
let decoded = postcard::from_bytes(&info.unwrap().content.as_slice());
|
||||||
if let Err(e) = decoded {
|
if let Err(e) = decoded {
|
||||||
@@ -381,8 +390,10 @@ impl PeerRpcManager {
|
|||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
|
use std::{pin::Pin, sync::Arc};
|
||||||
|
|
||||||
use futures::{SinkExt, StreamExt};
|
use futures::{SinkExt, StreamExt};
|
||||||
use tokio_util::bytes::Bytes;
|
use tokio::sync::Mutex;
|
||||||
|
|
||||||
use crate::{
|
use crate::{
|
||||||
common::{error::Error, new_peer_id, PeerId},
|
common::{error::Error, new_peer_id, PeerId},
|
||||||
@@ -390,7 +401,10 @@ mod tests {
|
|||||||
peer_rpc::PeerRpcManager,
|
peer_rpc::PeerRpcManager,
|
||||||
tests::{connect_peer_manager, create_mock_peer_manager, wait_route_appear},
|
tests::{connect_peer_manager, create_mock_peer_manager, wait_route_appear},
|
||||||
},
|
},
|
||||||
tunnels::{self, ring_tunnel::create_ring_tunnel_pair},
|
tunnel::{
|
||||||
|
packet_def::ZCPacket, ring::create_ring_tunnel_pair, Tunnel, ZCPacketSink,
|
||||||
|
ZCPacketStream,
|
||||||
|
},
|
||||||
};
|
};
|
||||||
|
|
||||||
use super::PeerRpcManagerTransport;
|
use super::PeerRpcManagerTransport;
|
||||||
@@ -415,7 +429,8 @@ mod tests {
|
|||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn peer_rpc_basic_test() {
|
async fn peer_rpc_basic_test() {
|
||||||
struct MockTransport {
|
struct MockTransport {
|
||||||
tunnel: Box<dyn tunnels::Tunnel>,
|
sink: Arc<Mutex<Pin<Box<dyn ZCPacketSink>>>>,
|
||||||
|
stream: Arc<Mutex<Pin<Box<dyn ZCPacketStream>>>>,
|
||||||
my_peer_id: PeerId,
|
my_peer_id: PeerId,
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -424,22 +439,25 @@ mod tests {
|
|||||||
fn my_peer_id(&self) -> PeerId {
|
fn my_peer_id(&self) -> PeerId {
|
||||||
self.my_peer_id
|
self.my_peer_id
|
||||||
}
|
}
|
||||||
async fn send(&self, msg: Bytes, _dst_peer_id: PeerId) -> Result<(), Error> {
|
async fn send(&self, msg: ZCPacket, _dst_peer_id: PeerId) -> Result<(), Error> {
|
||||||
println!("rpc mgr send: {:?}", msg);
|
println!("rpc mgr send: {:?}", msg);
|
||||||
self.tunnel.pin_sink().send(msg).await.unwrap();
|
self.sink.lock().await.send(msg).await.unwrap();
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
async fn recv(&self) -> Result<Bytes, Error> {
|
async fn recv(&self) -> Result<ZCPacket, Error> {
|
||||||
let ret = self.tunnel.pin_stream().next().await.unwrap();
|
let ret = self.stream.lock().await.next().await.unwrap();
|
||||||
println!("rpc mgr recv: {:?}", ret);
|
println!("rpc mgr recv: {:?}", ret);
|
||||||
return ret.map(|v| v.freeze()).map_err(|_| Error::Unknown);
|
return ret.map_err(|e| e.into());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
let (ct, st) = create_ring_tunnel_pair();
|
let (ct, st) = create_ring_tunnel_pair();
|
||||||
|
let (cts, ctsr) = ct.split();
|
||||||
|
let (sts, stsr) = st.split();
|
||||||
|
|
||||||
let server_rpc_mgr = PeerRpcManager::new(MockTransport {
|
let server_rpc_mgr = PeerRpcManager::new(MockTransport {
|
||||||
tunnel: st,
|
sink: Arc::new(Mutex::new(ctsr)),
|
||||||
|
stream: Arc::new(Mutex::new(cts)),
|
||||||
my_peer_id: new_peer_id(),
|
my_peer_id: new_peer_id(),
|
||||||
});
|
});
|
||||||
server_rpc_mgr.run();
|
server_rpc_mgr.run();
|
||||||
@@ -449,7 +467,8 @@ mod tests {
|
|||||||
server_rpc_mgr.run_service(1, s.serve());
|
server_rpc_mgr.run_service(1, s.serve());
|
||||||
|
|
||||||
let client_rpc_mgr = PeerRpcManager::new(MockTransport {
|
let client_rpc_mgr = PeerRpcManager::new(MockTransport {
|
||||||
tunnel: ct,
|
sink: Arc::new(Mutex::new(stsr)),
|
||||||
|
stream: Arc::new(Mutex::new(sts)),
|
||||||
my_peer_id: new_peer_id(),
|
my_peer_id: new_peer_id(),
|
||||||
});
|
});
|
||||||
client_rpc_mgr.run();
|
client_rpc_mgr.run();
|
||||||
|
|||||||
@@ -4,7 +4,7 @@ use futures::Future;
|
|||||||
|
|
||||||
use crate::{
|
use crate::{
|
||||||
common::{error::Error, global_ctx::tests::get_mock_global_ctx, PeerId},
|
common::{error::Error, global_ctx::tests::get_mock_global_ctx, PeerId},
|
||||||
tunnels::ring_tunnel::create_ring_tunnel_pair,
|
tunnel::ring::create_ring_tunnel_pair,
|
||||||
};
|
};
|
||||||
|
|
||||||
use super::peer_manager::{PeerManager, RouteAlgoType};
|
use super::peer_manager::{PeerManager, RouteAlgoType};
|
||||||
|
|||||||
@@ -0,0 +1,748 @@
|
|||||||
|
use std::{
|
||||||
|
any::Any,
|
||||||
|
fmt::Debug,
|
||||||
|
pin::Pin,
|
||||||
|
sync::{
|
||||||
|
atomic::{AtomicU32, Ordering},
|
||||||
|
Arc,
|
||||||
|
},
|
||||||
|
};
|
||||||
|
|
||||||
|
use futures::{SinkExt, StreamExt, TryFutureExt};
|
||||||
|
|
||||||
|
use prost::Message;
|
||||||
|
|
||||||
|
use tokio::{
|
||||||
|
sync::{broadcast, mpsc},
|
||||||
|
task::JoinSet,
|
||||||
|
time::{timeout, Duration},
|
||||||
|
};
|
||||||
|
|
||||||
|
use tokio_util::sync::PollSender;
|
||||||
|
use tracing::Instrument;
|
||||||
|
use zerocopy::AsBytes;
|
||||||
|
|
||||||
|
use crate::{
|
||||||
|
common::{
|
||||||
|
error::Error,
|
||||||
|
global_ctx::{ArcGlobalCtx, NetworkIdentity},
|
||||||
|
PeerId,
|
||||||
|
},
|
||||||
|
peers::packet::PacketType,
|
||||||
|
rpc::{HandshakeRequest, PeerConnInfo, PeerConnStats, TunnelInfo},
|
||||||
|
tunnel::{
|
||||||
|
filter::{StatsRecorderTunnelFilter, TunnelFilter, TunnelWithFilter},
|
||||||
|
mpsc::{MpscTunnel, MpscTunnelSender},
|
||||||
|
packet_def::ZCPacket,
|
||||||
|
stats::{Throughput, WindowLatency},
|
||||||
|
Tunnel, TunnelError, ZCPacketStream,
|
||||||
|
},
|
||||||
|
};
|
||||||
|
|
||||||
|
use super::{peer_conn_ping::PeerConnPinger, PacketRecvChan};
|
||||||
|
|
||||||
|
pub type PeerConnId = uuid::Uuid;
|
||||||
|
|
||||||
|
const MAGIC: u32 = 0xd1e1a5e1;
|
||||||
|
const VERSION: u32 = 1;
|
||||||
|
|
||||||
|
pub struct PeerConn {
|
||||||
|
conn_id: PeerConnId,
|
||||||
|
|
||||||
|
my_peer_id: PeerId,
|
||||||
|
global_ctx: ArcGlobalCtx,
|
||||||
|
|
||||||
|
tunnel: Box<dyn Any + Send + 'static>,
|
||||||
|
sink: MpscTunnelSender,
|
||||||
|
recv: Option<Pin<Box<dyn ZCPacketStream>>>,
|
||||||
|
tunnel_info: Option<TunnelInfo>,
|
||||||
|
|
||||||
|
tasks: JoinSet<Result<(), TunnelError>>,
|
||||||
|
|
||||||
|
info: Option<HandshakeRequest>,
|
||||||
|
|
||||||
|
close_event_sender: Option<mpsc::Sender<PeerConnId>>,
|
||||||
|
|
||||||
|
ctrl_resp_sender: broadcast::Sender<ZCPacket>,
|
||||||
|
|
||||||
|
latency_stats: Arc<WindowLatency>,
|
||||||
|
throughput: Arc<Throughput>,
|
||||||
|
loss_rate_stats: Arc<AtomicU32>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Debug for PeerConn {
|
||||||
|
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||||
|
f.debug_struct("PeerConn")
|
||||||
|
.field("conn_id", &self.conn_id)
|
||||||
|
.field("my_peer_id", &self.my_peer_id)
|
||||||
|
.field("info", &self.info)
|
||||||
|
.finish()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl PeerConn {
|
||||||
|
pub fn new(my_peer_id: PeerId, global_ctx: ArcGlobalCtx, tunnel: Box<dyn Tunnel>) -> Self {
|
||||||
|
let tunnel_info = tunnel.info();
|
||||||
|
let (ctrl_sender, _ctrl_receiver) = broadcast::channel(100);
|
||||||
|
|
||||||
|
let peer_conn_tunnel_filter = StatsRecorderTunnelFilter::new();
|
||||||
|
let throughput = peer_conn_tunnel_filter.filter_output();
|
||||||
|
let peer_conn_tunnel = TunnelWithFilter::new(tunnel, peer_conn_tunnel_filter);
|
||||||
|
let mut mpsc_tunnel = MpscTunnel::new(peer_conn_tunnel);
|
||||||
|
|
||||||
|
let (recv, sink) = (mpsc_tunnel.get_stream(), mpsc_tunnel.get_sink());
|
||||||
|
|
||||||
|
PeerConn {
|
||||||
|
conn_id: PeerConnId::new_v4(),
|
||||||
|
|
||||||
|
my_peer_id,
|
||||||
|
global_ctx,
|
||||||
|
|
||||||
|
tunnel: Box::new(mpsc_tunnel),
|
||||||
|
sink,
|
||||||
|
recv: Some(recv),
|
||||||
|
tunnel_info,
|
||||||
|
|
||||||
|
tasks: JoinSet::new(),
|
||||||
|
|
||||||
|
info: None,
|
||||||
|
close_event_sender: None,
|
||||||
|
|
||||||
|
ctrl_resp_sender: ctrl_sender,
|
||||||
|
|
||||||
|
latency_stats: Arc::new(WindowLatency::new(15)),
|
||||||
|
throughput,
|
||||||
|
loss_rate_stats: Arc::new(AtomicU32::new(0)),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn get_conn_id(&self) -> PeerConnId {
|
||||||
|
self.conn_id
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn wait_handshake(&mut self) -> Result<HandshakeRequest, Error> {
|
||||||
|
let recv = self.recv.as_mut().unwrap();
|
||||||
|
let Some(rsp) = recv.next().await else {
|
||||||
|
return Err(Error::WaitRespError(
|
||||||
|
"conn closed during wait handshake response".to_owned(),
|
||||||
|
));
|
||||||
|
};
|
||||||
|
let rsp = rsp?;
|
||||||
|
let rsp = HandshakeRequest::decode(rsp.payload())
|
||||||
|
.map_err(|e| Error::WaitRespError(format!("decode handshake response error: {:?}", e)));
|
||||||
|
|
||||||
|
return Ok(rsp.unwrap());
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn wait_handshake_loop(&mut self) -> Result<HandshakeRequest, Error> {
|
||||||
|
Ok(timeout(Duration::from_secs(5), async move {
|
||||||
|
loop {
|
||||||
|
match self.wait_handshake().await {
|
||||||
|
Ok(rsp) => return rsp,
|
||||||
|
Err(e) => {
|
||||||
|
log::warn!("wait handshake error: {:?}", e);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.map_err(|e| Error::WaitRespError(format!("wait handshake timeout: {:?}", e)))
|
||||||
|
.await?)
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn send_handshake(&mut self) -> Result<(), Error> {
|
||||||
|
let network = self.global_ctx.get_network_identity();
|
||||||
|
let req = HandshakeRequest {
|
||||||
|
magic: MAGIC,
|
||||||
|
my_peer_id: self.my_peer_id,
|
||||||
|
version: VERSION,
|
||||||
|
features: Vec::new(),
|
||||||
|
network_name: network.network_name.clone(),
|
||||||
|
network_secret: network.network_secret.clone(),
|
||||||
|
};
|
||||||
|
|
||||||
|
let hs_req = req.encode_to_vec();
|
||||||
|
let mut zc_packet = ZCPacket::new_with_payload(hs_req.as_bytes());
|
||||||
|
zc_packet.fill_peer_manager_hdr(
|
||||||
|
self.my_peer_id,
|
||||||
|
PeerId::default(),
|
||||||
|
PacketType::HandShake as u8,
|
||||||
|
);
|
||||||
|
|
||||||
|
self.sink.send(zc_packet).await.map_err(|e| {
|
||||||
|
tracing::warn!("send handshake request error: {:?}", e);
|
||||||
|
Error::WaitRespError("send handshake request error".to_owned())
|
||||||
|
})?;
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tracing::instrument]
|
||||||
|
pub async fn do_handshake_as_server(&mut self) -> Result<(), Error> {
|
||||||
|
let rsp = self.wait_handshake_loop().await?;
|
||||||
|
tracing::info!("handshake request: {:?}", rsp);
|
||||||
|
self.info = Some(rsp);
|
||||||
|
self.send_handshake().await?;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tracing::instrument]
|
||||||
|
pub async fn do_handshake_as_client(&mut self) -> Result<(), Error> {
|
||||||
|
self.send_handshake().await?;
|
||||||
|
tracing::info!("waiting for handshake request from server");
|
||||||
|
let rsp = self.wait_handshake_loop().await?;
|
||||||
|
tracing::info!("handshake response: {:?}", rsp);
|
||||||
|
self.info = Some(rsp);
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn handshake_done(&self) -> bool {
|
||||||
|
self.info.is_some()
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn start_recv_loop(&mut self, packet_recv_chan: PacketRecvChan) {
|
||||||
|
let mut stream = self.recv.take().unwrap();
|
||||||
|
let sink = self.sink.clone();
|
||||||
|
let mut sender = PollSender::new(packet_recv_chan.clone());
|
||||||
|
let close_event_sender = self.close_event_sender.clone().unwrap();
|
||||||
|
let conn_id = self.conn_id;
|
||||||
|
let ctrl_sender = self.ctrl_resp_sender.clone();
|
||||||
|
let _conn_info = self.get_conn_info();
|
||||||
|
let conn_info_for_instrument = self.get_conn_info();
|
||||||
|
|
||||||
|
self.tasks.spawn(
|
||||||
|
async move {
|
||||||
|
tracing::info!("start recving peer conn packet");
|
||||||
|
let mut task_ret = Ok(());
|
||||||
|
while let Some(ret) = stream.next().await {
|
||||||
|
if ret.is_err() {
|
||||||
|
tracing::error!(error = ?ret, "peer conn recv error");
|
||||||
|
task_ret = Err(ret.err().unwrap());
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut zc_packet = ret.unwrap();
|
||||||
|
let Some(peer_mgr_hdr) = zc_packet.mut_peer_manager_header() else {
|
||||||
|
tracing::error!(
|
||||||
|
"unexpected packet: {:?}, cannot decode peer manager hdr",
|
||||||
|
zc_packet
|
||||||
|
);
|
||||||
|
continue;
|
||||||
|
};
|
||||||
|
|
||||||
|
if peer_mgr_hdr.packet_type == PacketType::Ping as u8 {
|
||||||
|
peer_mgr_hdr.packet_type = PacketType::Pong as u8;
|
||||||
|
if let Err(e) = sink.send(zc_packet).await {
|
||||||
|
tracing::error!(?e, "peer conn send req error");
|
||||||
|
}
|
||||||
|
} else if peer_mgr_hdr.packet_type == PacketType::Pong as u8 {
|
||||||
|
if let Err(e) = ctrl_sender.send(zc_packet) {
|
||||||
|
tracing::error!(?e, "peer conn send ctrl resp error");
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if sender.send(zc_packet).await.is_err() {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
tracing::info!("end recving peer conn packet");
|
||||||
|
|
||||||
|
drop(sink);
|
||||||
|
if let Err(e) = close_event_sender.send(conn_id).await {
|
||||||
|
tracing::error!(error = ?e, "peer conn close event send error");
|
||||||
|
}
|
||||||
|
|
||||||
|
task_ret
|
||||||
|
}
|
||||||
|
.instrument(
|
||||||
|
tracing::info_span!("peer conn recv loop", conn_info = ?conn_info_for_instrument),
|
||||||
|
),
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn start_pingpong(&mut self) {
|
||||||
|
let mut pingpong = PeerConnPinger::new(
|
||||||
|
self.my_peer_id,
|
||||||
|
self.get_peer_id(),
|
||||||
|
self.sink.clone(),
|
||||||
|
self.ctrl_resp_sender.clone(),
|
||||||
|
self.latency_stats.clone(),
|
||||||
|
self.loss_rate_stats.clone(),
|
||||||
|
);
|
||||||
|
|
||||||
|
let close_event_sender = self.close_event_sender.clone().unwrap();
|
||||||
|
let conn_id = self.conn_id;
|
||||||
|
|
||||||
|
self.tasks.spawn(async move {
|
||||||
|
pingpong.pingpong().await;
|
||||||
|
|
||||||
|
tracing::warn!(?pingpong, "pingpong task exit");
|
||||||
|
|
||||||
|
if let Err(e) = close_event_sender.send(conn_id).await {
|
||||||
|
log::warn!("close event sender error: {:?}", e);
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn send_msg(&mut self, msg: ZCPacket) -> Result<(), Error> {
|
||||||
|
Ok(self.sink.send(msg).await?)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn get_peer_id(&self) -> PeerId {
|
||||||
|
self.info.as_ref().unwrap().my_peer_id
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn get_network_identity(&self) -> NetworkIdentity {
|
||||||
|
let info = self.info.as_ref().unwrap();
|
||||||
|
NetworkIdentity {
|
||||||
|
network_name: info.network_name.clone(),
|
||||||
|
network_secret: info.network_secret.clone(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn set_close_event_sender(&mut self, sender: mpsc::Sender<PeerConnId>) {
|
||||||
|
self.close_event_sender = Some(sender);
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn get_stats(&self) -> PeerConnStats {
|
||||||
|
PeerConnStats {
|
||||||
|
latency_us: self.latency_stats.get_latency_us(),
|
||||||
|
|
||||||
|
tx_bytes: self.throughput.tx_bytes(),
|
||||||
|
rx_bytes: self.throughput.rx_bytes(),
|
||||||
|
|
||||||
|
tx_packets: self.throughput.tx_packets(),
|
||||||
|
rx_packets: self.throughput.rx_packets(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn get_conn_info(&self) -> PeerConnInfo {
|
||||||
|
PeerConnInfo {
|
||||||
|
conn_id: self.conn_id.to_string(),
|
||||||
|
my_peer_id: self.my_peer_id,
|
||||||
|
peer_id: self.get_peer_id(),
|
||||||
|
features: self.info.as_ref().unwrap().features.clone(),
|
||||||
|
tunnel: self.tunnel_info.clone(),
|
||||||
|
stats: Some(self.get_stats()),
|
||||||
|
loss_rate: (f64::from(self.loss_rate_stats.load(Ordering::Relaxed)) / 100.0) as f32,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use std::sync::Arc;
|
||||||
|
|
||||||
|
use super::*;
|
||||||
|
use crate::common::global_ctx::tests::get_mock_global_ctx;
|
||||||
|
use crate::common::new_peer_id;
|
||||||
|
use crate::tunnel::filter::tests::DropSendTunnelFilter;
|
||||||
|
use crate::tunnel::filter::PacketRecorderTunnelFilter;
|
||||||
|
use crate::tunnel::ring::create_ring_tunnel_pair;
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn peer_conn_handshake() {
|
||||||
|
let (c, s) = create_ring_tunnel_pair();
|
||||||
|
|
||||||
|
let c_recorder = Arc::new(PacketRecorderTunnelFilter::new());
|
||||||
|
let s_recorder = Arc::new(PacketRecorderTunnelFilter::new());
|
||||||
|
|
||||||
|
let c = TunnelWithFilter::new(c, c_recorder.clone());
|
||||||
|
let s = TunnelWithFilter::new(s, s_recorder.clone());
|
||||||
|
|
||||||
|
let c_peer_id = new_peer_id();
|
||||||
|
let s_peer_id = new_peer_id();
|
||||||
|
|
||||||
|
let mut c_peer = PeerConn::new(c_peer_id, get_mock_global_ctx(), Box::new(c));
|
||||||
|
|
||||||
|
let mut s_peer = PeerConn::new(s_peer_id, get_mock_global_ctx(), Box::new(s));
|
||||||
|
|
||||||
|
let (c_ret, s_ret) = tokio::join!(
|
||||||
|
c_peer.do_handshake_as_client(),
|
||||||
|
s_peer.do_handshake_as_server()
|
||||||
|
);
|
||||||
|
|
||||||
|
c_ret.unwrap();
|
||||||
|
s_ret.unwrap();
|
||||||
|
|
||||||
|
assert_eq!(c_recorder.sent.lock().unwrap().len(), 1);
|
||||||
|
assert_eq!(c_recorder.received.lock().unwrap().len(), 1);
|
||||||
|
|
||||||
|
assert_eq!(s_recorder.sent.lock().unwrap().len(), 1);
|
||||||
|
assert_eq!(s_recorder.received.lock().unwrap().len(), 1);
|
||||||
|
|
||||||
|
assert_eq!(c_peer.get_peer_id(), s_peer_id);
|
||||||
|
assert_eq!(s_peer.get_peer_id(), c_peer_id);
|
||||||
|
assert_eq!(c_peer.get_network_identity(), s_peer.get_network_identity());
|
||||||
|
assert_eq!(c_peer.get_network_identity(), NetworkIdentity::default());
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn peer_conn_pingpong_test_common(drop_start: u32, drop_end: u32, conn_closed: bool) {
|
||||||
|
let (c, s) = create_ring_tunnel_pair();
|
||||||
|
|
||||||
|
// drop 1-3 packets should not affect pingpong
|
||||||
|
let c_recorder = Arc::new(DropSendTunnelFilter::new(drop_start, drop_end));
|
||||||
|
let c = TunnelWithFilter::new(c, c_recorder.clone());
|
||||||
|
|
||||||
|
let c_peer_id = new_peer_id();
|
||||||
|
let s_peer_id = new_peer_id();
|
||||||
|
|
||||||
|
let mut c_peer = PeerConn::new(c_peer_id, get_mock_global_ctx(), Box::new(c));
|
||||||
|
let mut s_peer = PeerConn::new(s_peer_id, get_mock_global_ctx(), Box::new(s));
|
||||||
|
|
||||||
|
let (c_ret, s_ret) = tokio::join!(
|
||||||
|
c_peer.do_handshake_as_client(),
|
||||||
|
s_peer.do_handshake_as_server()
|
||||||
|
);
|
||||||
|
|
||||||
|
s_peer.set_close_event_sender(tokio::sync::mpsc::channel(1).0);
|
||||||
|
s_peer.start_recv_loop(tokio::sync::mpsc::channel(200).0);
|
||||||
|
|
||||||
|
assert!(c_ret.is_ok());
|
||||||
|
assert!(s_ret.is_ok());
|
||||||
|
|
||||||
|
let (close_send, mut close_recv) = tokio::sync::mpsc::channel(1);
|
||||||
|
c_peer.set_close_event_sender(close_send);
|
||||||
|
c_peer.start_pingpong();
|
||||||
|
c_peer.start_recv_loop(tokio::sync::mpsc::channel(200).0);
|
||||||
|
|
||||||
|
// wait 5s, conn should not be disconnected
|
||||||
|
tokio::time::sleep(Duration::from_secs(15)).await;
|
||||||
|
|
||||||
|
if conn_closed {
|
||||||
|
assert!(close_recv.try_recv().is_ok());
|
||||||
|
} else {
|
||||||
|
assert!(close_recv.try_recv().is_err());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn peer_conn_pingpong_timeout() {
|
||||||
|
peer_conn_pingpong_test_common(3, 5, false).await;
|
||||||
|
peer_conn_pingpong_test_common(5, 12, true).await;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/*
|
||||||
|
use std::{
|
||||||
|
fmt::Debug,
|
||||||
|
pin::Pin,
|
||||||
|
sync::{
|
||||||
|
atomic::{AtomicU32, Ordering},
|
||||||
|
Arc,
|
||||||
|
},
|
||||||
|
};
|
||||||
|
|
||||||
|
use futures::{SinkExt, StreamExt};
|
||||||
|
use pnet::datalink::NetworkInterface;
|
||||||
|
|
||||||
|
use tokio::{
|
||||||
|
sync::{broadcast, mpsc, Mutex},
|
||||||
|
task::JoinSet,
|
||||||
|
time::{timeout, Duration},
|
||||||
|
};
|
||||||
|
|
||||||
|
use tokio_util::{bytes::Bytes, sync::PollSender};
|
||||||
|
use tracing::Instrument;
|
||||||
|
|
||||||
|
use crate::{
|
||||||
|
common::{
|
||||||
|
error::Error,
|
||||||
|
global_ctx::{ArcGlobalCtx, NetworkIdentity},
|
||||||
|
PeerId,
|
||||||
|
},
|
||||||
|
define_tunnel_filter_chain,
|
||||||
|
peers::packet::{ArchivedPacketType, CtrlPacketPayload, PacketType},
|
||||||
|
rpc::{PeerConnInfo, PeerConnStats},
|
||||||
|
tunnel::{mpsc::MpscTunnelSender, stats::WindowLatency, TunnelError},
|
||||||
|
};
|
||||||
|
|
||||||
|
use super::packet::{self, HandShake, Packet};
|
||||||
|
|
||||||
|
pub type PacketRecvChan = mpsc::Sender<Bytes>;
|
||||||
|
|
||||||
|
macro_rules! wait_response {
|
||||||
|
($stream: ident, $out_var:ident, $pattern:pat_param => $value:expr) => {
|
||||||
|
let Ok(rsp_vec) = timeout(Duration::from_secs(1), $stream.next()).await else {
|
||||||
|
return Err(Error::WaitRespError(
|
||||||
|
"wait handshake response timeout".to_owned(),
|
||||||
|
));
|
||||||
|
};
|
||||||
|
let Some(rsp_vec) = rsp_vec else {
|
||||||
|
return Err(Error::WaitRespError(
|
||||||
|
"wait handshake response get none".to_owned(),
|
||||||
|
));
|
||||||
|
};
|
||||||
|
let Ok(rsp_vec) = rsp_vec else {
|
||||||
|
return Err(Error::WaitRespError(format!(
|
||||||
|
"wait handshake response get error {}",
|
||||||
|
rsp_vec.err().unwrap()
|
||||||
|
)));
|
||||||
|
};
|
||||||
|
|
||||||
|
let $out_var;
|
||||||
|
let rsp_bytes = Packet::decode(&rsp_vec);
|
||||||
|
if rsp_bytes.packet_type != PacketType::HandShake {
|
||||||
|
tracing::error!("unexpected packet type: {:?}", rsp_bytes);
|
||||||
|
return Err(Error::WaitRespError("unexpected packet type".to_owned()));
|
||||||
|
}
|
||||||
|
let resp_payload = CtrlPacketPayload::from_packet(&rsp_bytes);
|
||||||
|
match &resp_payload {
|
||||||
|
$pattern => $out_var = $value,
|
||||||
|
_ => {
|
||||||
|
tracing::error!(
|
||||||
|
"unexpected packet: {:?}, pattern: {:?}",
|
||||||
|
rsp_bytes,
|
||||||
|
stringify!($pattern)
|
||||||
|
);
|
||||||
|
return Err(Error::WaitRespError("unexpected packet".to_owned()));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'a> From<&HandShake> for PeerInfo {
|
||||||
|
fn from(hs: &HandShake) -> Self {
|
||||||
|
PeerInfo {
|
||||||
|
magic: hs.magic.into(),
|
||||||
|
my_peer_id: hs.my_peer_id.into(),
|
||||||
|
version: hs.version.into(),
|
||||||
|
features: hs.features.iter().map(|x| x.to_string()).collect(),
|
||||||
|
interfaces: Vec::new(),
|
||||||
|
network_identity: hs.network_identity.clone(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
define_tunnel_filter_chain!(PeerConnTunnel, stats = StatsRecorderTunnelFilter);
|
||||||
|
|
||||||
|
pub struct PeerConn {
|
||||||
|
conn_id: PeerConnId,
|
||||||
|
|
||||||
|
my_peer_id: PeerId,
|
||||||
|
global_ctx: ArcGlobalCtx,
|
||||||
|
|
||||||
|
sink: Pin<Box<dyn DatagramSink>>,
|
||||||
|
tunnel: Box<dyn Tunnel>,
|
||||||
|
|
||||||
|
tasks: JoinSet<Result<(), TunnelError>>,
|
||||||
|
|
||||||
|
info: Option<PeerInfo>,
|
||||||
|
|
||||||
|
close_event_sender: Option<mpsc::Sender<PeerConnId>>,
|
||||||
|
|
||||||
|
ctrl_resp_sender: broadcast::Sender<Bytes>,
|
||||||
|
|
||||||
|
latency_stats: Arc<WindowLatency>,
|
||||||
|
throughput: Arc<Throughput>,
|
||||||
|
loss_rate_stats: Arc<AtomicU32>,
|
||||||
|
}
|
||||||
|
|
||||||
|
enum PeerConnPacketType {
|
||||||
|
Data(Bytes),
|
||||||
|
CtrlReq(Bytes),
|
||||||
|
CtrlResp(Bytes),
|
||||||
|
}
|
||||||
|
|
||||||
|
impl PeerConn {
|
||||||
|
pub fn new(my_peer_id: PeerId, global_ctx: ArcGlobalCtx, tunnel: Box<dyn Tunnel>) -> Self {
|
||||||
|
let (ctrl_sender, _ctrl_receiver) = broadcast::channel(100);
|
||||||
|
let peer_conn_tunnel = PeerConnTunnel::new();
|
||||||
|
let tunnel = peer_conn_tunnel.wrap_tunnel(tunnel);
|
||||||
|
|
||||||
|
PeerConn {
|
||||||
|
conn_id: PeerConnId::new_v4(),
|
||||||
|
|
||||||
|
my_peer_id,
|
||||||
|
global_ctx,
|
||||||
|
|
||||||
|
sink: tunnel.pin_sink(),
|
||||||
|
tunnel: Box::new(tunnel),
|
||||||
|
|
||||||
|
tasks: JoinSet::new(),
|
||||||
|
|
||||||
|
info: None,
|
||||||
|
close_event_sender: None,
|
||||||
|
|
||||||
|
ctrl_resp_sender: ctrl_sender,
|
||||||
|
|
||||||
|
latency_stats: Arc::new(WindowLatency::new(15)),
|
||||||
|
throughput: peer_conn_tunnel.stats.get_throughput().clone(),
|
||||||
|
loss_rate_stats: Arc::new(AtomicU32::new(0)),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn get_conn_id(&self) -> PeerConnId {
|
||||||
|
self.conn_id
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tracing::instrument]
|
||||||
|
pub async fn do_handshake_as_server(&mut self) -> Result<(), TunnelError> {
|
||||||
|
let mut stream = self.tunnel.pin_stream();
|
||||||
|
let mut sink = self.tunnel.pin_sink();
|
||||||
|
|
||||||
|
tracing::info!("waiting for handshake request from client");
|
||||||
|
wait_response!(stream, hs_req, CtrlPacketPayload::HandShake(x) => x);
|
||||||
|
self.info = Some(PeerInfo::from(hs_req));
|
||||||
|
tracing::info!("handshake request: {:?}", hs_req);
|
||||||
|
|
||||||
|
let hs_req = self
|
||||||
|
.global_ctx
|
||||||
|
.net_ns
|
||||||
|
.run(|| packet::Packet::new_handshake(self.my_peer_id, &self.global_ctx.network));
|
||||||
|
sink.send(hs_req.into()).await?;
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tracing::instrument]
|
||||||
|
pub async fn do_handshake_as_client(&mut self) -> Result<(), TunnelError> {
|
||||||
|
let mut stream = self.tunnel.pin_stream();
|
||||||
|
let mut sink = self.tunnel.pin_sink();
|
||||||
|
|
||||||
|
let hs_req = self
|
||||||
|
.global_ctx
|
||||||
|
.net_ns
|
||||||
|
.run(|| packet::Packet::new_handshake(self.my_peer_id, &self.global_ctx.network));
|
||||||
|
sink.send(hs_req.into()).await?;
|
||||||
|
|
||||||
|
tracing::info!("waiting for handshake request from server");
|
||||||
|
wait_response!(stream, hs_rsp, CtrlPacketPayload::HandShake(x) => x);
|
||||||
|
self.info = Some(PeerInfo::from(hs_rsp));
|
||||||
|
tracing::info!("handshake response: {:?}", hs_rsp);
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn handshake_done(&self) -> bool {
|
||||||
|
self.info.is_some()
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn start_recv_loop(&mut self, packet_recv_chan: PacketRecvChan) {
|
||||||
|
let mut stream = self.tunnel.pin_stream();
|
||||||
|
let mut sink = self.tunnel.pin_sink();
|
||||||
|
let mut sender = PollSender::new(packet_recv_chan.clone());
|
||||||
|
let close_event_sender = self.close_event_sender.clone().unwrap();
|
||||||
|
let conn_id = self.conn_id;
|
||||||
|
let ctrl_sender = self.ctrl_resp_sender.clone();
|
||||||
|
let conn_info = self.get_conn_info();
|
||||||
|
let conn_info_for_instrument = self.get_conn_info();
|
||||||
|
|
||||||
|
self.tasks.spawn(
|
||||||
|
async move {
|
||||||
|
tracing::info!("start recving peer conn packet");
|
||||||
|
let mut task_ret = Ok(());
|
||||||
|
while let Some(ret) = stream.next().await {
|
||||||
|
if ret.is_err() {
|
||||||
|
tracing::error!(error = ?ret, "peer conn recv error");
|
||||||
|
task_ret = Err(ret.err().unwrap());
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
let buf = ret.unwrap();
|
||||||
|
let p = Packet::decode(&buf);
|
||||||
|
match p.packet_type {
|
||||||
|
ArchivedPacketType::Ping => {
|
||||||
|
let CtrlPacketPayload::Ping(seq) = CtrlPacketPayload::from_packet(p)
|
||||||
|
else {
|
||||||
|
log::error!("unexpected packet: {:?}", p);
|
||||||
|
continue;
|
||||||
|
};
|
||||||
|
|
||||||
|
let pong = packet::Packet::new_pong_packet(
|
||||||
|
conn_info.my_peer_id,
|
||||||
|
conn_info.peer_id,
|
||||||
|
seq.into(),
|
||||||
|
);
|
||||||
|
|
||||||
|
if let Err(e) = sink.send(pong.into()).await {
|
||||||
|
tracing::error!(?e, "peer conn send req error");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
ArchivedPacketType::Pong => {
|
||||||
|
if let Err(e) = ctrl_sender.send(buf.into()) {
|
||||||
|
tracing::error!(?e, "peer conn send ctrl resp error");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
_ => {
|
||||||
|
if sender.send(buf.into()).await.is_err() {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
tracing::info!("end recving peer conn packet");
|
||||||
|
|
||||||
|
if let Err(close_ret) = sink.close().await {
|
||||||
|
tracing::error!(error = ?close_ret, "peer conn sink close error, ignore it");
|
||||||
|
}
|
||||||
|
if let Err(e) = close_event_sender.send(conn_id).await {
|
||||||
|
tracing::error!(error = ?e, "peer conn close event send error");
|
||||||
|
}
|
||||||
|
|
||||||
|
task_ret
|
||||||
|
}
|
||||||
|
.instrument(
|
||||||
|
tracing::info_span!("peer conn recv loop", conn_info = ?conn_info_for_instrument),
|
||||||
|
),
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn send_msg(&mut self, msg: Bytes) -> Result<(), Error> {
|
||||||
|
self.sink.send(msg).await
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn get_peer_id(&self) -> PeerId {
|
||||||
|
self.info.as_ref().unwrap().my_peer_id
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn get_network_identity(&self) -> NetworkIdentity {
|
||||||
|
self.info.as_ref().unwrap().network_identity.clone()
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn set_close_event_sender(&mut self, sender: mpsc::Sender<PeerConnId>) {
|
||||||
|
self.close_event_sender = Some(sender);
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn get_stats(&self) -> PeerConnStats {
|
||||||
|
PeerConnStats {
|
||||||
|
latency_us: self.latency_stats.get_latency_us(),
|
||||||
|
|
||||||
|
tx_bytes: self.throughput.tx_bytes(),
|
||||||
|
rx_bytes: self.throughput.rx_bytes(),
|
||||||
|
|
||||||
|
tx_packets: self.throughput.tx_packets(),
|
||||||
|
rx_packets: self.throughput.rx_packets(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn get_conn_info(&self) -> PeerConnInfo {
|
||||||
|
PeerConnInfo {
|
||||||
|
conn_id: self.conn_id.to_string(),
|
||||||
|
my_peer_id: self.my_peer_id,
|
||||||
|
peer_id: self.get_peer_id(),
|
||||||
|
features: self.info.as_ref().unwrap().features.clone(),
|
||||||
|
tunnel: self.tunnel.info(),
|
||||||
|
stats: Some(self.get_stats()),
|
||||||
|
loss_rate: (f64::from(self.loss_rate_stats.load(Ordering::Relaxed)) / 100.0) as f32,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Drop for PeerConn {
|
||||||
|
fn drop(&mut self) {
|
||||||
|
let mut sink = self.tunnel.pin_sink();
|
||||||
|
tokio::spawn(async move {
|
||||||
|
let ret = sink.close().await;
|
||||||
|
tracing::info!(error = ?ret, "peer conn tunnel closed.");
|
||||||
|
});
|
||||||
|
log::info!("peer conn {:?} drop", self.conn_id);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
*/
|
||||||
@@ -116,7 +116,7 @@ pub fn add_ns_to_bridge(br_name: &str, ns_name: &str) {
|
|||||||
|
|
||||||
pub fn enable_log() {
|
pub fn enable_log() {
|
||||||
let filter = tracing_subscriber::EnvFilter::builder()
|
let filter = tracing_subscriber::EnvFilter::builder()
|
||||||
.with_default_directive(tracing::level_filters::LevelFilter::INFO.into())
|
.with_default_directive(tracing::level_filters::LevelFilter::TRACE.into())
|
||||||
.from_env()
|
.from_env()
|
||||||
.unwrap()
|
.unwrap()
|
||||||
.add_directive("tarpc=error".parse().unwrap());
|
.add_directive("tarpc=error".parse().unwrap());
|
||||||
|
|||||||
@@ -9,18 +9,18 @@ use super::*;
|
|||||||
|
|
||||||
use crate::{
|
use crate::{
|
||||||
common::{
|
common::{
|
||||||
config::{ConfigLoader, NetworkIdentity, TomlConfigLoader},
|
config::{ConfigLoader, NetworkIdentity, TomlConfigLoader, VpnPortalConfig},
|
||||||
netns::{NetNS, ROOT_NETNS_NAME},
|
netns::{NetNS, ROOT_NETNS_NAME},
|
||||||
},
|
},
|
||||||
instance::instance::Instance,
|
instance::instance::Instance,
|
||||||
peers::tests::wait_for_condition,
|
peers::tests::wait_for_condition,
|
||||||
tunnels::{
|
tunnel::{
|
||||||
common::tests::_tunnel_pingpong_netns,
|
ring::RingTunnelConnector,
|
||||||
ring_tunnel::RingTunnelConnector,
|
tcp::TcpTunnelConnector,
|
||||||
tcp_tunnel::{TcpTunnelConnector, TcpTunnelListener},
|
udp::UdpTunnelConnector,
|
||||||
udp_tunnel::{UdpTunnelConnector, UdpTunnelListener},
|
|
||||||
wireguard::{WgConfig, WgTunnelConnector},
|
wireguard::{WgConfig, WgTunnelConnector},
|
||||||
},
|
},
|
||||||
|
vpn_portal::wireguard::get_wg_config_for_portal,
|
||||||
};
|
};
|
||||||
|
|
||||||
pub fn prepare_linux_namespaces() {
|
pub fn prepare_linux_namespaces() {
|
||||||
@@ -113,6 +113,26 @@ pub async fn init_three_node(proto: &str) -> Vec<Instance> {
|
|||||||
vec![inst1, inst2, inst3]
|
vec![inst1, inst2, inst3]
|
||||||
}
|
}
|
||||||
|
|
||||||
|
async fn ping_test(from_netns: &str, target_ip: &str) -> bool {
|
||||||
|
let _g = NetNS::new(Some(ROOT_NETNS_NAME.to_owned())).guard();
|
||||||
|
let code = tokio::process::Command::new("ip")
|
||||||
|
.args(&[
|
||||||
|
"netns",
|
||||||
|
"exec",
|
||||||
|
from_netns,
|
||||||
|
"ping",
|
||||||
|
"-c",
|
||||||
|
"1",
|
||||||
|
"-W",
|
||||||
|
"1",
|
||||||
|
target_ip.to_string().as_str(),
|
||||||
|
])
|
||||||
|
.status()
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
code.code().unwrap() == 0
|
||||||
|
}
|
||||||
|
|
||||||
#[rstest::rstest]
|
#[rstest::rstest]
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
#[serial_test::serial]
|
#[serial_test::serial]
|
||||||
@@ -130,12 +150,20 @@ pub async fn basic_three_node_test(#[values("tcp", "udp", "wg")] proto: &str) {
|
|||||||
insts[2].peer_id(),
|
insts[2].peer_id(),
|
||||||
insts[0].get_peer_manager().list_routes().await,
|
insts[0].get_peer_manager().list_routes().await,
|
||||||
);
|
);
|
||||||
|
|
||||||
|
wait_for_condition(
|
||||||
|
|| async { ping_test("net_c", "10.144.144.1").await },
|
||||||
|
Duration::from_secs(5),
|
||||||
|
)
|
||||||
|
.await;
|
||||||
}
|
}
|
||||||
|
|
||||||
#[rstest::rstest]
|
#[rstest::rstest]
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
#[serial_test::serial]
|
#[serial_test::serial]
|
||||||
pub async fn tcp_proxy_three_node_test(#[values("tcp", "udp", "wg")] proto: &str) {
|
pub async fn tcp_proxy_three_node_test(#[values("tcp", "udp", "wg")] proto: &str) {
|
||||||
|
use crate::tunnel::{common::tests::_tunnel_pingpong_netns, tcp::TcpTunnelListener};
|
||||||
|
|
||||||
let insts = init_three_node(proto).await;
|
let insts = init_three_node(proto).await;
|
||||||
|
|
||||||
insts[2]
|
insts[2]
|
||||||
@@ -187,25 +215,19 @@ pub async fn icmp_proxy_three_node_test(#[values("tcp", "udp", "wg")] proto: &st
|
|||||||
)
|
)
|
||||||
.await;
|
.await;
|
||||||
|
|
||||||
// wait updater
|
wait_for_condition(
|
||||||
tokio::time::sleep(tokio::time::Duration::from_secs(6)).await;
|
|| async { ping_test("net_a", "10.1.2.4").await },
|
||||||
|
Duration::from_secs(5),
|
||||||
// send ping with shell in net_a to net_d
|
)
|
||||||
let _g = NetNS::new(Some(ROOT_NETNS_NAME.to_owned())).guard();
|
.await;
|
||||||
let code = tokio::process::Command::new("ip")
|
|
||||||
.args(&[
|
|
||||||
"netns", "exec", "net_a", "ping", "-c", "1", "-W", "1", "10.1.2.4",
|
|
||||||
])
|
|
||||||
.status()
|
|
||||||
.await
|
|
||||||
.unwrap();
|
|
||||||
assert_eq!(code.code().unwrap(), 0);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[rstest::rstest]
|
#[rstest::rstest]
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
#[serial_test::serial]
|
#[serial_test::serial]
|
||||||
pub async fn proxy_three_node_disconnect_test(#[values("tcp", "wg")] proto: &str) {
|
pub async fn proxy_three_node_disconnect_test(#[values("tcp", "wg")] proto: &str) {
|
||||||
|
use crate::tunnel::wireguard::{WgConfig, WgTunnelConnector};
|
||||||
|
|
||||||
let insts = init_three_node(proto).await;
|
let insts = init_three_node(proto).await;
|
||||||
let mut inst4 = Instance::new(get_inst_config("inst4", Some("net_d"), "10.144.144.4"));
|
let mut inst4 = Instance::new(get_inst_config("inst4", Some("net_d"), "10.144.144.4"));
|
||||||
if proto == "tcp" {
|
if proto == "tcp" {
|
||||||
@@ -266,6 +288,8 @@ pub async fn proxy_three_node_disconnect_test(#[values("tcp", "wg")] proto: &str
|
|||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
#[serial_test::serial]
|
#[serial_test::serial]
|
||||||
pub async fn udp_proxy_three_node_test(#[values("tcp", "udp", "wg")] proto: &str) {
|
pub async fn udp_proxy_three_node_test(#[values("tcp", "udp", "wg")] proto: &str) {
|
||||||
|
use crate::tunnel::{common::tests::_tunnel_pingpong_netns, udp::UdpTunnelListener};
|
||||||
|
|
||||||
let insts = init_three_node(proto).await;
|
let insts = init_three_node(proto).await;
|
||||||
|
|
||||||
insts[2]
|
insts[2]
|
||||||
@@ -389,21 +413,108 @@ pub async fn foreign_network_forward_nic_data() {
|
|||||||
)
|
)
|
||||||
.await;
|
.await;
|
||||||
|
|
||||||
let _g = NetNS::new(Some(ROOT_NETNS_NAME.to_owned())).guard();
|
wait_for_condition(
|
||||||
let code = tokio::process::Command::new("ip")
|
|| async { ping_test("net_b", "10.144.145.2").await },
|
||||||
.args(&[
|
Duration::from_secs(5),
|
||||||
"netns",
|
)
|
||||||
"exec",
|
.await;
|
||||||
"net_b",
|
}
|
||||||
"ping",
|
|
||||||
"-c",
|
use std::{net::SocketAddr, str::FromStr};
|
||||||
"1",
|
|
||||||
"-W",
|
use defguard_wireguard_rs::{
|
||||||
"1",
|
host::Peer, key::Key, net::IpAddrMask, InterfaceConfiguration, WGApi, WireguardInterfaceApi,
|
||||||
"10.144.145.2",
|
};
|
||||||
])
|
|
||||||
.status()
|
fn run_wireguard_client(
|
||||||
.await
|
endpoint: SocketAddr,
|
||||||
.unwrap();
|
peer_public_key: Key,
|
||||||
assert_eq!(code.code().unwrap(), 0);
|
client_private_key: Key,
|
||||||
|
allowed_ips: Vec<String>,
|
||||||
|
client_ip: String,
|
||||||
|
) -> Result<(), Box<dyn std::error::Error>> {
|
||||||
|
// Create new API object for interface
|
||||||
|
let ifname: String = if cfg!(target_os = "linux") || cfg!(target_os = "freebsd") {
|
||||||
|
"wg0".into()
|
||||||
|
} else {
|
||||||
|
"utun3".into()
|
||||||
|
};
|
||||||
|
let wgapi = WGApi::new(ifname.clone(), false)?;
|
||||||
|
|
||||||
|
// create interface
|
||||||
|
wgapi.create_interface()?;
|
||||||
|
|
||||||
|
// Peer secret key
|
||||||
|
let mut peer = Peer::new(peer_public_key.clone());
|
||||||
|
|
||||||
|
log::info!("endpoint");
|
||||||
|
// Peer endpoint and interval
|
||||||
|
peer.endpoint = Some(endpoint);
|
||||||
|
peer.persistent_keepalive_interval = Some(25);
|
||||||
|
for ip in allowed_ips {
|
||||||
|
peer.allowed_ips.push(IpAddrMask::from_str(ip.as_str())?);
|
||||||
|
}
|
||||||
|
|
||||||
|
// interface configuration
|
||||||
|
let interface_config = InterfaceConfiguration {
|
||||||
|
name: ifname.clone(),
|
||||||
|
prvkey: client_private_key.to_string(),
|
||||||
|
address: client_ip,
|
||||||
|
port: 12345,
|
||||||
|
peers: vec![peer],
|
||||||
|
};
|
||||||
|
|
||||||
|
#[cfg(not(windows))]
|
||||||
|
wgapi.configure_interface(&interface_config)?;
|
||||||
|
#[cfg(windows)]
|
||||||
|
wgapi.configure_interface(&interface_config, &[])?;
|
||||||
|
wgapi.configure_peer_routing(&interface_config.peers)?;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
#[serial_test::serial]
|
||||||
|
pub async fn wireguard_vpn_portal() {
|
||||||
|
let mut insts = init_three_node("tcp").await;
|
||||||
|
let net_ns = NetNS::new(Some("net_d".into()));
|
||||||
|
let _g = net_ns.guard();
|
||||||
|
insts[2]
|
||||||
|
.get_global_ctx()
|
||||||
|
.config
|
||||||
|
.set_vpn_portal_config(VpnPortalConfig {
|
||||||
|
wireguard_listen: "0.0.0.0:22121".parse().unwrap(),
|
||||||
|
client_cidr: "10.14.14.0/24".parse().unwrap(),
|
||||||
|
});
|
||||||
|
insts[2].run_vpn_portal().await.unwrap();
|
||||||
|
|
||||||
|
let net_ns = NetNS::new(Some("net_d".into()));
|
||||||
|
let _g = net_ns.guard();
|
||||||
|
let wg_cfg = get_wg_config_for_portal(&insts[2].get_global_ctx().get_network_identity());
|
||||||
|
run_wireguard_client(
|
||||||
|
"10.1.2.3:22121".parse().unwrap(),
|
||||||
|
Key::try_from(wg_cfg.my_public_key()).unwrap(),
|
||||||
|
Key::try_from(wg_cfg.peer_secret_key()).unwrap(),
|
||||||
|
vec!["10.14.14.0/24".to_string(), "10.144.144.0/24".to_string()],
|
||||||
|
"10.14.14.2".to_string(),
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
// ping other node in network
|
||||||
|
wait_for_condition(
|
||||||
|
|| async { ping_test("net_d", "10.144.144.1").await },
|
||||||
|
Duration::from_secs(5),
|
||||||
|
)
|
||||||
|
.await;
|
||||||
|
wait_for_condition(
|
||||||
|
|| async { ping_test("net_d", "10.144.144.2").await },
|
||||||
|
Duration::from_secs(5),
|
||||||
|
)
|
||||||
|
.await;
|
||||||
|
|
||||||
|
// ping portal node
|
||||||
|
wait_for_condition(
|
||||||
|
|| async { ping_test("net_d", "10.144.144.3").await },
|
||||||
|
Duration::from_secs(500),
|
||||||
|
)
|
||||||
|
.await;
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -0,0 +1,92 @@
|
|||||||
|
use std::collections::VecDeque;
|
||||||
|
use std::io::IoSlice;
|
||||||
|
|
||||||
|
use bytes::{Buf, BufMut, Bytes, BytesMut};
|
||||||
|
|
||||||
|
pub(crate) struct BufList<T> {
|
||||||
|
bufs: VecDeque<T>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<T: Buf> BufList<T> {
|
||||||
|
pub(crate) fn new() -> BufList<T> {
|
||||||
|
BufList {
|
||||||
|
bufs: VecDeque::new(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[inline]
|
||||||
|
pub(crate) fn push(&mut self, buf: T) {
|
||||||
|
debug_assert!(buf.has_remaining());
|
||||||
|
self.bufs.push_back(buf);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[inline]
|
||||||
|
pub(crate) fn bufs_cnt(&self) -> usize {
|
||||||
|
self.bufs.len()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<T: Buf> Buf for BufList<T> {
|
||||||
|
#[inline]
|
||||||
|
fn remaining(&self) -> usize {
|
||||||
|
self.bufs.iter().map(|buf| buf.remaining()).sum()
|
||||||
|
}
|
||||||
|
|
||||||
|
#[inline]
|
||||||
|
fn chunk(&self) -> &[u8] {
|
||||||
|
self.bufs.front().map(Buf::chunk).unwrap_or_default()
|
||||||
|
}
|
||||||
|
|
||||||
|
#[inline]
|
||||||
|
fn advance(&mut self, mut cnt: usize) {
|
||||||
|
while cnt > 0 {
|
||||||
|
{
|
||||||
|
let front = &mut self.bufs[0];
|
||||||
|
let rem = front.remaining();
|
||||||
|
if rem > cnt {
|
||||||
|
front.advance(cnt);
|
||||||
|
return;
|
||||||
|
} else {
|
||||||
|
front.advance(rem);
|
||||||
|
cnt -= rem;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
self.bufs.pop_front();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[inline]
|
||||||
|
fn chunks_vectored<'t>(&'t self, dst: &mut [IoSlice<'t>]) -> usize {
|
||||||
|
if dst.is_empty() {
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
let mut vecs = 0;
|
||||||
|
for buf in &self.bufs {
|
||||||
|
vecs += buf.chunks_vectored(&mut dst[vecs..]);
|
||||||
|
if vecs == dst.len() {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
vecs
|
||||||
|
}
|
||||||
|
|
||||||
|
#[inline]
|
||||||
|
fn copy_to_bytes(&mut self, len: usize) -> Bytes {
|
||||||
|
// Our inner buffer may have an optimized version of copy_to_bytes, and if the whole
|
||||||
|
// request can be fulfilled by the front buffer, we can take advantage.
|
||||||
|
match self.bufs.front_mut() {
|
||||||
|
Some(front) if front.remaining() == len => {
|
||||||
|
let b = front.copy_to_bytes(len);
|
||||||
|
self.bufs.pop_front();
|
||||||
|
b
|
||||||
|
}
|
||||||
|
Some(front) if front.remaining() > len => front.copy_to_bytes(len),
|
||||||
|
_ => {
|
||||||
|
assert!(len <= self.remaining(), "`len` greater than remaining");
|
||||||
|
let mut bm = BytesMut::with_capacity(len);
|
||||||
|
bm.put(self.take(len));
|
||||||
|
bm.freeze()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,539 @@
|
|||||||
|
use std::{
|
||||||
|
any::Any,
|
||||||
|
net::{IpAddr, SocketAddr},
|
||||||
|
pin::Pin,
|
||||||
|
sync::{Arc, Mutex},
|
||||||
|
task::{ready, Poll},
|
||||||
|
};
|
||||||
|
|
||||||
|
use futures::{stream::FuturesUnordered, Future, Sink, Stream};
|
||||||
|
use network_interface::NetworkInterfaceConfig as _;
|
||||||
|
use pin_project_lite::pin_project;
|
||||||
|
use tokio::io::{AsyncRead, AsyncWrite};
|
||||||
|
|
||||||
|
use bytes::{Buf, Bytes, BytesMut};
|
||||||
|
use tokio_stream::StreamExt;
|
||||||
|
use tokio_util::io::{poll_read_buf, poll_write_buf};
|
||||||
|
use zerocopy::FromBytes as _;
|
||||||
|
|
||||||
|
use crate::{
|
||||||
|
rpc::TunnelInfo,
|
||||||
|
tunnel::packet_def::{ZCPacket, PEER_MANAGER_HEADER_SIZE},
|
||||||
|
};
|
||||||
|
|
||||||
|
use super::{
|
||||||
|
buf::BufList,
|
||||||
|
packet_def::{TCPTunnelHeader, ZCPacketType, TCP_TUNNEL_HEADER_SIZE},
|
||||||
|
SinkItem, StreamItem, Tunnel, TunnelError, ZCPacketSink, ZCPacketStream,
|
||||||
|
};
|
||||||
|
|
||||||
|
pub struct TunnelWrapper<R, W> {
|
||||||
|
reader: Arc<Mutex<Option<R>>>,
|
||||||
|
writer: Arc<Mutex<Option<W>>>,
|
||||||
|
info: Option<TunnelInfo>,
|
||||||
|
associate_data: Option<Box<dyn Any + Send + 'static>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<R, W> TunnelWrapper<R, W> {
|
||||||
|
pub fn new(reader: R, writer: W, info: Option<TunnelInfo>) -> Self {
|
||||||
|
Self::new_with_associate_data(reader, writer, info, None)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn new_with_associate_data(
|
||||||
|
reader: R,
|
||||||
|
writer: W,
|
||||||
|
info: Option<TunnelInfo>,
|
||||||
|
associate_data: Option<Box<dyn Any + Send + 'static>>,
|
||||||
|
) -> Self {
|
||||||
|
TunnelWrapper {
|
||||||
|
reader: Arc::new(Mutex::new(Some(reader))),
|
||||||
|
writer: Arc::new(Mutex::new(Some(writer))),
|
||||||
|
info,
|
||||||
|
associate_data,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<R, W> Tunnel for TunnelWrapper<R, W>
|
||||||
|
where
|
||||||
|
R: ZCPacketStream + Send + 'static,
|
||||||
|
W: ZCPacketSink + Send + 'static,
|
||||||
|
{
|
||||||
|
fn split(&self) -> (Pin<Box<dyn ZCPacketStream>>, Pin<Box<dyn ZCPacketSink>>) {
|
||||||
|
let reader = self.reader.lock().unwrap().take().unwrap();
|
||||||
|
let writer = self.writer.lock().unwrap().take().unwrap();
|
||||||
|
(Box::pin(reader), Box::pin(writer))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn info(&self) -> Option<TunnelInfo> {
|
||||||
|
self.info.clone()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// a length delimited codec for async reader
|
||||||
|
pin_project! {
|
||||||
|
pub struct FramedReader<R> {
|
||||||
|
#[pin]
|
||||||
|
reader: R,
|
||||||
|
buf: BytesMut,
|
||||||
|
state: FrameReaderState,
|
||||||
|
max_packet_size: usize,
|
||||||
|
associate_data: Option<Box<dyn Any + Send + 'static>>,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// usize means the size remaining to read
|
||||||
|
enum FrameReaderState {
|
||||||
|
ReadingHeader(usize),
|
||||||
|
ReadingBody(usize),
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<R> FramedReader<R> {
|
||||||
|
pub fn new(reader: R, max_packet_size: usize) -> Self {
|
||||||
|
Self::new_with_associate_data(reader, max_packet_size, None)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn new_with_associate_data(
|
||||||
|
reader: R,
|
||||||
|
max_packet_size: usize,
|
||||||
|
associate_data: Option<Box<dyn Any + Send + 'static>>,
|
||||||
|
) -> Self {
|
||||||
|
FramedReader {
|
||||||
|
reader,
|
||||||
|
buf: BytesMut::with_capacity(max_packet_size),
|
||||||
|
state: FrameReaderState::ReadingHeader(4),
|
||||||
|
max_packet_size,
|
||||||
|
associate_data,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn extract_one_packet(buf: &mut BytesMut) -> Option<ZCPacket> {
|
||||||
|
if buf.len() < TCP_TUNNEL_HEADER_SIZE {
|
||||||
|
// header is not complete
|
||||||
|
return None;
|
||||||
|
}
|
||||||
|
|
||||||
|
let header = TCPTunnelHeader::ref_from_prefix(&buf[..]).unwrap();
|
||||||
|
let body_len = header.len.get() as usize;
|
||||||
|
if buf.len() < TCP_TUNNEL_HEADER_SIZE + body_len {
|
||||||
|
// body is not complete
|
||||||
|
return None;
|
||||||
|
}
|
||||||
|
|
||||||
|
// extract one packet
|
||||||
|
let packet_buf = buf.split_to(TCP_TUNNEL_HEADER_SIZE + body_len);
|
||||||
|
Some(ZCPacket::new_from_buf(packet_buf, ZCPacketType::TCP))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<R> Stream for FramedReader<R>
|
||||||
|
where
|
||||||
|
R: AsyncRead + Send + 'static + Unpin,
|
||||||
|
{
|
||||||
|
type Item = StreamItem;
|
||||||
|
|
||||||
|
fn poll_next(
|
||||||
|
self: Pin<&mut Self>,
|
||||||
|
cx: &mut std::task::Context<'_>,
|
||||||
|
) -> std::task::Poll<Option<Self::Item>> {
|
||||||
|
let mut self_mut = self.project();
|
||||||
|
|
||||||
|
loop {
|
||||||
|
while let Some(packet) = Self::extract_one_packet(self_mut.buf) {
|
||||||
|
return Poll::Ready(Some(Ok(packet)));
|
||||||
|
}
|
||||||
|
|
||||||
|
reserve_buf(
|
||||||
|
&mut self_mut.buf,
|
||||||
|
*self_mut.max_packet_size,
|
||||||
|
*self_mut.max_packet_size * 64,
|
||||||
|
);
|
||||||
|
|
||||||
|
match ready!(poll_read_buf(
|
||||||
|
self_mut.reader.as_mut(),
|
||||||
|
cx,
|
||||||
|
&mut self_mut.buf
|
||||||
|
)) {
|
||||||
|
Ok(size) => {
|
||||||
|
if size == 0 {
|
||||||
|
return Poll::Ready(None);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Err(e) => {
|
||||||
|
return Poll::Ready(Some(Err(TunnelError::IOError(e))));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pin_project! {
|
||||||
|
pub struct FramedWriter<W> {
|
||||||
|
#[pin]
|
||||||
|
writer: W,
|
||||||
|
sending_bufs: BufList<Bytes>,
|
||||||
|
associate_data: Option<Box<dyn Any + Send + 'static>>,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<W> FramedWriter<W> {
|
||||||
|
pub fn new(writer: W) -> Self {
|
||||||
|
Self::new_with_associate_data(writer, None)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn new_with_associate_data(
|
||||||
|
writer: W,
|
||||||
|
associate_data: Option<Box<dyn Any + Send + 'static>>,
|
||||||
|
) -> Self {
|
||||||
|
FramedWriter {
|
||||||
|
writer,
|
||||||
|
sending_bufs: BufList::new(),
|
||||||
|
associate_data: associate_data,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn max_buffer_count(&self) -> usize {
|
||||||
|
64
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<W> Sink<SinkItem> for FramedWriter<W>
|
||||||
|
where
|
||||||
|
W: AsyncWrite + Send + 'static,
|
||||||
|
{
|
||||||
|
type Error = TunnelError;
|
||||||
|
|
||||||
|
fn poll_ready(
|
||||||
|
mut self: Pin<&mut Self>,
|
||||||
|
cx: &mut std::task::Context<'_>,
|
||||||
|
) -> std::task::Poll<Result<(), Self::Error>> {
|
||||||
|
let max_buffer_count = self.max_buffer_count();
|
||||||
|
if self.sending_bufs.bufs_cnt() >= max_buffer_count {
|
||||||
|
self.as_mut().poll_flush(cx)
|
||||||
|
} else {
|
||||||
|
tracing::trace!(bufs_cnt = self.sending_bufs.bufs_cnt(), "ready to send");
|
||||||
|
Poll::Ready(Ok(()))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn start_send(self: Pin<&mut Self>, mut item: ZCPacket) -> Result<(), Self::Error> {
|
||||||
|
let tcp_len = PEER_MANAGER_HEADER_SIZE + item.payload_len();
|
||||||
|
let Some(header) = item.mut_tcp_tunnel_header() else {
|
||||||
|
return Err(TunnelError::InvalidPacket("packet too short".to_string()));
|
||||||
|
};
|
||||||
|
header.len.set(tcp_len.try_into().unwrap());
|
||||||
|
|
||||||
|
let item = item.into_bytes(ZCPacketType::TCP);
|
||||||
|
self.project().sending_bufs.push(item);
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn poll_flush(
|
||||||
|
self: Pin<&mut Self>,
|
||||||
|
cx: &mut std::task::Context<'_>,
|
||||||
|
) -> Poll<Result<(), Self::Error>> {
|
||||||
|
let mut pinned = self.project();
|
||||||
|
let mut remaining = pinned.sending_bufs.remaining();
|
||||||
|
while remaining != 0 {
|
||||||
|
let n = ready!(poll_write_buf(
|
||||||
|
pinned.writer.as_mut(),
|
||||||
|
cx,
|
||||||
|
pinned.sending_bufs
|
||||||
|
))?;
|
||||||
|
if n == 0 {
|
||||||
|
return Poll::Ready(Err(TunnelError::IOError(std::io::Error::new(
|
||||||
|
std::io::ErrorKind::WriteZero,
|
||||||
|
"failed to \
|
||||||
|
write frame to transport",
|
||||||
|
))));
|
||||||
|
}
|
||||||
|
remaining -= n;
|
||||||
|
}
|
||||||
|
|
||||||
|
tracing::trace!(?remaining, "flushed");
|
||||||
|
|
||||||
|
// Try flushing the underlying IO
|
||||||
|
ready!(pinned.writer.poll_flush(cx))?;
|
||||||
|
|
||||||
|
Poll::Ready(Ok(()))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn poll_close(
|
||||||
|
mut self: Pin<&mut Self>,
|
||||||
|
cx: &mut std::task::Context<'_>,
|
||||||
|
) -> Poll<Result<(), Self::Error>> {
|
||||||
|
ready!(self.as_mut().poll_flush(cx))?;
|
||||||
|
ready!(self.project().writer.poll_shutdown(cx))?;
|
||||||
|
|
||||||
|
Poll::Ready(Ok(()))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(crate) fn get_interface_name_by_ip(local_ip: &IpAddr) -> Option<String> {
|
||||||
|
if local_ip.is_unspecified() || local_ip.is_multicast() {
|
||||||
|
return None;
|
||||||
|
}
|
||||||
|
let ifaces = network_interface::NetworkInterface::show().ok()?;
|
||||||
|
for iface in ifaces {
|
||||||
|
for addr in iface.addr {
|
||||||
|
if addr.ip() == *local_ip {
|
||||||
|
return Some(iface.name);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
tracing::error!(?local_ip, "can not find interface name by ip");
|
||||||
|
None
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(crate) fn setup_sokcet2_ext(
|
||||||
|
socket2_socket: &socket2::Socket,
|
||||||
|
bind_addr: &SocketAddr,
|
||||||
|
bind_dev: Option<String>,
|
||||||
|
) -> Result<(), TunnelError> {
|
||||||
|
#[cfg(target_os = "windows")]
|
||||||
|
{
|
||||||
|
let is_udp = matches!(socket2_socket.r#type()?, socket2::Type::DGRAM);
|
||||||
|
crate::arch::windows::setup_socket_for_win(socket2_socket, bind_addr, bind_dev, is_udp)?;
|
||||||
|
}
|
||||||
|
|
||||||
|
socket2_socket.set_nonblocking(true)?;
|
||||||
|
socket2_socket.set_reuse_address(true)?;
|
||||||
|
socket2_socket.bind(&socket2::SockAddr::from(*bind_addr))?;
|
||||||
|
|
||||||
|
// #[cfg(all(unix, not(target_os = "solaris"), not(target_os = "illumos")))]
|
||||||
|
// socket2_socket.set_reuse_port(true)?;
|
||||||
|
|
||||||
|
if bind_addr.ip().is_unspecified() {
|
||||||
|
return Ok(());
|
||||||
|
}
|
||||||
|
|
||||||
|
// linux/mac does not use interface of bind_addr to send packet, so we need to bind device
|
||||||
|
// win can handle this with bind correctly
|
||||||
|
#[cfg(any(target_os = "ios", target_os = "macos"))]
|
||||||
|
if let Some(dev_name) = bind_dev {
|
||||||
|
// use IP_BOUND_IF to bind device
|
||||||
|
unsafe {
|
||||||
|
let dev_idx = nix::libc::if_nametoindex(dev_name.as_str().as_ptr() as *const i8);
|
||||||
|
tracing::warn!(?dev_idx, ?dev_name, "bind device");
|
||||||
|
socket2_socket.bind_device_by_index_v4(std::num::NonZeroU32::new(dev_idx))?;
|
||||||
|
tracing::warn!(?dev_idx, ?dev_name, "bind device doen");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(any(target_os = "android", target_os = "fuchsia", target_os = "linux"))]
|
||||||
|
if let Some(dev_name) = bind_dev {
|
||||||
|
tracing::trace!(dev_name = ?dev_name, "bind device");
|
||||||
|
socket2_socket.bind_device(Some(dev_name.as_bytes()))?;
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(crate) async fn wait_for_connect_futures<Fut, Ret, E>(
|
||||||
|
mut futures: FuturesUnordered<Fut>,
|
||||||
|
) -> Result<Ret, TunnelError>
|
||||||
|
where
|
||||||
|
Fut: Future<Output = Result<Ret, E>> + Send + Sync,
|
||||||
|
E: std::error::Error + Into<TunnelError> + Send + Sync + 'static,
|
||||||
|
{
|
||||||
|
// return last error
|
||||||
|
let mut last_err = None;
|
||||||
|
|
||||||
|
while let Some(ret) = futures.next().await {
|
||||||
|
if let Err(e) = ret {
|
||||||
|
last_err = Some(e.into());
|
||||||
|
} else {
|
||||||
|
return ret.map_err(|e| e.into());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Err(last_err.unwrap_or(TunnelError::Shutdown))
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(crate) fn setup_sokcet2(
|
||||||
|
socket2_socket: &socket2::Socket,
|
||||||
|
bind_addr: &SocketAddr,
|
||||||
|
) -> Result<(), TunnelError> {
|
||||||
|
setup_sokcet2_ext(
|
||||||
|
socket2_socket,
|
||||||
|
bind_addr,
|
||||||
|
super::common::get_interface_name_by_ip(&bind_addr.ip()),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn reserve_buf(buf: &mut BytesMut, min_size: usize, max_size: usize) {
|
||||||
|
if buf.capacity() < min_size {
|
||||||
|
buf.reserve(max_size);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub mod tests {
|
||||||
|
use std::time::Instant;
|
||||||
|
|
||||||
|
use futures::{SinkExt, StreamExt, TryStreamExt};
|
||||||
|
use tokio_util::bytes::{BufMut, Bytes, BytesMut};
|
||||||
|
|
||||||
|
use crate::{
|
||||||
|
common::netns::NetNS,
|
||||||
|
tunnel::{packet_def::ZCPacket, TunnelConnector, TunnelListener},
|
||||||
|
};
|
||||||
|
|
||||||
|
pub async fn _tunnel_echo_server(tunnel: Box<dyn super::Tunnel>, once: bool) {
|
||||||
|
let (mut recv, mut send) = tunnel.split();
|
||||||
|
|
||||||
|
if !once {
|
||||||
|
recv.forward(send).await.unwrap();
|
||||||
|
} else {
|
||||||
|
let Some(ret) = recv.next().await else {
|
||||||
|
assert!(false, "recv error");
|
||||||
|
return;
|
||||||
|
};
|
||||||
|
|
||||||
|
if ret.is_err() {
|
||||||
|
tracing::debug!(?ret, "recv error");
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
let res = ret.unwrap();
|
||||||
|
tracing::debug!(?res, "recv a msg, try echo back");
|
||||||
|
send.send(res).await.unwrap();
|
||||||
|
}
|
||||||
|
|
||||||
|
tracing::warn!("echo server exit...");
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(crate) async fn _tunnel_pingpong<L, C>(listener: L, connector: C)
|
||||||
|
where
|
||||||
|
L: TunnelListener + Send + Sync + 'static,
|
||||||
|
C: TunnelConnector + Send + Sync + 'static,
|
||||||
|
{
|
||||||
|
_tunnel_pingpong_netns(listener, connector, NetNS::new(None), NetNS::new(None)).await
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(crate) async fn _tunnel_pingpong_netns<L, C>(
|
||||||
|
mut listener: L,
|
||||||
|
mut connector: C,
|
||||||
|
l_netns: NetNS,
|
||||||
|
c_netns: NetNS,
|
||||||
|
) where
|
||||||
|
L: TunnelListener + Send + Sync + 'static,
|
||||||
|
C: TunnelConnector + Send + Sync + 'static,
|
||||||
|
{
|
||||||
|
l_netns
|
||||||
|
.run_async(|| async {
|
||||||
|
listener.listen().await.unwrap();
|
||||||
|
})
|
||||||
|
.await;
|
||||||
|
|
||||||
|
let lis = tokio::spawn(async move {
|
||||||
|
let ret = listener.accept().await.unwrap();
|
||||||
|
assert_eq!(
|
||||||
|
ret.info().unwrap().local_addr,
|
||||||
|
listener.local_url().to_string()
|
||||||
|
);
|
||||||
|
_tunnel_echo_server(ret, false).await
|
||||||
|
});
|
||||||
|
|
||||||
|
let tunnel = c_netns.run_async(|| connector.connect()).await.unwrap();
|
||||||
|
|
||||||
|
assert_eq!(
|
||||||
|
tunnel.info().unwrap().remote_addr,
|
||||||
|
connector.remote_url().to_string()
|
||||||
|
);
|
||||||
|
|
||||||
|
let (mut recv, mut send) = tunnel.split();
|
||||||
|
|
||||||
|
send.send(ZCPacket::new_with_payload("12345678abcdefg".as_bytes()))
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
let ret = tokio::time::timeout(tokio::time::Duration::from_secs(1), recv.next())
|
||||||
|
.await
|
||||||
|
.unwrap()
|
||||||
|
.unwrap()
|
||||||
|
.unwrap();
|
||||||
|
println!("echo back: {:?}", ret);
|
||||||
|
assert_eq!(ret.payload(), Bytes::from("12345678abcdefg"));
|
||||||
|
|
||||||
|
drop(send);
|
||||||
|
|
||||||
|
if ["udp", "wg"].contains(&connector.remote_url().scheme()) {
|
||||||
|
lis.abort();
|
||||||
|
} else {
|
||||||
|
// lis should finish in 1 second
|
||||||
|
let ret = tokio::time::timeout(tokio::time::Duration::from_secs(1), lis).await;
|
||||||
|
assert!(ret.is_ok());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(crate) async fn _tunnel_bench<L, C>(mut listener: L, mut connector: C)
|
||||||
|
where
|
||||||
|
L: TunnelListener + Send + Sync + 'static,
|
||||||
|
C: TunnelConnector + Send + Sync + 'static,
|
||||||
|
{
|
||||||
|
listener.listen().await.unwrap();
|
||||||
|
|
||||||
|
let lis = tokio::spawn(async move {
|
||||||
|
let ret = listener.accept().await.unwrap();
|
||||||
|
_tunnel_echo_server(ret, false).await
|
||||||
|
});
|
||||||
|
|
||||||
|
let tunnel = connector.connect().await.unwrap();
|
||||||
|
|
||||||
|
let (recv, mut send) = tunnel.split();
|
||||||
|
|
||||||
|
// prepare a 4k buffer with random data
|
||||||
|
let mut send_buf = BytesMut::new();
|
||||||
|
for _ in 0..64 {
|
||||||
|
send_buf.put_i128(rand::random::<i128>());
|
||||||
|
}
|
||||||
|
|
||||||
|
let r = tokio::spawn(async move {
|
||||||
|
let now = Instant::now();
|
||||||
|
let count = recv
|
||||||
|
.try_fold(0usize, |mut ret, _| async move {
|
||||||
|
ret += 1;
|
||||||
|
Ok(ret)
|
||||||
|
})
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
println!(
|
||||||
|
"bps: {}",
|
||||||
|
(count / 1024) * 4 / now.elapsed().as_secs() as usize
|
||||||
|
);
|
||||||
|
});
|
||||||
|
|
||||||
|
let now = Instant::now();
|
||||||
|
while now.elapsed().as_secs() < 10 {
|
||||||
|
// send.feed(item)
|
||||||
|
let item = ZCPacket::new_with_payload(send_buf.as_ref());
|
||||||
|
let _ = send.feed(item).await.unwrap();
|
||||||
|
}
|
||||||
|
|
||||||
|
drop(send);
|
||||||
|
drop(connector);
|
||||||
|
drop(tunnel);
|
||||||
|
|
||||||
|
tracing::warn!("wait for recv to finish...");
|
||||||
|
|
||||||
|
let _ = tokio::join!(r);
|
||||||
|
|
||||||
|
lis.abort();
|
||||||
|
let _ = tokio::join!(lis);
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn enable_log() {
|
||||||
|
let filter = tracing_subscriber::EnvFilter::builder()
|
||||||
|
.with_default_directive(tracing::level_filters::LevelFilter::TRACE.into())
|
||||||
|
.from_env()
|
||||||
|
.unwrap()
|
||||||
|
.add_directive("tarpc=error".parse().unwrap());
|
||||||
|
tracing_subscriber::fmt::fmt()
|
||||||
|
.pretty()
|
||||||
|
.with_env_filter(filter)
|
||||||
|
.init();
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,362 @@
|
|||||||
|
use std::{
|
||||||
|
sync::Arc,
|
||||||
|
task::{Context, Poll},
|
||||||
|
};
|
||||||
|
|
||||||
|
use crate::rpc::TunnelInfo;
|
||||||
|
use auto_impl::auto_impl;
|
||||||
|
use futures::{Sink, SinkExt, Stream, StreamExt};
|
||||||
|
|
||||||
|
use self::stats::Throughput;
|
||||||
|
|
||||||
|
use super::*;
|
||||||
|
|
||||||
|
#[auto_impl(Arc, Box)]
|
||||||
|
pub trait TunnelFilter: Send + Sync {
|
||||||
|
type FilterOutput;
|
||||||
|
|
||||||
|
fn before_send(&self, data: SinkItem) -> Option<SinkItem> {
|
||||||
|
Some(data)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn after_received(&self, data: StreamItem) -> Option<StreamItem> {
|
||||||
|
match data {
|
||||||
|
Ok(v) => Some(Ok(v)),
|
||||||
|
Err(e) => Some(Err(e)),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn filter_output(&self) -> Self::FilterOutput;
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct TunnelFilterChain<A, B> {
|
||||||
|
a: A,
|
||||||
|
b: B,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<A, B, OA, OB> TunnelFilter for TunnelFilterChain<A, B>
|
||||||
|
where
|
||||||
|
A: TunnelFilter<FilterOutput = OA>,
|
||||||
|
B: TunnelFilter<FilterOutput = OB>,
|
||||||
|
{
|
||||||
|
type FilterOutput = (OA, OB);
|
||||||
|
fn before_send(&self, data: SinkItem) -> Option<SinkItem> {
|
||||||
|
let data = self.a.before_send(data)?;
|
||||||
|
self.b.before_send(data)
|
||||||
|
}
|
||||||
|
fn after_received(&self, data: StreamItem) -> Option<StreamItem> {
|
||||||
|
let data = self.b.after_received(data)?;
|
||||||
|
self.a.after_received(data)
|
||||||
|
}
|
||||||
|
fn filter_output(&self) -> Self::FilterOutput {
|
||||||
|
(self.a.filter_output(), self.b.filter_output())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<A, B> TunnelFilterChain<A, B> {
|
||||||
|
pub fn new(a: A, b: B) -> Self {
|
||||||
|
Self { a, b }
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn chain<T: TunnelFilter>(self, c: T) -> TunnelFilterChain<Self, T> {
|
||||||
|
TunnelFilterChain::new(self, c)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct EmptyFilter;
|
||||||
|
impl TunnelFilter for EmptyFilter {
|
||||||
|
type FilterOutput = ();
|
||||||
|
fn filter_output(&self) {}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub trait ToTunnelChain {
|
||||||
|
fn to_chain(self) -> TunnelFilterChain<EmptyFilter, Self>
|
||||||
|
where
|
||||||
|
Self: Sized,
|
||||||
|
{
|
||||||
|
TunnelFilterChain::new(EmptyFilter, self)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<O, T: TunnelFilter<FilterOutput = O>> ToTunnelChain for T {}
|
||||||
|
|
||||||
|
pub struct TunnelWithFilter<T, F> {
|
||||||
|
inner: T,
|
||||||
|
filter: Arc<F>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<T, F> TunnelWithFilter<T, F>
|
||||||
|
where
|
||||||
|
T: Tunnel + Send + 'static,
|
||||||
|
F: TunnelFilter + Send + 'static,
|
||||||
|
{
|
||||||
|
pub fn new(inner: T, filter: F) -> Self {
|
||||||
|
Self {
|
||||||
|
inner,
|
||||||
|
filter: Arc::new(filter),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn wrap_sink<S: ZCPacketSink + Unpin + 'static>(&self, sink: S) -> impl ZCPacketSink {
|
||||||
|
struct SinkWrapper<F, S> {
|
||||||
|
sink: S,
|
||||||
|
filter: Arc<F>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<F, S> Sink<ZCPacket> for SinkWrapper<F, S>
|
||||||
|
where
|
||||||
|
F: TunnelFilter + 'static,
|
||||||
|
S: ZCPacketSink + 'static + Unpin,
|
||||||
|
{
|
||||||
|
type Error = SinkError;
|
||||||
|
|
||||||
|
fn poll_ready(
|
||||||
|
self: std::pin::Pin<&mut Self>,
|
||||||
|
cx: &mut Context<'_>,
|
||||||
|
) -> Poll<Result<(), Self::Error>> {
|
||||||
|
self.get_mut().sink.poll_ready_unpin(cx)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn start_send(
|
||||||
|
self: std::pin::Pin<&mut Self>,
|
||||||
|
item: ZCPacket,
|
||||||
|
) -> Result<(), Self::Error> {
|
||||||
|
let Some(item) = self.filter.before_send(item) else {
|
||||||
|
return Ok(());
|
||||||
|
};
|
||||||
|
self.get_mut().sink.start_send_unpin(item)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn poll_flush(
|
||||||
|
self: std::pin::Pin<&mut Self>,
|
||||||
|
cx: &mut Context<'_>,
|
||||||
|
) -> Poll<Result<(), Self::Error>> {
|
||||||
|
self.get_mut().sink.poll_flush_unpin(cx)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn poll_close(
|
||||||
|
self: std::pin::Pin<&mut Self>,
|
||||||
|
cx: &mut Context<'_>,
|
||||||
|
) -> Poll<Result<(), Self::Error>> {
|
||||||
|
self.get_mut().sink.poll_close_unpin(cx)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
SinkWrapper {
|
||||||
|
sink,
|
||||||
|
filter: self.filter.clone(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn wrap_stream<S: ZCPacketStream + Unpin + 'static>(&self, stream: S) -> impl ZCPacketStream {
|
||||||
|
struct StreamWrapper<F, S> {
|
||||||
|
stream: S,
|
||||||
|
filter: Arc<F>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<F, S> Stream for StreamWrapper<F, S>
|
||||||
|
where
|
||||||
|
F: TunnelFilter + 'static,
|
||||||
|
S: ZCPacketStream + 'static + Unpin,
|
||||||
|
{
|
||||||
|
type Item = StreamItem;
|
||||||
|
|
||||||
|
fn poll_next(
|
||||||
|
self: std::pin::Pin<&mut Self>,
|
||||||
|
cx: &mut Context<'_>,
|
||||||
|
) -> Poll<Option<Self::Item>> {
|
||||||
|
let self_mut = self.get_mut();
|
||||||
|
loop {
|
||||||
|
match self_mut.stream.poll_next_unpin(cx) {
|
||||||
|
Poll::Ready(Some(ret)) => {
|
||||||
|
let Some(ret) = self_mut.filter.after_received(ret) else {
|
||||||
|
continue;
|
||||||
|
};
|
||||||
|
return Poll::Ready(Some(ret));
|
||||||
|
}
|
||||||
|
Poll::Ready(None) => {
|
||||||
|
return Poll::Ready(None);
|
||||||
|
}
|
||||||
|
Poll::Pending => {
|
||||||
|
return Poll::Pending;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
StreamWrapper {
|
||||||
|
stream,
|
||||||
|
filter: self.filter.clone(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<T, F> Tunnel for TunnelWithFilter<T, F>
|
||||||
|
where
|
||||||
|
T: Tunnel + Send + 'static,
|
||||||
|
F: TunnelFilter + Send + 'static,
|
||||||
|
{
|
||||||
|
fn info(&self) -> Option<TunnelInfo> {
|
||||||
|
self.inner.info()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn split(&self) -> (Pin<Box<dyn ZCPacketStream>>, Pin<Box<dyn ZCPacketSink>>) {
|
||||||
|
let (stream, sink) = self.inner.split();
|
||||||
|
(
|
||||||
|
Box::pin(self.wrap_stream(stream)),
|
||||||
|
Box::pin(self.wrap_sink(sink)),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct PacketRecorderTunnelFilter {
|
||||||
|
pub received: Arc<std::sync::Mutex<Vec<ZCPacket>>>,
|
||||||
|
pub sent: Arc<std::sync::Mutex<Vec<ZCPacket>>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl TunnelFilter for PacketRecorderTunnelFilter {
|
||||||
|
type FilterOutput = (Vec<ZCPacket>, Vec<ZCPacket>);
|
||||||
|
|
||||||
|
fn before_send(&self, data: SinkItem) -> Option<SinkItem> {
|
||||||
|
self.received.lock().unwrap().push(data.clone());
|
||||||
|
Some(data)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn after_received(&self, data: StreamItem) -> Option<StreamItem> {
|
||||||
|
match data {
|
||||||
|
Ok(v) => {
|
||||||
|
self.sent.lock().unwrap().push(v.clone().into());
|
||||||
|
Some(Ok(v))
|
||||||
|
}
|
||||||
|
Err(e) => Some(Err(e)),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn filter_output(&self) -> Self::FilterOutput {
|
||||||
|
(
|
||||||
|
self.received.lock().unwrap().clone(),
|
||||||
|
self.sent.lock().unwrap().clone(),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl PacketRecorderTunnelFilter {
|
||||||
|
pub fn new() -> Self {
|
||||||
|
Self {
|
||||||
|
received: Arc::new(std::sync::Mutex::new(Vec::new())),
|
||||||
|
sent: Arc::new(std::sync::Mutex::new(Vec::new())),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct StatsRecorderTunnelFilter {
|
||||||
|
throughput: Arc<Throughput>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl TunnelFilter for StatsRecorderTunnelFilter {
|
||||||
|
type FilterOutput = Arc<Throughput>;
|
||||||
|
|
||||||
|
fn before_send(&self, data: SinkItem) -> Option<SinkItem> {
|
||||||
|
self.throughput.record_tx_bytes(data.buf_len() as u64);
|
||||||
|
Some(data)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn after_received(&self, data: StreamItem) -> Option<StreamItem> {
|
||||||
|
match data {
|
||||||
|
Ok(v) => {
|
||||||
|
self.throughput.record_rx_bytes(v.buf_len() as u64);
|
||||||
|
Some(Ok(v))
|
||||||
|
}
|
||||||
|
Err(e) => Some(Err(e)),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn filter_output(&self) -> Self::FilterOutput {
|
||||||
|
self.throughput.clone()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl StatsRecorderTunnelFilter {
|
||||||
|
pub fn new() -> Self {
|
||||||
|
Self {
|
||||||
|
throughput: Arc::new(Throughput::new()),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn get_throughput(&self) -> Arc<Throughput> {
|
||||||
|
self.throughput.clone()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
pub mod tests {
|
||||||
|
use std::sync::atomic::{AtomicU32, Ordering};
|
||||||
|
|
||||||
|
use filter::ring::create_ring_tunnel_pair;
|
||||||
|
|
||||||
|
use super::*;
|
||||||
|
|
||||||
|
pub struct DropSendTunnelFilter {
|
||||||
|
start: AtomicU32,
|
||||||
|
end: AtomicU32,
|
||||||
|
cur: AtomicU32,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl TunnelFilter for DropSendTunnelFilter {
|
||||||
|
type FilterOutput = ();
|
||||||
|
|
||||||
|
fn before_send(&self, data: SinkItem) -> Option<SinkItem> {
|
||||||
|
self.cur.fetch_add(1, Ordering::SeqCst);
|
||||||
|
if self.cur.load(Ordering::SeqCst) >= self.start.load(Ordering::SeqCst)
|
||||||
|
&& self.cur.load(std::sync::atomic::Ordering::SeqCst)
|
||||||
|
< self.end.load(Ordering::SeqCst)
|
||||||
|
{
|
||||||
|
tracing::trace!("drop packet: {:?}", data);
|
||||||
|
return None;
|
||||||
|
}
|
||||||
|
Some(data)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn filter_output(&self) {}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl DropSendTunnelFilter {
|
||||||
|
pub fn new(start: u32, end: u32) -> Self {
|
||||||
|
Self {
|
||||||
|
start: AtomicU32::new(start),
|
||||||
|
end: AtomicU32::new(end),
|
||||||
|
cur: AtomicU32::new(0),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_nested_filter() {
|
||||||
|
let filter = Arc::new(
|
||||||
|
PacketRecorderTunnelFilter::new()
|
||||||
|
.to_chain()
|
||||||
|
.chain(PacketRecorderTunnelFilter::new())
|
||||||
|
.chain(PacketRecorderTunnelFilter::new())
|
||||||
|
.chain(PacketRecorderTunnelFilter::new()),
|
||||||
|
);
|
||||||
|
let (s, _b) = create_ring_tunnel_pair();
|
||||||
|
let tunnel = TunnelWithFilter::new(s, filter.clone());
|
||||||
|
|
||||||
|
let (_r, mut s) = tunnel.split();
|
||||||
|
s.send(ZCPacket::new_with_payload("ab".as_bytes()))
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
let out = filter.filter_output();
|
||||||
|
|
||||||
|
let a = out.0 .0 .0 .1;
|
||||||
|
let b = out.0 .0 .1;
|
||||||
|
let c = out.0 .1;
|
||||||
|
let _d = out.1;
|
||||||
|
|
||||||
|
assert_eq!(1, a.0.len());
|
||||||
|
assert_eq!(1, b.0.len());
|
||||||
|
assert_eq!(1, c.0.len());
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,196 @@
|
|||||||
|
use std::{net::SocketAddr, pin::Pin, sync::Arc};
|
||||||
|
|
||||||
|
use async_trait::async_trait;
|
||||||
|
use futures::{Sink, Stream};
|
||||||
|
use std::fmt::Debug;
|
||||||
|
|
||||||
|
use tokio::time::error::Elapsed;
|
||||||
|
|
||||||
|
use crate::rpc::TunnelInfo;
|
||||||
|
|
||||||
|
use self::packet_def::ZCPacket;
|
||||||
|
|
||||||
|
pub mod buf;
|
||||||
|
pub mod common;
|
||||||
|
pub mod filter;
|
||||||
|
pub mod mpsc;
|
||||||
|
pub mod packet_def;
|
||||||
|
pub mod quic;
|
||||||
|
pub mod ring;
|
||||||
|
pub mod stats;
|
||||||
|
pub mod tcp;
|
||||||
|
pub mod udp;
|
||||||
|
pub mod wireguard;
|
||||||
|
|
||||||
|
#[derive(thiserror::Error, Debug)]
|
||||||
|
pub enum TunnelError {
|
||||||
|
#[error("io error")]
|
||||||
|
IOError(#[from] std::io::Error),
|
||||||
|
#[error("invalid packet. msg: {0}")]
|
||||||
|
InvalidPacket(String),
|
||||||
|
#[error("exceed max packet size. max: {0}, input: {1}")]
|
||||||
|
ExceedMaxPacketSize(usize, usize),
|
||||||
|
|
||||||
|
#[error("invalid protocol: {0}")]
|
||||||
|
InvalidProtocol(String),
|
||||||
|
#[error("invalid addr: {0}")]
|
||||||
|
InvalidAddr(String),
|
||||||
|
|
||||||
|
#[error("internal error {0}")]
|
||||||
|
InternalError(String),
|
||||||
|
|
||||||
|
#[error("conn id not match, expect: {0}, actual: {1}")]
|
||||||
|
ConnIdNotMatch(u32, u32),
|
||||||
|
#[error("buffer full")]
|
||||||
|
BufferFull,
|
||||||
|
|
||||||
|
#[error("timeout")]
|
||||||
|
Timeout(#[from] Elapsed),
|
||||||
|
|
||||||
|
#[error("anyhow error: {0}")]
|
||||||
|
Anyhow(#[from] anyhow::Error),
|
||||||
|
|
||||||
|
#[error("shutdown")]
|
||||||
|
Shutdown,
|
||||||
|
}
|
||||||
|
|
||||||
|
pub type StreamT = packet_def::ZCPacket;
|
||||||
|
pub type StreamItem = Result<StreamT, TunnelError>;
|
||||||
|
pub type SinkItem = packet_def::ZCPacket;
|
||||||
|
pub type SinkError = TunnelError;
|
||||||
|
|
||||||
|
pub trait ZCPacketStream: Stream<Item = StreamItem> + Send {}
|
||||||
|
impl<T> ZCPacketStream for T where T: Stream<Item = StreamItem> + Send {}
|
||||||
|
pub trait ZCPacketSink: Sink<SinkItem, Error = SinkError> + Send {}
|
||||||
|
impl<T> ZCPacketSink for T where T: Sink<SinkItem, Error = SinkError> + Send {}
|
||||||
|
|
||||||
|
#[auto_impl::auto_impl(Box, Arc)]
|
||||||
|
pub trait Tunnel: Send {
|
||||||
|
fn split(&self) -> (Pin<Box<dyn ZCPacketStream>>, Pin<Box<dyn ZCPacketSink>>);
|
||||||
|
fn info(&self) -> Option<TunnelInfo>;
|
||||||
|
}
|
||||||
|
|
||||||
|
#[auto_impl::auto_impl(Arc)]
|
||||||
|
pub trait TunnelConnCounter: 'static + Send + Sync + Debug {
|
||||||
|
fn get(&self) -> u32;
|
||||||
|
}
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
#[auto_impl::auto_impl(Box)]
|
||||||
|
pub trait TunnelListener: Send {
|
||||||
|
async fn listen(&mut self) -> Result<(), TunnelError>;
|
||||||
|
async fn accept(&mut self) -> Result<Box<dyn Tunnel>, TunnelError>;
|
||||||
|
fn local_url(&self) -> url::Url;
|
||||||
|
fn get_conn_counter(&self) -> Arc<Box<dyn TunnelConnCounter>> {
|
||||||
|
#[derive(Debug)]
|
||||||
|
struct FakeTunnelConnCounter {}
|
||||||
|
impl TunnelConnCounter for FakeTunnelConnCounter {
|
||||||
|
fn get(&self) -> u32 {
|
||||||
|
0
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Arc::new(Box::new(FakeTunnelConnCounter {}))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
#[auto_impl::auto_impl(Box)]
|
||||||
|
pub trait TunnelConnector: Send {
|
||||||
|
async fn connect(&mut self) -> Result<Box<dyn Tunnel>, TunnelError>;
|
||||||
|
fn remote_url(&self) -> url::Url;
|
||||||
|
fn set_bind_addrs(&mut self, _addrs: Vec<SocketAddr>) {}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn build_url_from_socket_addr(addr: &String, scheme: &str) -> url::Url {
|
||||||
|
url::Url::parse(format!("{}://{}", scheme, addr).as_str()).unwrap()
|
||||||
|
}
|
||||||
|
|
||||||
|
impl std::fmt::Debug for dyn Tunnel {
|
||||||
|
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||||
|
f.debug_struct("Tunnel")
|
||||||
|
.field("info", &self.info())
|
||||||
|
.finish()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl std::fmt::Debug for dyn TunnelConnector {
|
||||||
|
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||||
|
f.debug_struct("TunnelConnector")
|
||||||
|
.field("remote_url", &self.remote_url())
|
||||||
|
.finish()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl std::fmt::Debug for dyn TunnelListener {
|
||||||
|
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||||
|
f.debug_struct("TunnelListener")
|
||||||
|
.field("local_url", &self.local_url())
|
||||||
|
.finish()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(crate) trait FromUrl {
|
||||||
|
fn from_url(url: url::Url) -> Result<Self, TunnelError>
|
||||||
|
where
|
||||||
|
Self: Sized;
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(crate) fn check_scheme_and_get_socket_addr<T>(
|
||||||
|
url: &url::Url,
|
||||||
|
scheme: &str,
|
||||||
|
) -> Result<T, TunnelError>
|
||||||
|
where
|
||||||
|
T: FromUrl,
|
||||||
|
{
|
||||||
|
if url.scheme() != scheme {
|
||||||
|
return Err(TunnelError::InvalidProtocol(url.scheme().to_string()));
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(T::from_url(url.clone())?)
|
||||||
|
}
|
||||||
|
|
||||||
|
impl FromUrl for SocketAddr {
|
||||||
|
fn from_url(url: url::Url) -> Result<Self, TunnelError> {
|
||||||
|
Ok(url.socket_addrs(|| None)?.pop().unwrap())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl FromUrl for uuid::Uuid {
|
||||||
|
fn from_url(url: url::Url) -> Result<Self, TunnelError> {
|
||||||
|
let o = url.host_str().unwrap();
|
||||||
|
let o = uuid::Uuid::parse_str(o).map_err(|e| TunnelError::InvalidAddr(e.to_string()))?;
|
||||||
|
Ok(o)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct TunnelUrl {
|
||||||
|
inner: url::Url,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<url::Url> for TunnelUrl {
|
||||||
|
fn from(url: url::Url) -> Self {
|
||||||
|
TunnelUrl { inner: url }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<TunnelUrl> for url::Url {
|
||||||
|
fn from(url: TunnelUrl) -> Self {
|
||||||
|
url.into_inner()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl TunnelUrl {
|
||||||
|
pub fn into_inner(self) -> url::Url {
|
||||||
|
self.inner
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn bind_dev(&self) -> Option<String> {
|
||||||
|
self.inner.path().strip_prefix("/").and_then(|s| {
|
||||||
|
if s.is_empty() {
|
||||||
|
None
|
||||||
|
} else {
|
||||||
|
Some(String::from_utf8(percent_encoding::percent_decode_str(&s).collect()).unwrap())
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,180 @@
|
|||||||
|
// this mod wrap tunnel to a mpsc tunnel, based on crossbeam_channel
|
||||||
|
|
||||||
|
use std::pin::Pin;
|
||||||
|
|
||||||
|
use anyhow::Context;
|
||||||
|
use tokio::task::JoinHandle;
|
||||||
|
|
||||||
|
use super::{packet_def::ZCPacket, Tunnel, TunnelError, ZCPacketSink, ZCPacketStream};
|
||||||
|
|
||||||
|
use tachyonix::{channel, Receiver, Sender};
|
||||||
|
|
||||||
|
use futures::SinkExt;
|
||||||
|
|
||||||
|
#[derive(Clone)]
|
||||||
|
pub struct MpscTunnelSender(Sender<ZCPacket>);
|
||||||
|
|
||||||
|
impl MpscTunnelSender {
|
||||||
|
pub async fn send(&self, item: ZCPacket) -> Result<(), TunnelError> {
|
||||||
|
self.0.send(item).await.with_context(|| "send error")?;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct MpscTunnel<T> {
|
||||||
|
tx: Sender<ZCPacket>,
|
||||||
|
|
||||||
|
tunnel: T,
|
||||||
|
stream: Option<Pin<Box<dyn ZCPacketStream>>>,
|
||||||
|
|
||||||
|
task: Option<JoinHandle<()>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<T: Tunnel> MpscTunnel<T> {
|
||||||
|
pub fn new(tunnel: T) -> Self {
|
||||||
|
let (tx, mut rx) = channel(32);
|
||||||
|
let (stream, mut sink) = tunnel.split();
|
||||||
|
|
||||||
|
let task = tokio::spawn(async move {
|
||||||
|
loop {
|
||||||
|
if let Err(e) = Self::forward_one_round(&mut rx, &mut sink).await {
|
||||||
|
tracing::error!(?e, "forward error");
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
Self {
|
||||||
|
tx,
|
||||||
|
tunnel,
|
||||||
|
stream: Some(stream),
|
||||||
|
task: Some(task),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn forward_one_round(
|
||||||
|
rx: &mut Receiver<ZCPacket>,
|
||||||
|
sink: &mut Pin<Box<dyn ZCPacketSink>>,
|
||||||
|
) -> Result<(), TunnelError> {
|
||||||
|
let item = rx.recv().await.with_context(|| "recv error")?;
|
||||||
|
sink.feed(item).await?;
|
||||||
|
while let Ok(item) = rx.try_recv() {
|
||||||
|
if let Err(e) = sink.feed(item).await {
|
||||||
|
tracing::error!(?e, "feed error");
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
sink.flush().await
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn get_stream(&mut self) -> Pin<Box<dyn ZCPacketStream>> {
|
||||||
|
self.stream.take().unwrap()
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn get_sink(&self) -> MpscTunnelSender {
|
||||||
|
MpscTunnelSender(self.tx.clone())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<T: Tunnel> From<T> for MpscTunnel<T> {
|
||||||
|
fn from(tunnel: T) -> Self {
|
||||||
|
Self::new(tunnel)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use futures::StreamExt;
|
||||||
|
|
||||||
|
use crate::tunnel::{
|
||||||
|
tcp::{TcpTunnelConnector, TcpTunnelListener},
|
||||||
|
TunnelConnector, TunnelListener,
|
||||||
|
};
|
||||||
|
|
||||||
|
use super::*;
|
||||||
|
// test slow send lock in framed tunnel
|
||||||
|
#[tokio::test]
|
||||||
|
async fn mpsc_slow_receiver() {
|
||||||
|
let mut listener = TcpTunnelListener::new("tcp://127.0.0.1:11014".parse().unwrap());
|
||||||
|
let mut connector = TcpTunnelConnector::new("tcp://127.0.0.1:11014".parse().unwrap());
|
||||||
|
|
||||||
|
listener.listen().await.unwrap();
|
||||||
|
let t1 = tokio::spawn(async move {
|
||||||
|
let t = listener.accept().await.unwrap();
|
||||||
|
let (mut stream, _sink) = t.split();
|
||||||
|
let now = tokio::time::Instant::now();
|
||||||
|
|
||||||
|
let mut a_counter = 0;
|
||||||
|
let mut b_counter = 0;
|
||||||
|
|
||||||
|
while let Some(Ok(msg)) = stream.next().await {
|
||||||
|
tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
|
||||||
|
if now.elapsed().as_secs() > 5 {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
if msg.payload() == "hello".as_bytes() {
|
||||||
|
a_counter += 1;
|
||||||
|
} else if msg.payload() == "hello2".as_bytes() {
|
||||||
|
b_counter += 1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
tracing::info!("t1 exit");
|
||||||
|
assert_ne!(a_counter, 0);
|
||||||
|
assert_ne!(b_counter, 0);
|
||||||
|
});
|
||||||
|
|
||||||
|
let tunnel = connector.connect().await.unwrap();
|
||||||
|
let mpsc_tunnel = MpscTunnel::from(tunnel);
|
||||||
|
|
||||||
|
let sink1 = mpsc_tunnel.get_sink();
|
||||||
|
let t2 = tokio::spawn(async move {
|
||||||
|
for i in 0..1000000 {
|
||||||
|
tokio::time::sleep(tokio::time::Duration::from_millis(50)).await;
|
||||||
|
let a = sink1
|
||||||
|
.send(ZCPacket::new_with_payload("hello".as_bytes()))
|
||||||
|
.await;
|
||||||
|
if a.is_err() {
|
||||||
|
tracing::info!(?a, "t2 exit with err");
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
if i % 5000 == 0 {
|
||||||
|
tracing::info!(i, "send2 1000");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
tracing::info!("t2 exit");
|
||||||
|
});
|
||||||
|
|
||||||
|
let sink2 = mpsc_tunnel.get_sink();
|
||||||
|
let t3 = tokio::spawn(async move {
|
||||||
|
for i in 0..1000000 {
|
||||||
|
tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
|
||||||
|
let a = sink2
|
||||||
|
.send(ZCPacket::new_with_payload("hello2".as_bytes()))
|
||||||
|
.await;
|
||||||
|
if a.is_err() {
|
||||||
|
tracing::info!(?a, "t3 exit with err");
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
if i % 5000 == 0 {
|
||||||
|
tracing::info!(i, "send2 1000");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
tracing::info!("t3 exit");
|
||||||
|
});
|
||||||
|
|
||||||
|
let t4 = tokio::spawn(async move {
|
||||||
|
tokio::time::sleep(tokio::time::Duration::from_secs(5)).await;
|
||||||
|
tracing::info!("closing");
|
||||||
|
drop(mpsc_tunnel);
|
||||||
|
tracing::info!("closed");
|
||||||
|
});
|
||||||
|
|
||||||
|
let _ = tokio::join!(t1, t2, t3, t4);
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,340 @@
|
|||||||
|
use bytes::Bytes;
|
||||||
|
use bytes::BytesMut;
|
||||||
|
use zerocopy::byteorder::*;
|
||||||
|
use zerocopy::AsBytes;
|
||||||
|
use zerocopy::FromBytes;
|
||||||
|
use zerocopy::FromZeroes;
|
||||||
|
|
||||||
|
type DefaultEndian = LittleEndian;
|
||||||
|
|
||||||
|
// TCP TunnelHeader
|
||||||
|
#[repr(C, packed)]
|
||||||
|
#[derive(AsBytes, FromBytes, FromZeroes, Clone, Debug, Default)]
|
||||||
|
pub struct TCPTunnelHeader {
|
||||||
|
pub len: U32<DefaultEndian>,
|
||||||
|
}
|
||||||
|
pub const TCP_TUNNEL_HEADER_SIZE: usize = std::mem::size_of::<TCPTunnelHeader>();
|
||||||
|
|
||||||
|
#[derive(AsBytes, FromZeroes, Clone, Debug)]
|
||||||
|
#[repr(u8)]
|
||||||
|
pub enum UdpPacketType {
|
||||||
|
Invalid = 0,
|
||||||
|
Syn = 1,
|
||||||
|
Sack = 2,
|
||||||
|
Data = 3,
|
||||||
|
Fin = 4,
|
||||||
|
HolePunch = 5,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[repr(C, packed)]
|
||||||
|
#[derive(AsBytes, FromBytes, FromZeroes, Clone, Debug, Default)]
|
||||||
|
pub struct UDPTunnelHeader {
|
||||||
|
pub conn_id: U32<DefaultEndian>,
|
||||||
|
pub msg_type: u8,
|
||||||
|
pub padding: u8,
|
||||||
|
pub len: U16<DefaultEndian>,
|
||||||
|
}
|
||||||
|
pub const UDP_TUNNEL_HEADER_SIZE: usize = std::mem::size_of::<UDPTunnelHeader>();
|
||||||
|
|
||||||
|
#[repr(C, packed)]
|
||||||
|
#[derive(AsBytes, FromBytes, FromZeroes, Clone, Debug, Default)]
|
||||||
|
pub struct WGTunnelHeader {
|
||||||
|
pub ipv4_header: [u8; 20],
|
||||||
|
}
|
||||||
|
pub const WG_TUNNEL_HEADER_SIZE: usize = std::mem::size_of::<WGTunnelHeader>();
|
||||||
|
|
||||||
|
#[derive(AsBytes, FromZeroes, Clone, Debug)]
|
||||||
|
#[repr(u8)]
|
||||||
|
pub enum PacketType {
|
||||||
|
Invalid = 0,
|
||||||
|
Data = 1,
|
||||||
|
HandShake = 2,
|
||||||
|
RoutePacket = 3,
|
||||||
|
Ping = 4,
|
||||||
|
Pong = 5,
|
||||||
|
TaRpc = 6,
|
||||||
|
Route = 7,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[repr(C, packed)]
|
||||||
|
#[derive(AsBytes, FromBytes, FromZeroes, Clone, Debug, Default)]
|
||||||
|
pub struct PeerManagerHeader {
|
||||||
|
pub from_peer_id: U32<DefaultEndian>,
|
||||||
|
pub to_peer_id: U32<DefaultEndian>,
|
||||||
|
pub packet_type: u8,
|
||||||
|
pub len: U32<DefaultEndian>,
|
||||||
|
}
|
||||||
|
pub const PEER_MANAGER_HEADER_SIZE: usize = std::mem::size_of::<PeerManagerHeader>();
|
||||||
|
|
||||||
|
const fn max(a: usize, b: usize) -> usize {
|
||||||
|
[a, b][(a < b) as usize]
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Default, Debug)]
|
||||||
|
pub struct ZCPacketOffsets {
|
||||||
|
pub payload_offset: usize,
|
||||||
|
pub peer_manager_header_offset: usize,
|
||||||
|
pub tcp_tunnel_header_offset: usize,
|
||||||
|
pub udp_tunnel_header_offset: usize,
|
||||||
|
pub wg_tunnel_header_offset: usize,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Copy, PartialEq)]
|
||||||
|
pub enum ZCPacketType {
|
||||||
|
// received from peer tcp connection
|
||||||
|
TCP,
|
||||||
|
// received from peer udp connection
|
||||||
|
UDP,
|
||||||
|
// received from peer wireguard connection
|
||||||
|
WG,
|
||||||
|
// received from local tun device, should reserve header space for tcp or udp tunnel
|
||||||
|
NIC,
|
||||||
|
}
|
||||||
|
|
||||||
|
const PAYLOAD_OFFSET_FOR_NIC_PACKET: usize = max(
|
||||||
|
max(TCP_TUNNEL_HEADER_SIZE, UDP_TUNNEL_HEADER_SIZE),
|
||||||
|
WG_TUNNEL_HEADER_SIZE,
|
||||||
|
) + PEER_MANAGER_HEADER_SIZE;
|
||||||
|
|
||||||
|
impl ZCPacketType {
|
||||||
|
pub fn get_packet_offsets(&self) -> ZCPacketOffsets {
|
||||||
|
match self {
|
||||||
|
ZCPacketType::TCP => ZCPacketOffsets {
|
||||||
|
payload_offset: TCP_TUNNEL_HEADER_SIZE + PEER_MANAGER_HEADER_SIZE,
|
||||||
|
peer_manager_header_offset: TCP_TUNNEL_HEADER_SIZE,
|
||||||
|
..Default::default()
|
||||||
|
},
|
||||||
|
ZCPacketType::UDP => ZCPacketOffsets {
|
||||||
|
payload_offset: UDP_TUNNEL_HEADER_SIZE + PEER_MANAGER_HEADER_SIZE,
|
||||||
|
peer_manager_header_offset: UDP_TUNNEL_HEADER_SIZE,
|
||||||
|
..Default::default()
|
||||||
|
},
|
||||||
|
ZCPacketType::WG => ZCPacketOffsets {
|
||||||
|
payload_offset: WG_TUNNEL_HEADER_SIZE + PEER_MANAGER_HEADER_SIZE,
|
||||||
|
peer_manager_header_offset: WG_TUNNEL_HEADER_SIZE,
|
||||||
|
..Default::default()
|
||||||
|
},
|
||||||
|
ZCPacketType::NIC => ZCPacketOffsets {
|
||||||
|
payload_offset: PAYLOAD_OFFSET_FOR_NIC_PACKET,
|
||||||
|
peer_manager_header_offset: PAYLOAD_OFFSET_FOR_NIC_PACKET
|
||||||
|
- PEER_MANAGER_HEADER_SIZE,
|
||||||
|
tcp_tunnel_header_offset: PAYLOAD_OFFSET_FOR_NIC_PACKET
|
||||||
|
- PEER_MANAGER_HEADER_SIZE
|
||||||
|
- TCP_TUNNEL_HEADER_SIZE,
|
||||||
|
udp_tunnel_header_offset: PAYLOAD_OFFSET_FOR_NIC_PACKET
|
||||||
|
- PEER_MANAGER_HEADER_SIZE
|
||||||
|
- UDP_TUNNEL_HEADER_SIZE,
|
||||||
|
wg_tunnel_header_offset: PAYLOAD_OFFSET_FOR_NIC_PACKET
|
||||||
|
- PEER_MANAGER_HEADER_SIZE
|
||||||
|
- WG_TUNNEL_HEADER_SIZE,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub struct ZCPacket {
|
||||||
|
inner: BytesMut,
|
||||||
|
packet_type: ZCPacketType,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ZCPacket {
|
||||||
|
pub fn new_nic_packet() -> Self {
|
||||||
|
Self {
|
||||||
|
inner: BytesMut::new(),
|
||||||
|
packet_type: ZCPacketType::NIC,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn new_from_buf(buf: BytesMut, packet_type: ZCPacketType) -> Self {
|
||||||
|
Self {
|
||||||
|
inner: buf,
|
||||||
|
packet_type,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn new_with_payload(payload: &[u8]) -> Self {
|
||||||
|
let mut ret = Self::new_nic_packet();
|
||||||
|
let total_len = ret.packet_type.get_packet_offsets().payload_offset + payload.len();
|
||||||
|
ret.inner.resize(total_len, 0);
|
||||||
|
ret.mut_payload()[..payload.len()].copy_from_slice(&payload);
|
||||||
|
ret
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn packet_type(&self) -> ZCPacketType {
|
||||||
|
self.packet_type
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn mut_payload(&mut self) -> &mut [u8] {
|
||||||
|
&mut self.inner[self.packet_type.get_packet_offsets().payload_offset..]
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn mut_peer_manager_header(&mut self) -> Option<&mut PeerManagerHeader> {
|
||||||
|
PeerManagerHeader::mut_from_prefix(
|
||||||
|
&mut self.inner[self
|
||||||
|
.packet_type
|
||||||
|
.get_packet_offsets()
|
||||||
|
.peer_manager_header_offset..],
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn mut_tcp_tunnel_header(&mut self) -> Option<&mut TCPTunnelHeader> {
|
||||||
|
TCPTunnelHeader::mut_from_prefix(
|
||||||
|
&mut self.inner[self
|
||||||
|
.packet_type
|
||||||
|
.get_packet_offsets()
|
||||||
|
.tcp_tunnel_header_offset..],
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn mut_udp_tunnel_header(&mut self) -> Option<&mut UDPTunnelHeader> {
|
||||||
|
UDPTunnelHeader::mut_from_prefix(
|
||||||
|
&mut self.inner[self
|
||||||
|
.packet_type
|
||||||
|
.get_packet_offsets()
|
||||||
|
.udp_tunnel_header_offset..],
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn mut_wg_tunnel_header(&mut self) -> Option<&mut WGTunnelHeader> {
|
||||||
|
WGTunnelHeader::mut_from_prefix(
|
||||||
|
&mut self.inner[self
|
||||||
|
.packet_type
|
||||||
|
.get_packet_offsets()
|
||||||
|
.wg_tunnel_header_offset..],
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ref versions
|
||||||
|
pub fn payload(&self) -> &[u8] {
|
||||||
|
&self.inner[self.packet_type.get_packet_offsets().payload_offset..]
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn peer_manager_header(&self) -> Option<&PeerManagerHeader> {
|
||||||
|
PeerManagerHeader::ref_from_prefix(
|
||||||
|
&self.inner[self
|
||||||
|
.packet_type
|
||||||
|
.get_packet_offsets()
|
||||||
|
.peer_manager_header_offset..],
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn tcp_tunnel_header(&self) -> Option<&TCPTunnelHeader> {
|
||||||
|
TCPTunnelHeader::ref_from_prefix(
|
||||||
|
&self.inner[self
|
||||||
|
.packet_type
|
||||||
|
.get_packet_offsets()
|
||||||
|
.tcp_tunnel_header_offset..],
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn udp_tunnel_header(&self) -> Option<&UDPTunnelHeader> {
|
||||||
|
UDPTunnelHeader::ref_from_prefix(
|
||||||
|
&self.inner[self
|
||||||
|
.packet_type
|
||||||
|
.get_packet_offsets()
|
||||||
|
.udp_tunnel_header_offset..],
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn udp_payload(&self) -> &[u8] {
|
||||||
|
&self.inner[self
|
||||||
|
.packet_type
|
||||||
|
.get_packet_offsets()
|
||||||
|
.udp_tunnel_header_offset
|
||||||
|
+ UDP_TUNNEL_HEADER_SIZE..]
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn payload_len(&self) -> usize {
|
||||||
|
let payload_offset = self.packet_type.get_packet_offsets().payload_offset;
|
||||||
|
self.inner.len() - payload_offset
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn buf_len(&self) -> usize {
|
||||||
|
self.inner.len()
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn fill_peer_manager_hdr(&mut self, from_peer_id: u32, to_peer_id: u32, packet_type: u8) {
|
||||||
|
let payload_len = self.payload_len();
|
||||||
|
let hdr = self.mut_peer_manager_header().unwrap();
|
||||||
|
hdr.from_peer_id.set(from_peer_id);
|
||||||
|
hdr.to_peer_id.set(to_peer_id);
|
||||||
|
hdr.packet_type = packet_type;
|
||||||
|
hdr.len.set(payload_len as u32);
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn into_bytes(mut self, target_packet_type: ZCPacketType) -> Bytes {
|
||||||
|
if target_packet_type == self.packet_type {
|
||||||
|
return self.inner.freeze();
|
||||||
|
} else {
|
||||||
|
assert_eq!(
|
||||||
|
self.packet_type,
|
||||||
|
ZCPacketType::NIC,
|
||||||
|
"only support NIC, got {:?}",
|
||||||
|
self
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
match target_packet_type {
|
||||||
|
ZCPacketType::TCP => self
|
||||||
|
.inner
|
||||||
|
.split_off(
|
||||||
|
self.packet_type
|
||||||
|
.get_packet_offsets()
|
||||||
|
.tcp_tunnel_header_offset,
|
||||||
|
)
|
||||||
|
.freeze(),
|
||||||
|
ZCPacketType::UDP => self
|
||||||
|
.inner
|
||||||
|
.split_off(
|
||||||
|
self.packet_type
|
||||||
|
.get_packet_offsets()
|
||||||
|
.udp_tunnel_header_offset,
|
||||||
|
)
|
||||||
|
.freeze(),
|
||||||
|
ZCPacketType::WG => self
|
||||||
|
.inner
|
||||||
|
.split_off(
|
||||||
|
self.packet_type
|
||||||
|
.get_packet_offsets()
|
||||||
|
.wg_tunnel_header_offset,
|
||||||
|
)
|
||||||
|
.freeze(),
|
||||||
|
ZCPacketType::NIC => unreachable!(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn inner(self) -> BytesMut {
|
||||||
|
self.inner
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_zc_packet() {
|
||||||
|
let payload = b"hello world";
|
||||||
|
let mut packet = ZCPacket::new_with_payload(payload);
|
||||||
|
let peer_manager_header = packet.mut_peer_manager_header().unwrap();
|
||||||
|
peer_manager_header.packet_type = PacketType::Data as u8;
|
||||||
|
peer_manager_header.len.set(payload.len() as u32);
|
||||||
|
|
||||||
|
let tcp_tunnel_header = packet.mut_tcp_tunnel_header().unwrap();
|
||||||
|
tcp_tunnel_header.len.set(payload.len() as u32);
|
||||||
|
|
||||||
|
// let udp_tunnel_header = packet.mut_udp_tunnel_header().unwrap();
|
||||||
|
// udp_tunnel_header.conn_id = 1;
|
||||||
|
// udp_tunnel_header.msg_type = 2;
|
||||||
|
// udp_tunnel_header.len = payload.len() as u32;
|
||||||
|
|
||||||
|
assert_eq!(packet.payload(), b"hello world");
|
||||||
|
assert_eq!(packet.payload_len(), 11);
|
||||||
|
println!("{:?}", packet.inner);
|
||||||
|
|
||||||
|
let tcp_packet = packet.into_bytes(ZCPacketType::TCP);
|
||||||
|
assert_eq!(&tcp_packet[..1], b"\x0b");
|
||||||
|
println!("{:?}", tcp_packet);
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,226 @@
|
|||||||
|
//! This example demonstrates how to make a QUIC connection that ignores the server certificate.
|
||||||
|
//!
|
||||||
|
//! Checkout the `README.md` for guidance.
|
||||||
|
|
||||||
|
use std::{error::Error, net::SocketAddr, sync::Arc};
|
||||||
|
|
||||||
|
use crate::{
|
||||||
|
rpc::TunnelInfo,
|
||||||
|
tunnel::common::{FramedReader, FramedWriter, TunnelWrapper},
|
||||||
|
};
|
||||||
|
use anyhow::Context;
|
||||||
|
use quinn::{ClientConfig, Connection, Endpoint, ServerConfig};
|
||||||
|
|
||||||
|
use super::{
|
||||||
|
check_scheme_and_get_socket_addr, Tunnel, TunnelConnector, TunnelError, TunnelListener,
|
||||||
|
};
|
||||||
|
|
||||||
|
/// Dummy certificate verifier that treats any certificate as valid.
|
||||||
|
/// NOTE, such verification is vulnerable to MITM attacks, but convenient for testing.
|
||||||
|
struct SkipServerVerification;
|
||||||
|
|
||||||
|
impl SkipServerVerification {
|
||||||
|
fn new() -> Arc<Self> {
|
||||||
|
Arc::new(Self)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl rustls::client::ServerCertVerifier for SkipServerVerification {
|
||||||
|
fn verify_server_cert(
|
||||||
|
&self,
|
||||||
|
_end_entity: &rustls::Certificate,
|
||||||
|
_intermediates: &[rustls::Certificate],
|
||||||
|
_server_name: &rustls::ServerName,
|
||||||
|
_scts: &mut dyn Iterator<Item = &[u8]>,
|
||||||
|
_ocsp_response: &[u8],
|
||||||
|
_now: std::time::SystemTime,
|
||||||
|
) -> Result<rustls::client::ServerCertVerified, rustls::Error> {
|
||||||
|
Ok(rustls::client::ServerCertVerified::assertion())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn configure_client() -> ClientConfig {
|
||||||
|
let crypto = rustls::ClientConfig::builder()
|
||||||
|
.with_safe_defaults()
|
||||||
|
.with_custom_certificate_verifier(SkipServerVerification::new())
|
||||||
|
.with_no_client_auth();
|
||||||
|
|
||||||
|
ClientConfig::new(Arc::new(crypto))
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Constructs a QUIC endpoint configured to listen for incoming connections on a certain address
|
||||||
|
/// and port.
|
||||||
|
///
|
||||||
|
/// ## Returns
|
||||||
|
///
|
||||||
|
/// - a stream of incoming QUIC connections
|
||||||
|
/// - server certificate serialized into DER format
|
||||||
|
#[allow(unused)]
|
||||||
|
pub fn make_server_endpoint(bind_addr: SocketAddr) -> Result<(Endpoint, Vec<u8>), Box<dyn Error>> {
|
||||||
|
let (server_config, server_cert) = configure_server()?;
|
||||||
|
let endpoint = Endpoint::server(server_config, bind_addr)?;
|
||||||
|
Ok((endpoint, server_cert))
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Returns default server configuration along with its certificate.
|
||||||
|
fn configure_server() -> Result<(ServerConfig, Vec<u8>), Box<dyn Error>> {
|
||||||
|
let cert = rcgen::generate_simple_self_signed(vec!["localhost".into()]).unwrap();
|
||||||
|
let cert_der = cert.serialize_der().unwrap();
|
||||||
|
let priv_key = cert.serialize_private_key_der();
|
||||||
|
let priv_key = rustls::PrivateKey(priv_key);
|
||||||
|
let cert_chain = vec![rustls::Certificate(cert_der.clone())];
|
||||||
|
|
||||||
|
let mut server_config = ServerConfig::with_single_cert(cert_chain, priv_key)?;
|
||||||
|
let transport_config = Arc::get_mut(&mut server_config.transport).unwrap();
|
||||||
|
transport_config.max_concurrent_uni_streams(10_u8.into());
|
||||||
|
transport_config.max_concurrent_bidi_streams(10_u8.into());
|
||||||
|
|
||||||
|
Ok((server_config, cert_der))
|
||||||
|
}
|
||||||
|
|
||||||
|
#[allow(unused)]
|
||||||
|
pub const ALPN_QUIC_HTTP: &[&[u8]] = &[b"hq-29"];
|
||||||
|
|
||||||
|
/// Runs a QUIC server bound to given address.
|
||||||
|
|
||||||
|
struct ConnWrapper {
|
||||||
|
conn: Connection,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Drop for ConnWrapper {
|
||||||
|
fn drop(&mut self) {
|
||||||
|
self.conn.close(0u32.into(), b"done");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct QUICTunnelListener {
|
||||||
|
addr: url::Url,
|
||||||
|
endpoint: Option<Endpoint>,
|
||||||
|
server_cert: Option<Vec<u8>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl QUICTunnelListener {
|
||||||
|
pub fn new(addr: url::Url) -> Self {
|
||||||
|
QUICTunnelListener {
|
||||||
|
addr,
|
||||||
|
endpoint: None,
|
||||||
|
server_cert: None,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[async_trait::async_trait]
|
||||||
|
impl TunnelListener for QUICTunnelListener {
|
||||||
|
async fn listen(&mut self) -> Result<(), TunnelError> {
|
||||||
|
let addr = check_scheme_and_get_socket_addr::<SocketAddr>(&self.addr, "quic")?;
|
||||||
|
let (endpoint, server_cert) = make_server_endpoint(addr).unwrap();
|
||||||
|
self.endpoint = Some(endpoint);
|
||||||
|
self.server_cert = Some(server_cert);
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn accept(&mut self) -> Result<Box<dyn Tunnel>, super::TunnelError> {
|
||||||
|
// accept a single connection
|
||||||
|
let incoming_conn = self.endpoint.as_ref().unwrap().accept().await.unwrap();
|
||||||
|
let conn = incoming_conn.await.unwrap();
|
||||||
|
println!(
|
||||||
|
"[server] connection accepted: addr={}",
|
||||||
|
conn.remote_address()
|
||||||
|
);
|
||||||
|
let remote_addr = conn.remote_address();
|
||||||
|
let (w, r) = conn.accept_bi().await.with_context(|| "accept_bi failed")?;
|
||||||
|
|
||||||
|
let arc_conn = Arc::new(ConnWrapper { conn });
|
||||||
|
|
||||||
|
let info = TunnelInfo {
|
||||||
|
tunnel_type: "quic".to_owned(),
|
||||||
|
local_addr: self.local_url().into(),
|
||||||
|
remote_addr: super::build_url_from_socket_addr(&remote_addr.to_string(), "quic").into(),
|
||||||
|
};
|
||||||
|
|
||||||
|
Ok(Box::new(TunnelWrapper::new(
|
||||||
|
FramedReader::new_with_associate_data(r, 4500, Some(Box::new(arc_conn.clone()))),
|
||||||
|
FramedWriter::new_with_associate_data(w, Some(Box::new(arc_conn))),
|
||||||
|
Some(info),
|
||||||
|
)))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn local_url(&self) -> url::Url {
|
||||||
|
self.addr.clone()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct QUICTunnelConnector {
|
||||||
|
addr: url::Url,
|
||||||
|
endpoint: Option<Endpoint>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl QUICTunnelConnector {
|
||||||
|
pub fn new(addr: url::Url) -> Self {
|
||||||
|
QUICTunnelConnector {
|
||||||
|
addr,
|
||||||
|
endpoint: None,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[async_trait::async_trait]
|
||||||
|
impl TunnelConnector for QUICTunnelConnector {
|
||||||
|
async fn connect(&mut self) -> Result<Box<dyn Tunnel>, super::TunnelError> {
|
||||||
|
let addr = check_scheme_and_get_socket_addr::<SocketAddr>(&self.addr, "quic")?;
|
||||||
|
|
||||||
|
let mut endpoint = Endpoint::client("127.0.0.1:0".parse().unwrap())?;
|
||||||
|
endpoint.set_default_client_config(configure_client());
|
||||||
|
|
||||||
|
// connect to server
|
||||||
|
let connection = endpoint.connect(addr, "localhost").unwrap().await.unwrap();
|
||||||
|
println!("[client] connected: addr={}", connection.remote_address());
|
||||||
|
|
||||||
|
let local_addr = endpoint.local_addr().unwrap();
|
||||||
|
|
||||||
|
self.endpoint = Some(endpoint);
|
||||||
|
|
||||||
|
let (w, r) = connection
|
||||||
|
.open_bi()
|
||||||
|
.await
|
||||||
|
.with_context(|| "open_bi failed")?;
|
||||||
|
|
||||||
|
let info = TunnelInfo {
|
||||||
|
tunnel_type: "quic".to_owned(),
|
||||||
|
local_addr: super::build_url_from_socket_addr(&local_addr.to_string(), "quic").into(),
|
||||||
|
remote_addr: self.addr.to_string(),
|
||||||
|
};
|
||||||
|
|
||||||
|
let arc_conn = Arc::new(ConnWrapper { conn: connection });
|
||||||
|
Ok(Box::new(TunnelWrapper::new(
|
||||||
|
FramedReader::new_with_associate_data(r, 4500, Some(Box::new(arc_conn.clone()))),
|
||||||
|
FramedWriter::new_with_associate_data(w, Some(Box::new(arc_conn))),
|
||||||
|
Some(info),
|
||||||
|
)))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn remote_url(&self) -> url::Url {
|
||||||
|
self.addr.clone()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use crate::tunnel::common::tests::{_tunnel_bench, _tunnel_pingpong};
|
||||||
|
|
||||||
|
use super::*;
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn quic_pingpong() {
|
||||||
|
let listener = QUICTunnelListener::new("quic://0.0.0.0:21011".parse().unwrap());
|
||||||
|
let connector = QUICTunnelConnector::new("quic://127.0.0.1:21011".parse().unwrap());
|
||||||
|
_tunnel_pingpong(listener, connector).await
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn quic_bench() {
|
||||||
|
let listener = QUICTunnelListener::new("quic://0.0.0.0:21012".parse().unwrap());
|
||||||
|
let connector = QUICTunnelConnector::new("quic://127.0.0.1:21012".parse().unwrap());
|
||||||
|
_tunnel_bench(listener, connector).await
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,427 @@
|
|||||||
|
use std::{
|
||||||
|
collections::HashMap,
|
||||||
|
sync::{
|
||||||
|
atomic::{AtomicBool, Ordering},
|
||||||
|
Arc,
|
||||||
|
},
|
||||||
|
task::{Poll, Waker},
|
||||||
|
};
|
||||||
|
|
||||||
|
use atomicbox::AtomicOptionBox;
|
||||||
|
use crossbeam_queue::ArrayQueue;
|
||||||
|
|
||||||
|
use async_trait::async_trait;
|
||||||
|
use futures::{Sink, Stream};
|
||||||
|
use once_cell::sync::Lazy;
|
||||||
|
|
||||||
|
use tokio::sync::{
|
||||||
|
mpsc::{UnboundedReceiver, UnboundedSender},
|
||||||
|
Mutex,
|
||||||
|
};
|
||||||
|
|
||||||
|
use uuid::Uuid;
|
||||||
|
|
||||||
|
use crate::tunnel::{SinkError, SinkItem};
|
||||||
|
|
||||||
|
use super::{
|
||||||
|
build_url_from_socket_addr, check_scheme_and_get_socket_addr, common::TunnelWrapper,
|
||||||
|
StreamItem, Tunnel, TunnelConnector, TunnelError, TunnelInfo, TunnelListener,
|
||||||
|
};
|
||||||
|
|
||||||
|
static RING_TUNNEL_CAP: usize = 128;
|
||||||
|
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub struct RingTunnel {
|
||||||
|
id: Uuid,
|
||||||
|
ring: ArrayQueue<SinkItem>,
|
||||||
|
closed: AtomicBool,
|
||||||
|
|
||||||
|
wait_for_new_item: AtomicOptionBox<Waker>,
|
||||||
|
wait_for_empty_slot: AtomicOptionBox<Waker>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl RingTunnel {
|
||||||
|
fn wait_for_new_item<T>(&self, cx: &mut std::task::Context<'_>) -> Poll<T> {
|
||||||
|
let ret = self
|
||||||
|
.wait_for_new_item
|
||||||
|
.swap(Some(Box::new(cx.waker().clone())), Ordering::AcqRel);
|
||||||
|
if let Some(old_waker) = ret {
|
||||||
|
assert!(old_waker.will_wake(cx.waker()));
|
||||||
|
}
|
||||||
|
Poll::Pending
|
||||||
|
}
|
||||||
|
|
||||||
|
fn wait_for_empty_slot<T>(&self, cx: &mut std::task::Context<'_>) -> Poll<T> {
|
||||||
|
let ret = self
|
||||||
|
.wait_for_empty_slot
|
||||||
|
.swap(Some(Box::new(cx.waker().clone())), Ordering::AcqRel);
|
||||||
|
if let Some(old_waker) = ret {
|
||||||
|
assert!(old_waker.will_wake(cx.waker()));
|
||||||
|
}
|
||||||
|
Poll::Pending
|
||||||
|
}
|
||||||
|
|
||||||
|
fn notify_new_item(&self) {
|
||||||
|
if let Some(w) = self.wait_for_new_item.take(Ordering::AcqRel) {
|
||||||
|
tracing::trace!(?self.id, "notify new item");
|
||||||
|
w.wake();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn notify_empty_slot(&self) {
|
||||||
|
if let Some(w) = self.wait_for_empty_slot.take(Ordering::AcqRel) {
|
||||||
|
tracing::trace!(?self.id, "notify empty slot");
|
||||||
|
w.wake();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn id(&self) -> &Uuid {
|
||||||
|
&self.id
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn len(&self) -> usize {
|
||||||
|
self.ring.len()
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn capacity(&self) -> usize {
|
||||||
|
self.ring.capacity()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn close(&self) {
|
||||||
|
tracing::info!("close ring tunnel {:?}", self.id);
|
||||||
|
self.closed
|
||||||
|
.store(true, std::sync::atomic::Ordering::Relaxed);
|
||||||
|
self.notify_new_item();
|
||||||
|
}
|
||||||
|
|
||||||
|
fn closed(&self) -> bool {
|
||||||
|
self.closed.load(std::sync::atomic::Ordering::Relaxed)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn new(cap: usize) -> Self {
|
||||||
|
let id = Uuid::new_v4();
|
||||||
|
Self {
|
||||||
|
id: id.clone(),
|
||||||
|
ring: ArrayQueue::new(cap),
|
||||||
|
closed: AtomicBool::new(false),
|
||||||
|
|
||||||
|
wait_for_new_item: AtomicOptionBox::new(None),
|
||||||
|
wait_for_empty_slot: AtomicOptionBox::new(None),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn new_with_id(id: Uuid, cap: usize) -> Self {
|
||||||
|
let mut ret = Self::new(cap);
|
||||||
|
ret.id = id;
|
||||||
|
ret
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub struct RingStream {
|
||||||
|
tunnel: Arc<RingTunnel>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl RingStream {
|
||||||
|
pub fn new(tunnel: Arc<RingTunnel>) -> Self {
|
||||||
|
Self { tunnel }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Stream for RingStream {
|
||||||
|
type Item = StreamItem;
|
||||||
|
|
||||||
|
fn poll_next(
|
||||||
|
self: std::pin::Pin<&mut Self>,
|
||||||
|
cx: &mut std::task::Context<'_>,
|
||||||
|
) -> Poll<Option<Self::Item>> {
|
||||||
|
let s = self.get_mut();
|
||||||
|
let ret = s.tunnel.ring.pop();
|
||||||
|
match ret {
|
||||||
|
Some(v) => {
|
||||||
|
s.tunnel.notify_empty_slot();
|
||||||
|
return Poll::Ready(Some(Ok(v)));
|
||||||
|
}
|
||||||
|
None => {
|
||||||
|
if s.tunnel.closed() {
|
||||||
|
tracing::warn!("ring recv tunnel {:?} closed", s.tunnel.id());
|
||||||
|
return Poll::Ready(None);
|
||||||
|
} else {
|
||||||
|
tracing::trace!("waiting recv buffer, id: {}", s.tunnel.id());
|
||||||
|
}
|
||||||
|
s.tunnel.wait_for_new_item(cx)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub struct RingSink {
|
||||||
|
tunnel: Arc<RingTunnel>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Drop for RingSink {
|
||||||
|
fn drop(&mut self) {
|
||||||
|
self.tunnel.close();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl RingSink {
|
||||||
|
pub fn new(tunnel: Arc<RingTunnel>) -> Self {
|
||||||
|
Self { tunnel }
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn push_no_check(&self, item: SinkItem) -> Result<(), TunnelError> {
|
||||||
|
if self.tunnel.closed() {
|
||||||
|
return Err(TunnelError::Shutdown);
|
||||||
|
}
|
||||||
|
|
||||||
|
log::trace!("id: {}, send buffer, buf: {:?}", self.tunnel.id(), &item);
|
||||||
|
self.tunnel.ring.push(item).unwrap();
|
||||||
|
self.tunnel.notify_new_item();
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn has_empty_slot(&self) -> bool {
|
||||||
|
self.tunnel.len() < self.tunnel.capacity()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Sink<SinkItem> for RingSink {
|
||||||
|
type Error = SinkError;
|
||||||
|
|
||||||
|
fn poll_ready(
|
||||||
|
self: std::pin::Pin<&mut Self>,
|
||||||
|
cx: &mut std::task::Context<'_>,
|
||||||
|
) -> std::task::Poll<Result<(), Self::Error>> {
|
||||||
|
let self_mut = self.get_mut();
|
||||||
|
if !self_mut.has_empty_slot() {
|
||||||
|
if self_mut.tunnel.closed() {
|
||||||
|
return Poll::Ready(Err(TunnelError::Shutdown));
|
||||||
|
}
|
||||||
|
self_mut.tunnel.wait_for_empty_slot(cx)
|
||||||
|
} else {
|
||||||
|
Poll::Ready(Ok(()))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn start_send(self: std::pin::Pin<&mut Self>, item: SinkItem) -> Result<(), Self::Error> {
|
||||||
|
self.push_no_check(item)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn poll_flush(
|
||||||
|
self: std::pin::Pin<&mut Self>,
|
||||||
|
_cx: &mut std::task::Context<'_>,
|
||||||
|
) -> std::task::Poll<Result<(), Self::Error>> {
|
||||||
|
if self.tunnel.closed() {
|
||||||
|
return Poll::Ready(Err(TunnelError::Shutdown));
|
||||||
|
}
|
||||||
|
Poll::Ready(Ok(()))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn poll_close(
|
||||||
|
self: std::pin::Pin<&mut Self>,
|
||||||
|
_cx: &mut std::task::Context<'_>,
|
||||||
|
) -> std::task::Poll<Result<(), Self::Error>> {
|
||||||
|
self.tunnel.close();
|
||||||
|
Poll::Ready(Ok(()))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
struct Connection {
|
||||||
|
client: Arc<RingTunnel>,
|
||||||
|
server: Arc<RingTunnel>,
|
||||||
|
}
|
||||||
|
|
||||||
|
static CONNECTION_MAP: Lazy<Arc<Mutex<HashMap<uuid::Uuid, UnboundedSender<Arc<Connection>>>>>> =
|
||||||
|
Lazy::new(|| Arc::new(Mutex::new(HashMap::new())));
|
||||||
|
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub struct RingTunnelListener {
|
||||||
|
listerner_addr: url::Url,
|
||||||
|
conn_sender: UnboundedSender<Arc<Connection>>,
|
||||||
|
conn_receiver: UnboundedReceiver<Arc<Connection>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl RingTunnelListener {
|
||||||
|
pub fn new(key: url::Url) -> Self {
|
||||||
|
let (conn_sender, conn_receiver) = tokio::sync::mpsc::unbounded_channel();
|
||||||
|
RingTunnelListener {
|
||||||
|
listerner_addr: key,
|
||||||
|
conn_sender,
|
||||||
|
conn_receiver,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn get_tunnel_for_client(conn: Arc<Connection>) -> impl Tunnel {
|
||||||
|
TunnelWrapper::new(
|
||||||
|
RingStream::new(conn.client.clone()),
|
||||||
|
RingSink::new(conn.server.clone()),
|
||||||
|
Some(TunnelInfo {
|
||||||
|
tunnel_type: "ring".to_owned(),
|
||||||
|
local_addr: build_url_from_socket_addr(&conn.client.id.into(), "ring").into(),
|
||||||
|
remote_addr: build_url_from_socket_addr(&conn.server.id.into(), "ring").into(),
|
||||||
|
}),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn get_tunnel_for_server(conn: Arc<Connection>) -> impl Tunnel {
|
||||||
|
TunnelWrapper::new(
|
||||||
|
RingStream::new(conn.server.clone()),
|
||||||
|
RingSink::new(conn.client.clone()),
|
||||||
|
Some(TunnelInfo {
|
||||||
|
tunnel_type: "ring".to_owned(),
|
||||||
|
local_addr: build_url_from_socket_addr(&conn.server.id.into(), "ring").into(),
|
||||||
|
remote_addr: build_url_from_socket_addr(&conn.client.id.into(), "ring").into(),
|
||||||
|
}),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
impl RingTunnelListener {
|
||||||
|
fn get_addr(&self) -> Result<uuid::Uuid, TunnelError> {
|
||||||
|
check_scheme_and_get_socket_addr::<Uuid>(&self.listerner_addr, "ring")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
impl TunnelListener for RingTunnelListener {
|
||||||
|
async fn listen(&mut self) -> Result<(), TunnelError> {
|
||||||
|
log::info!("listen new conn of key: {}", self.listerner_addr);
|
||||||
|
CONNECTION_MAP
|
||||||
|
.lock()
|
||||||
|
.await
|
||||||
|
.insert(self.get_addr()?, self.conn_sender.clone());
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn accept(&mut self) -> Result<Box<dyn Tunnel>, TunnelError> {
|
||||||
|
log::info!("waiting accept new conn of key: {}", self.listerner_addr);
|
||||||
|
let my_addr = self.get_addr()?;
|
||||||
|
if let Some(conn) = self.conn_receiver.recv().await {
|
||||||
|
if conn.server.id == my_addr {
|
||||||
|
log::info!("accept new conn of key: {}", self.listerner_addr);
|
||||||
|
return Ok(Box::new(get_tunnel_for_server(conn)));
|
||||||
|
} else {
|
||||||
|
tracing::error!(?conn.server.id, ?my_addr, "got new conn with wrong id");
|
||||||
|
return Err(TunnelError::InternalError(
|
||||||
|
"accept got wrong ring server id".to_owned(),
|
||||||
|
));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return Err(TunnelError::InternalError(
|
||||||
|
"conn receiver stopped".to_owned(),
|
||||||
|
));
|
||||||
|
}
|
||||||
|
|
||||||
|
fn local_url(&self) -> url::Url {
|
||||||
|
self.listerner_addr.clone()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct RingTunnelConnector {
|
||||||
|
remote_addr: url::Url,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl RingTunnelConnector {
|
||||||
|
pub fn new(remote_addr: url::Url) -> Self {
|
||||||
|
RingTunnelConnector { remote_addr }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
impl TunnelConnector for RingTunnelConnector {
|
||||||
|
async fn connect(&mut self) -> Result<Box<dyn Tunnel>, super::TunnelError> {
|
||||||
|
let remote_addr = check_scheme_and_get_socket_addr::<Uuid>(&self.remote_addr, "ring")?;
|
||||||
|
let entry = CONNECTION_MAP
|
||||||
|
.lock()
|
||||||
|
.await
|
||||||
|
.get(&remote_addr)
|
||||||
|
.unwrap()
|
||||||
|
.clone();
|
||||||
|
log::info!("connecting");
|
||||||
|
let conn = Arc::new(Connection {
|
||||||
|
client: Arc::new(RingTunnel::new(RING_TUNNEL_CAP)),
|
||||||
|
server: Arc::new(RingTunnel::new_with_id(
|
||||||
|
remote_addr.clone(),
|
||||||
|
RING_TUNNEL_CAP,
|
||||||
|
)),
|
||||||
|
});
|
||||||
|
entry
|
||||||
|
.send(conn.clone())
|
||||||
|
.map_err(|_| TunnelError::InternalError("send conn to listner failed".to_owned()))?;
|
||||||
|
Ok(Box::new(get_tunnel_for_client(conn)))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn remote_url(&self) -> url::Url {
|
||||||
|
self.remote_addr.clone()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn create_ring_tunnel_pair() -> (Box<dyn Tunnel>, Box<dyn Tunnel>) {
|
||||||
|
let conn = Arc::new(Connection {
|
||||||
|
client: Arc::new(RingTunnel::new(RING_TUNNEL_CAP)),
|
||||||
|
server: Arc::new(RingTunnel::new(RING_TUNNEL_CAP)),
|
||||||
|
});
|
||||||
|
(
|
||||||
|
Box::new(get_tunnel_for_server(conn.clone())),
|
||||||
|
Box::new(get_tunnel_for_client(conn)),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use futures::StreamExt;
|
||||||
|
use tokio::time::timeout;
|
||||||
|
|
||||||
|
use crate::tunnel::common::tests::{_tunnel_bench, _tunnel_pingpong};
|
||||||
|
|
||||||
|
use super::*;
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn ring_pingpong() {
|
||||||
|
let id: url::Url = format!("ring://{}", Uuid::new_v4()).parse().unwrap();
|
||||||
|
let listener = RingTunnelListener::new(id.clone());
|
||||||
|
let connector = RingTunnelConnector::new(id.clone());
|
||||||
|
_tunnel_pingpong(listener, connector).await
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn ring_bench() {
|
||||||
|
let id: url::Url = format!("ring://{}", Uuid::new_v4()).parse().unwrap();
|
||||||
|
let listener = RingTunnelListener::new(id.clone());
|
||||||
|
let connector = RingTunnelConnector::new(id);
|
||||||
|
_tunnel_bench(listener, connector).await
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn ring_close() {
|
||||||
|
let (stunnel, ctunnel) = create_ring_tunnel_pair();
|
||||||
|
drop(stunnel);
|
||||||
|
|
||||||
|
let mut stream = ctunnel.split().0;
|
||||||
|
let ret = stream.next().await;
|
||||||
|
assert!(ret.as_ref().is_none(), "expect none, got {:?}", ret);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn abort_ring_stream() {
|
||||||
|
let (_stunnel, ctunnel) = create_ring_tunnel_pair();
|
||||||
|
let mut stream = ctunnel.split().0;
|
||||||
|
let task = tokio::spawn(async move {
|
||||||
|
let _ = stream.next().await;
|
||||||
|
});
|
||||||
|
tokio::time::sleep(tokio::time::Duration::from_secs(1)).await;
|
||||||
|
task.abort();
|
||||||
|
let _ = tokio::join!(task);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn ring_stream_recv_timeout() {
|
||||||
|
let (_stunnel, ctunnel) = create_ring_tunnel_pair();
|
||||||
|
let mut stream = ctunnel.split().0;
|
||||||
|
let _ = timeout(tokio::time::Duration::from_millis(10), stream.next()).await;
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,95 @@
|
|||||||
|
use std::sync::atomic::{AtomicU32, AtomicU64, Ordering::Relaxed};
|
||||||
|
|
||||||
|
pub struct WindowLatency {
|
||||||
|
latency_us_window: Vec<AtomicU32>,
|
||||||
|
latency_us_window_index: AtomicU32,
|
||||||
|
latency_us_window_size: u32,
|
||||||
|
|
||||||
|
sum: AtomicU32,
|
||||||
|
count: AtomicU32,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl WindowLatency {
|
||||||
|
pub fn new(window_size: u32) -> Self {
|
||||||
|
Self {
|
||||||
|
latency_us_window: (0..window_size).map(|_| AtomicU32::new(0)).collect(),
|
||||||
|
latency_us_window_index: AtomicU32::new(0),
|
||||||
|
latency_us_window_size: window_size,
|
||||||
|
|
||||||
|
sum: AtomicU32::new(0),
|
||||||
|
count: AtomicU32::new(0),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn record_latency(&self, latency_us: u32) {
|
||||||
|
let index = self.latency_us_window_index.fetch_add(1, Relaxed);
|
||||||
|
if self.count.load(Relaxed) < self.latency_us_window_size {
|
||||||
|
self.count.fetch_add(1, Relaxed);
|
||||||
|
}
|
||||||
|
|
||||||
|
let index = index % self.latency_us_window_size;
|
||||||
|
let old_lat = self.latency_us_window[index as usize].swap(latency_us, Relaxed);
|
||||||
|
|
||||||
|
if old_lat < latency_us {
|
||||||
|
self.sum.fetch_add(latency_us - old_lat, Relaxed);
|
||||||
|
} else {
|
||||||
|
self.sum.fetch_sub(old_lat - latency_us, Relaxed);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn get_latency_us<T: From<u32> + std::ops::Div<Output = T>>(&self) -> T {
|
||||||
|
let count = self.count.load(Relaxed);
|
||||||
|
let sum = self.sum.load(Relaxed);
|
||||||
|
if count == 0 {
|
||||||
|
0.into()
|
||||||
|
} else {
|
||||||
|
(T::from(sum)) / T::from(count)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct Throughput {
|
||||||
|
tx_bytes: AtomicU64,
|
||||||
|
rx_bytes: AtomicU64,
|
||||||
|
|
||||||
|
tx_packets: AtomicU64,
|
||||||
|
rx_packets: AtomicU64,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Throughput {
|
||||||
|
pub fn new() -> Self {
|
||||||
|
Self {
|
||||||
|
tx_bytes: AtomicU64::new(0),
|
||||||
|
rx_bytes: AtomicU64::new(0),
|
||||||
|
|
||||||
|
tx_packets: AtomicU64::new(0),
|
||||||
|
rx_packets: AtomicU64::new(0),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn tx_bytes(&self) -> u64 {
|
||||||
|
self.tx_bytes.load(Relaxed)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn rx_bytes(&self) -> u64 {
|
||||||
|
self.rx_bytes.load(Relaxed)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn tx_packets(&self) -> u64 {
|
||||||
|
self.tx_packets.load(Relaxed)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn rx_packets(&self) -> u64 {
|
||||||
|
self.rx_packets.load(Relaxed)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn record_tx_bytes(&self, bytes: u64) {
|
||||||
|
self.tx_bytes.fetch_add(bytes, Relaxed);
|
||||||
|
self.tx_packets.fetch_add(1, Relaxed);
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn record_rx_bytes(&self, bytes: u64) {
|
||||||
|
self.rx_bytes.fetch_add(bytes, Relaxed);
|
||||||
|
self.rx_packets.fetch_add(1, Relaxed);
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,200 @@
|
|||||||
|
use std::net::SocketAddr;
|
||||||
|
|
||||||
|
use async_trait::async_trait;
|
||||||
|
use futures::stream::FuturesUnordered;
|
||||||
|
use tokio::net::{TcpListener, TcpSocket, TcpStream};
|
||||||
|
|
||||||
|
use crate::{rpc::TunnelInfo, tunnel::common::setup_sokcet2};
|
||||||
|
|
||||||
|
use super::{
|
||||||
|
check_scheme_and_get_socket_addr,
|
||||||
|
common::{wait_for_connect_futures, FramedReader, FramedWriter, TunnelWrapper},
|
||||||
|
Tunnel, TunnelError, TunnelListener,
|
||||||
|
};
|
||||||
|
|
||||||
|
const TCP_MTU_BYTES: usize = 64 * 1024;
|
||||||
|
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub struct TcpTunnelListener {
|
||||||
|
addr: url::Url,
|
||||||
|
listener: Option<TcpListener>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl TcpTunnelListener {
|
||||||
|
pub fn new(addr: url::Url) -> Self {
|
||||||
|
TcpTunnelListener {
|
||||||
|
addr,
|
||||||
|
listener: None,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
impl TunnelListener for TcpTunnelListener {
|
||||||
|
async fn listen(&mut self) -> Result<(), TunnelError> {
|
||||||
|
let addr = check_scheme_and_get_socket_addr::<SocketAddr>(&self.addr, "tcp")?;
|
||||||
|
|
||||||
|
let socket = if addr.is_ipv4() {
|
||||||
|
TcpSocket::new_v4()?
|
||||||
|
} else {
|
||||||
|
TcpSocket::new_v6()?
|
||||||
|
};
|
||||||
|
|
||||||
|
socket.set_reuseaddr(true)?;
|
||||||
|
// #[cfg(all(unix, not(target_os = "solaris"), not(target_os = "illumos")))]
|
||||||
|
// socket.set_reuseport(true)?;
|
||||||
|
socket.bind(addr)?;
|
||||||
|
|
||||||
|
self.listener = Some(socket.listen(1024)?);
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn accept(&mut self) -> Result<Box<dyn Tunnel>, super::TunnelError> {
|
||||||
|
let listener = self.listener.as_ref().unwrap();
|
||||||
|
let (stream, _) = listener.accept().await?;
|
||||||
|
stream.set_nodelay(true).unwrap();
|
||||||
|
let info = TunnelInfo {
|
||||||
|
tunnel_type: "tcp".to_owned(),
|
||||||
|
local_addr: self.local_url().into(),
|
||||||
|
remote_addr: super::build_url_from_socket_addr(&stream.peer_addr()?.to_string(), "tcp")
|
||||||
|
.into(),
|
||||||
|
};
|
||||||
|
|
||||||
|
let (r, w) = stream.into_split();
|
||||||
|
Ok(Box::new(TunnelWrapper::new(
|
||||||
|
FramedReader::new(r, TCP_MTU_BYTES),
|
||||||
|
FramedWriter::new(w),
|
||||||
|
Some(info),
|
||||||
|
)))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn local_url(&self) -> url::Url {
|
||||||
|
self.addr.clone()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn get_tunnel_with_tcp_stream(
|
||||||
|
stream: TcpStream,
|
||||||
|
remote_url: url::Url,
|
||||||
|
) -> Result<Box<dyn Tunnel>, super::TunnelError> {
|
||||||
|
stream.set_nodelay(true).unwrap();
|
||||||
|
|
||||||
|
let info = TunnelInfo {
|
||||||
|
tunnel_type: "tcp".to_owned(),
|
||||||
|
local_addr: super::build_url_from_socket_addr(&stream.local_addr()?.to_string(), "tcp")
|
||||||
|
.into(),
|
||||||
|
remote_addr: remote_url.into(),
|
||||||
|
};
|
||||||
|
|
||||||
|
let (r, w) = stream.into_split();
|
||||||
|
Ok(Box::new(TunnelWrapper::new(
|
||||||
|
FramedReader::new(r, TCP_MTU_BYTES),
|
||||||
|
FramedWriter::new(w),
|
||||||
|
Some(info),
|
||||||
|
)))
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub struct TcpTunnelConnector {
|
||||||
|
addr: url::Url,
|
||||||
|
|
||||||
|
bind_addrs: Vec<SocketAddr>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl TcpTunnelConnector {
|
||||||
|
pub fn new(addr: url::Url) -> Self {
|
||||||
|
TcpTunnelConnector {
|
||||||
|
addr,
|
||||||
|
bind_addrs: vec![],
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn connect_with_default_bind(&mut self) -> Result<Box<dyn Tunnel>, super::TunnelError> {
|
||||||
|
tracing::info!(addr = ?self.addr, "connect tcp start");
|
||||||
|
let addr = check_scheme_and_get_socket_addr::<SocketAddr>(&self.addr, "tcp")?;
|
||||||
|
let stream = TcpStream::connect(addr).await?;
|
||||||
|
tracing::info!(addr = ?self.addr, "connect tcp succ");
|
||||||
|
return get_tunnel_with_tcp_stream(stream, self.addr.clone().into());
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn connect_with_custom_bind(&mut self) -> Result<Box<dyn Tunnel>, super::TunnelError> {
|
||||||
|
let futures = FuturesUnordered::new();
|
||||||
|
let dst_addr = check_scheme_and_get_socket_addr::<SocketAddr>(&self.addr, "tcp")?;
|
||||||
|
|
||||||
|
for bind_addr in self.bind_addrs.iter() {
|
||||||
|
tracing::info!(bind_addr = ?bind_addr, ?dst_addr, "bind addr");
|
||||||
|
|
||||||
|
let socket2_socket = socket2::Socket::new(
|
||||||
|
socket2::Domain::for_address(dst_addr),
|
||||||
|
socket2::Type::STREAM,
|
||||||
|
Some(socket2::Protocol::TCP),
|
||||||
|
)?;
|
||||||
|
setup_sokcet2(&socket2_socket, bind_addr)?;
|
||||||
|
|
||||||
|
let socket = TcpSocket::from_std_stream(socket2_socket.into());
|
||||||
|
futures.push(socket.connect(dst_addr.clone()));
|
||||||
|
}
|
||||||
|
|
||||||
|
let ret = wait_for_connect_futures(futures).await;
|
||||||
|
return get_tunnel_with_tcp_stream(ret?, self.addr.clone().into());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
impl super::TunnelConnector for TcpTunnelConnector {
|
||||||
|
async fn connect(&mut self) -> Result<Box<dyn Tunnel>, super::TunnelError> {
|
||||||
|
if self.bind_addrs.is_empty() {
|
||||||
|
self.connect_with_default_bind().await
|
||||||
|
} else {
|
||||||
|
self.connect_with_custom_bind().await
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn remote_url(&self) -> url::Url {
|
||||||
|
self.addr.clone()
|
||||||
|
}
|
||||||
|
fn set_bind_addrs(&mut self, addrs: Vec<SocketAddr>) {
|
||||||
|
self.bind_addrs = addrs;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use crate::tunnel::{
|
||||||
|
common::tests::{_tunnel_bench, _tunnel_pingpong},
|
||||||
|
TunnelConnector,
|
||||||
|
};
|
||||||
|
|
||||||
|
use super::*;
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn tcp_pingpong() {
|
||||||
|
let listener = TcpTunnelListener::new("tcp://0.0.0.0:31011".parse().unwrap());
|
||||||
|
let connector = TcpTunnelConnector::new("tcp://127.0.0.1:31011".parse().unwrap());
|
||||||
|
_tunnel_pingpong(listener, connector).await
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn tcp_bench() {
|
||||||
|
let listener = TcpTunnelListener::new("tcp://0.0.0.0:31012".parse().unwrap());
|
||||||
|
let connector = TcpTunnelConnector::new("tcp://127.0.0.1:31012".parse().unwrap());
|
||||||
|
_tunnel_bench(listener, connector).await
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn tcp_bench_with_bind() {
|
||||||
|
let listener = TcpTunnelListener::new("tcp://127.0.0.1:11013".parse().unwrap());
|
||||||
|
let mut connector = TcpTunnelConnector::new("tcp://127.0.0.1:11013".parse().unwrap());
|
||||||
|
connector.set_bind_addrs(vec!["127.0.0.1:0".parse().unwrap()]);
|
||||||
|
_tunnel_pingpong(listener, connector).await
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
#[should_panic]
|
||||||
|
async fn tcp_bench_with_bind_fail() {
|
||||||
|
let listener = TcpTunnelListener::new("tcp://127.0.0.1:11014".parse().unwrap());
|
||||||
|
let mut connector = TcpTunnelConnector::new("tcp://127.0.0.1:11014".parse().unwrap());
|
||||||
|
connector.set_bind_addrs(vec!["10.0.0.1:0".parse().unwrap()]);
|
||||||
|
_tunnel_pingpong(listener, connector).await
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,838 @@
|
|||||||
|
use std::{fmt::Debug, sync::Arc};
|
||||||
|
|
||||||
|
use async_trait::async_trait;
|
||||||
|
use bytes::BytesMut;
|
||||||
|
use dashmap::DashMap;
|
||||||
|
use futures::{stream::FuturesUnordered, StreamExt};
|
||||||
|
use rand::{Rng, SeedableRng};
|
||||||
|
|
||||||
|
use std::net::SocketAddr;
|
||||||
|
use tokio::{
|
||||||
|
net::UdpSocket,
|
||||||
|
sync::mpsc::{Receiver, Sender, UnboundedReceiver, UnboundedSender},
|
||||||
|
task::{JoinHandle, JoinSet},
|
||||||
|
};
|
||||||
|
|
||||||
|
use tracing::{instrument, Instrument};
|
||||||
|
|
||||||
|
use crate::{
|
||||||
|
common::join_joinset_background,
|
||||||
|
rpc::TunnelInfo,
|
||||||
|
tunnel::{
|
||||||
|
common::{reserve_buf, TunnelWrapper},
|
||||||
|
packet_def::{UdpPacketType, ZCPacket, ZCPacketType},
|
||||||
|
ring::RingTunnel,
|
||||||
|
},
|
||||||
|
};
|
||||||
|
|
||||||
|
use super::{
|
||||||
|
common::{setup_sokcet2, setup_sokcet2_ext, wait_for_connect_futures},
|
||||||
|
packet_def::{UDPTunnelHeader, UDP_TUNNEL_HEADER_SIZE},
|
||||||
|
ring::{RingSink, RingStream},
|
||||||
|
Tunnel, TunnelConnCounter, TunnelError, TunnelListener, TunnelUrl,
|
||||||
|
};
|
||||||
|
|
||||||
|
pub const UDP_DATA_MTU: usize = 65000;
|
||||||
|
|
||||||
|
type UdpCloseEventSender = UnboundedSender<Option<TunnelError>>;
|
||||||
|
type UdpCloseEventReceiver = UnboundedReceiver<Option<TunnelError>>;
|
||||||
|
|
||||||
|
fn new_udp_packet<F>(f: F, udp_body: Option<&mut [u8]>) -> ZCPacket
|
||||||
|
where
|
||||||
|
F: FnOnce(&mut UDPTunnelHeader),
|
||||||
|
{
|
||||||
|
let mut buf = BytesMut::new();
|
||||||
|
buf.resize(
|
||||||
|
UDP_TUNNEL_HEADER_SIZE + udp_body.as_ref().map(|v| v.len()).unwrap_or(0),
|
||||||
|
0,
|
||||||
|
);
|
||||||
|
buf[UDP_TUNNEL_HEADER_SIZE..].copy_from_slice(udp_body.unwrap());
|
||||||
|
|
||||||
|
let mut ret = ZCPacket::new_from_buf(buf, ZCPacketType::UDP);
|
||||||
|
let header = ret.mut_udp_tunnel_header().unwrap();
|
||||||
|
f(header);
|
||||||
|
ret
|
||||||
|
}
|
||||||
|
|
||||||
|
fn new_syn_packet(conn_id: u32, magic: u64) -> ZCPacket {
|
||||||
|
new_udp_packet(
|
||||||
|
|header| {
|
||||||
|
header.msg_type = UdpPacketType::Syn as u8;
|
||||||
|
header.conn_id.set(conn_id);
|
||||||
|
header.len.set(8);
|
||||||
|
},
|
||||||
|
Some(&mut magic.to_le_bytes()),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn new_sack_packet(conn_id: u32, magic: u64) -> ZCPacket {
|
||||||
|
new_udp_packet(
|
||||||
|
|header| {
|
||||||
|
header.msg_type = UdpPacketType::Sack as u8;
|
||||||
|
header.conn_id.set(conn_id);
|
||||||
|
header.len.set(8);
|
||||||
|
},
|
||||||
|
Some(&mut magic.to_le_bytes()),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn new_hole_punch_packet() -> ZCPacket {
|
||||||
|
// generate a 128 bytes vec with random data
|
||||||
|
let mut rng = rand::rngs::StdRng::from_entropy();
|
||||||
|
let mut buf = vec![0u8; 128];
|
||||||
|
rng.fill(&mut buf[..]);
|
||||||
|
new_udp_packet(
|
||||||
|
|header| {
|
||||||
|
header.msg_type = UdpPacketType::HolePunch as u8;
|
||||||
|
header.conn_id.set(0);
|
||||||
|
header.len.set(0);
|
||||||
|
},
|
||||||
|
Some(&mut buf),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn get_zcpacket_from_buf(buf: BytesMut) -> Result<ZCPacket, TunnelError> {
|
||||||
|
let dg_size = buf.len();
|
||||||
|
if dg_size < UDP_TUNNEL_HEADER_SIZE {
|
||||||
|
return Err(TunnelError::InvalidPacket(format!(
|
||||||
|
"udp packet size too small: {:?}, packet: {:?}",
|
||||||
|
dg_size, buf
|
||||||
|
)));
|
||||||
|
}
|
||||||
|
|
||||||
|
let zc_packet = ZCPacket::new_from_buf(buf, ZCPacketType::UDP);
|
||||||
|
let header = zc_packet.udp_tunnel_header().unwrap();
|
||||||
|
let payload_len = header.len.get() as usize;
|
||||||
|
if payload_len != dg_size - UDP_TUNNEL_HEADER_SIZE {
|
||||||
|
return Err(TunnelError::InvalidPacket(format!(
|
||||||
|
"udp packet payload len not match: header len: {:?}, real len: {:?}",
|
||||||
|
payload_len, dg_size
|
||||||
|
)));
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(zc_packet)
|
||||||
|
}
|
||||||
|
|
||||||
|
#[instrument]
|
||||||
|
async fn forward_from_ring_to_udp(
|
||||||
|
mut ring_recv: RingStream,
|
||||||
|
socket: &Arc<UdpSocket>,
|
||||||
|
addr: &SocketAddr,
|
||||||
|
conn_id: u32,
|
||||||
|
) -> Option<TunnelError> {
|
||||||
|
tracing::debug!("udp forward from ring to udp");
|
||||||
|
loop {
|
||||||
|
let Some(buf) = ring_recv.next().await else {
|
||||||
|
return None;
|
||||||
|
};
|
||||||
|
let mut packet = match buf {
|
||||||
|
Ok(v) => v,
|
||||||
|
Err(e) => {
|
||||||
|
return Some(e);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
let udp_payload_len = packet.udp_payload().len();
|
||||||
|
let header = packet.mut_udp_tunnel_header().unwrap();
|
||||||
|
header.conn_id.set(conn_id);
|
||||||
|
header.len.set(udp_payload_len as u16);
|
||||||
|
header.msg_type = UdpPacketType::Data as u8;
|
||||||
|
|
||||||
|
let buf = packet.into_bytes(ZCPacketType::UDP);
|
||||||
|
tracing::trace!(?udp_payload_len, ?buf, "udp forward from ring to udp");
|
||||||
|
let ret = socket.send_to(&buf, &addr).await;
|
||||||
|
if ret.is_err() {
|
||||||
|
return Some(TunnelError::IOError(ret.unwrap_err()));
|
||||||
|
} else if ret.unwrap() == 0 {
|
||||||
|
return None;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
struct UdpConnection {
|
||||||
|
socket: Arc<UdpSocket>,
|
||||||
|
conn_id: u32,
|
||||||
|
dst_addr: SocketAddr,
|
||||||
|
|
||||||
|
ring_sender: RingSink,
|
||||||
|
forward_task: JoinHandle<()>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl UdpConnection {
|
||||||
|
pub fn new(
|
||||||
|
socket: Arc<UdpSocket>,
|
||||||
|
conn_id: u32,
|
||||||
|
dst_addr: SocketAddr,
|
||||||
|
ring_sender: RingSink,
|
||||||
|
ring_recv: RingStream,
|
||||||
|
close_event_sender: UdpCloseEventSender,
|
||||||
|
) -> Self {
|
||||||
|
let s = socket.clone();
|
||||||
|
let forward_task = tokio::spawn(async move {
|
||||||
|
let close_event_sender = close_event_sender;
|
||||||
|
let err = forward_from_ring_to_udp(ring_recv, &s, &dst_addr, conn_id).await;
|
||||||
|
if let Err(e) = close_event_sender.send(err) {
|
||||||
|
tracing::error!(?e, "udp send close event error");
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
Self {
|
||||||
|
socket,
|
||||||
|
conn_id,
|
||||||
|
dst_addr,
|
||||||
|
ring_sender,
|
||||||
|
forward_task,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Drop for UdpConnection {
|
||||||
|
fn drop(&mut self) {
|
||||||
|
self.forward_task.abort();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone)]
|
||||||
|
struct UdpTunnelListenerData {
|
||||||
|
local_url: url::Url,
|
||||||
|
socket: Option<Arc<UdpSocket>>,
|
||||||
|
sock_map: Arc<DashMap<SocketAddr, UdpConnection>>,
|
||||||
|
conn_send: Sender<Box<dyn Tunnel>>,
|
||||||
|
close_event_sender: UdpCloseEventSender,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl UdpTunnelListenerData {
|
||||||
|
pub fn new(
|
||||||
|
local_url: url::Url,
|
||||||
|
conn_send: Sender<Box<dyn Tunnel>>,
|
||||||
|
close_event_sender: UdpCloseEventSender,
|
||||||
|
) -> Self {
|
||||||
|
Self {
|
||||||
|
local_url,
|
||||||
|
socket: None,
|
||||||
|
sock_map: Arc::new(DashMap::new()),
|
||||||
|
conn_send,
|
||||||
|
close_event_sender,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn handle_new_connect(self: Self, remote_addr: SocketAddr, zc_packet: ZCPacket) {
|
||||||
|
let udp_payload = zc_packet.udp_payload();
|
||||||
|
if udp_payload.len() != 8 {
|
||||||
|
tracing::warn!(
|
||||||
|
"udp syn packet payload len not match: {:?}, packet: {:?}",
|
||||||
|
udp_payload.len(),
|
||||||
|
zc_packet,
|
||||||
|
);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
let magic = u64::from_le_bytes(udp_payload[..8].try_into().unwrap());
|
||||||
|
let conn_id = zc_packet.udp_tunnel_header().unwrap().conn_id.get();
|
||||||
|
|
||||||
|
tracing::info!(?conn_id, ?remote_addr, "udp connection accept handling",);
|
||||||
|
let socket = self.socket.as_ref().unwrap().clone();
|
||||||
|
|
||||||
|
let sack_buf = new_sack_packet(conn_id, magic).into_bytes(ZCPacketType::UDP);
|
||||||
|
if let Err(e) = socket.send_to(&sack_buf, remote_addr).await {
|
||||||
|
tracing::error!(?e, "udp send sack packet error");
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
let ring_for_send_udp = Arc::new(RingTunnel::new(128));
|
||||||
|
let ring_for_recv_udp = Arc::new(RingTunnel::new(128));
|
||||||
|
tracing::debug!(
|
||||||
|
?ring_for_send_udp,
|
||||||
|
?ring_for_recv_udp,
|
||||||
|
"udp build tunnel for listener"
|
||||||
|
);
|
||||||
|
|
||||||
|
let internal_conn = UdpConnection::new(
|
||||||
|
socket.clone(),
|
||||||
|
conn_id,
|
||||||
|
remote_addr,
|
||||||
|
RingSink::new(ring_for_recv_udp.clone()),
|
||||||
|
RingStream::new(ring_for_send_udp.clone()),
|
||||||
|
self.close_event_sender.clone(),
|
||||||
|
);
|
||||||
|
self.sock_map.insert(remote_addr, internal_conn);
|
||||||
|
|
||||||
|
let conn = Box::new(TunnelWrapper::new(
|
||||||
|
Box::new(RingStream::new(ring_for_recv_udp)),
|
||||||
|
Box::new(RingSink::new(ring_for_send_udp)),
|
||||||
|
Some(TunnelInfo {
|
||||||
|
tunnel_type: "udp".to_owned(),
|
||||||
|
local_addr: self.local_url.clone().into(),
|
||||||
|
remote_addr: url::Url::parse(&format!("udp://{}", remote_addr))
|
||||||
|
.unwrap()
|
||||||
|
.into(),
|
||||||
|
}),
|
||||||
|
));
|
||||||
|
|
||||||
|
if let Err(e) = self.conn_send.send(conn).await {
|
||||||
|
tracing::warn!(?e, "udp send conn to accept channel error");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn try_forward_packet(
|
||||||
|
self: &Self,
|
||||||
|
remote_addr: &SocketAddr,
|
||||||
|
conn_id: u32,
|
||||||
|
p: ZCPacket,
|
||||||
|
) -> Result<(), TunnelError> {
|
||||||
|
let Some(conn) = self.sock_map.get(remote_addr) else {
|
||||||
|
return Err(TunnelError::InternalError(
|
||||||
|
"udp connection not found".to_owned(),
|
||||||
|
));
|
||||||
|
};
|
||||||
|
|
||||||
|
if conn.conn_id != conn_id {
|
||||||
|
return Err(TunnelError::ConnIdNotMatch(conn.conn_id, conn_id));
|
||||||
|
}
|
||||||
|
|
||||||
|
if !conn.ring_sender.has_empty_slot() {
|
||||||
|
return Err(TunnelError::BufferFull);
|
||||||
|
}
|
||||||
|
|
||||||
|
conn.ring_sender.push_no_check(p)?;
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn process_forward_packet(&self, zc_packet: ZCPacket, addr: &SocketAddr) {
|
||||||
|
let header = zc_packet.udp_tunnel_header().unwrap();
|
||||||
|
if header.msg_type == UdpPacketType::Syn as u8 {
|
||||||
|
tokio::spawn(Self::handle_new_connect(self.clone(), *addr, zc_packet));
|
||||||
|
} else {
|
||||||
|
if let Err(e) = self
|
||||||
|
.try_forward_packet(addr, header.conn_id.get(), zc_packet)
|
||||||
|
.await
|
||||||
|
{
|
||||||
|
tracing::trace!(?e, "udp forward packet error");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn do_forward_task(self: Self) {
|
||||||
|
let socket = self.socket.as_ref().unwrap().clone();
|
||||||
|
let mut buf = BytesMut::new();
|
||||||
|
loop {
|
||||||
|
reserve_buf(&mut buf, UDP_DATA_MTU, UDP_DATA_MTU * 128);
|
||||||
|
let (dg_size, addr) = socket.recv_buf_from(&mut buf).await.unwrap();
|
||||||
|
tracing::trace!(
|
||||||
|
"udp recv packet: {:?}, buf: {:?}, size: {}",
|
||||||
|
addr,
|
||||||
|
buf,
|
||||||
|
dg_size
|
||||||
|
);
|
||||||
|
|
||||||
|
let zc_packet = match get_zcpacket_from_buf(buf.split()) {
|
||||||
|
Ok(v) => v,
|
||||||
|
Err(e) => {
|
||||||
|
tracing::warn!(?e, "udp get zc packet from buf error");
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
self.process_forward_packet(zc_packet, &addr).await;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct UdpTunnelListener {
|
||||||
|
addr: url::Url,
|
||||||
|
socket: Option<Arc<UdpSocket>>,
|
||||||
|
|
||||||
|
conn_recv: Receiver<Box<dyn Tunnel>>,
|
||||||
|
data: UdpTunnelListenerData,
|
||||||
|
forward_tasks: Arc<std::sync::Mutex<JoinSet<()>>>,
|
||||||
|
close_event_recv: UdpCloseEventReceiver,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl UdpTunnelListener {
|
||||||
|
pub fn new(addr: url::Url) -> Self {
|
||||||
|
let (close_event_send, close_event_recv) = tokio::sync::mpsc::unbounded_channel();
|
||||||
|
let (conn_send, conn_recv) = tokio::sync::mpsc::channel(100);
|
||||||
|
Self {
|
||||||
|
addr: addr.clone(),
|
||||||
|
socket: None,
|
||||||
|
conn_recv,
|
||||||
|
data: UdpTunnelListenerData::new(addr, conn_send, close_event_send),
|
||||||
|
forward_tasks: Arc::new(std::sync::Mutex::new(JoinSet::new())),
|
||||||
|
close_event_recv,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn get_socket(&self) -> Option<Arc<UdpSocket>> {
|
||||||
|
self.socket.clone()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
impl TunnelListener for UdpTunnelListener {
|
||||||
|
async fn listen(&mut self) -> Result<(), super::TunnelError> {
|
||||||
|
let addr = super::check_scheme_and_get_socket_addr::<SocketAddr>(&self.addr, "udp")?;
|
||||||
|
|
||||||
|
let socket2_socket = socket2::Socket::new(
|
||||||
|
socket2::Domain::for_address(addr),
|
||||||
|
socket2::Type::DGRAM,
|
||||||
|
Some(socket2::Protocol::UDP),
|
||||||
|
)?;
|
||||||
|
|
||||||
|
let tunnel_url: TunnelUrl = self.addr.clone().into();
|
||||||
|
if let Some(bind_dev) = tunnel_url.bind_dev() {
|
||||||
|
setup_sokcet2_ext(&socket2_socket, &addr, Some(bind_dev))?;
|
||||||
|
} else {
|
||||||
|
setup_sokcet2(&socket2_socket, &addr)?;
|
||||||
|
}
|
||||||
|
|
||||||
|
self.socket = Some(Arc::new(UdpSocket::from_std(socket2_socket.into())?));
|
||||||
|
self.data.socket = self.socket.clone();
|
||||||
|
|
||||||
|
self.forward_tasks
|
||||||
|
.lock()
|
||||||
|
.unwrap()
|
||||||
|
.spawn(self.data.clone().do_forward_task());
|
||||||
|
|
||||||
|
join_joinset_background(self.forward_tasks.clone(), "UdpTunnelListener".to_owned());
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn accept(&mut self) -> Result<Box<dyn super::Tunnel>, super::TunnelError> {
|
||||||
|
log::info!("start udp accept: {:?}", self.addr);
|
||||||
|
while let Some(conn) = self.conn_recv.recv().await {
|
||||||
|
return Ok(conn);
|
||||||
|
}
|
||||||
|
return Err(super::TunnelError::InternalError(
|
||||||
|
"udp accept error".to_owned(),
|
||||||
|
));
|
||||||
|
}
|
||||||
|
|
||||||
|
fn local_url(&self) -> url::Url {
|
||||||
|
self.addr.clone()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn get_conn_counter(&self) -> Arc<Box<dyn TunnelConnCounter>> {
|
||||||
|
struct UdpTunnelConnCounter {
|
||||||
|
sock_map: Arc<DashMap<SocketAddr, UdpConnection>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Debug for UdpTunnelConnCounter {
|
||||||
|
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||||
|
f.debug_struct("UdpTunnelConnCounter")
|
||||||
|
.field("sock_map_len", &self.sock_map.len())
|
||||||
|
.finish()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl TunnelConnCounter for UdpTunnelConnCounter {
|
||||||
|
fn get(&self) -> u32 {
|
||||||
|
self.sock_map.len() as u32
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Arc::new(Box::new(UdpTunnelConnCounter {
|
||||||
|
sock_map: self.data.sock_map.clone(),
|
||||||
|
}))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct UdpTunnelConnector {
|
||||||
|
addr: url::Url,
|
||||||
|
bind_addrs: Vec<SocketAddr>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl UdpTunnelConnector {
|
||||||
|
pub fn new(addr: url::Url) -> Self {
|
||||||
|
Self {
|
||||||
|
addr,
|
||||||
|
bind_addrs: vec![],
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn wait_sack(
|
||||||
|
socket: &UdpSocket,
|
||||||
|
addr: SocketAddr,
|
||||||
|
conn_id: u32,
|
||||||
|
magic: u64,
|
||||||
|
) -> Result<SocketAddr, TunnelError> {
|
||||||
|
let mut buf = BytesMut::new();
|
||||||
|
buf.reserve(UDP_DATA_MTU);
|
||||||
|
|
||||||
|
let (usize, recv_addr) = tokio::time::timeout(
|
||||||
|
tokio::time::Duration::from_secs(3),
|
||||||
|
socket.recv_buf_from(&mut buf),
|
||||||
|
)
|
||||||
|
.await??;
|
||||||
|
let zc_packet = get_zcpacket_from_buf(buf.split())?;
|
||||||
|
if recv_addr != addr {
|
||||||
|
tracing::warn!(?recv_addr, ?addr, ?usize, "udp wait sack addr not match");
|
||||||
|
}
|
||||||
|
|
||||||
|
let header = zc_packet.udp_tunnel_header().unwrap();
|
||||||
|
|
||||||
|
if header.conn_id.get() != conn_id {
|
||||||
|
return Err(super::TunnelError::ConnIdNotMatch(
|
||||||
|
header.conn_id.get(),
|
||||||
|
conn_id,
|
||||||
|
));
|
||||||
|
}
|
||||||
|
|
||||||
|
if header.msg_type != UdpPacketType::Sack as u8 {
|
||||||
|
return Err(TunnelError::InvalidPacket("not sack packet".to_owned()));
|
||||||
|
}
|
||||||
|
|
||||||
|
let payload = zc_packet.udp_payload();
|
||||||
|
if payload.len() != 8 {
|
||||||
|
return Err(TunnelError::InvalidPacket(
|
||||||
|
"udp sack packet payload len not match".to_owned(),
|
||||||
|
));
|
||||||
|
}
|
||||||
|
|
||||||
|
let sack_magic = u64::from_le_bytes(payload[..8].try_into().unwrap());
|
||||||
|
if sack_magic != magic {
|
||||||
|
return Err(TunnelError::InvalidPacket(
|
||||||
|
"udp sack magic not match".to_owned(),
|
||||||
|
));
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(recv_addr)
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn wait_sack_loop(
|
||||||
|
socket: &UdpSocket,
|
||||||
|
addr: SocketAddr,
|
||||||
|
conn_id: u32,
|
||||||
|
magic: u64,
|
||||||
|
) -> Result<SocketAddr, super::TunnelError> {
|
||||||
|
loop {
|
||||||
|
let ret = Self::wait_sack(socket, addr, conn_id, magic).await;
|
||||||
|
if ret.is_err() {
|
||||||
|
tracing::debug!(?ret, "udp wait sack error");
|
||||||
|
continue;
|
||||||
|
} else {
|
||||||
|
return ret;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn build_tunnel(
|
||||||
|
&self,
|
||||||
|
socket: UdpSocket,
|
||||||
|
dst_addr: SocketAddr,
|
||||||
|
conn_id: u32,
|
||||||
|
) -> Result<Box<dyn super::Tunnel>, super::TunnelError> {
|
||||||
|
let socket = Arc::new(socket);
|
||||||
|
let ring_for_send_udp = Arc::new(RingTunnel::new(128));
|
||||||
|
let ring_for_recv_udp = Arc::new(RingTunnel::new(128));
|
||||||
|
tracing::debug!(
|
||||||
|
?ring_for_send_udp,
|
||||||
|
?ring_for_recv_udp,
|
||||||
|
"udp build tunnel for connector"
|
||||||
|
);
|
||||||
|
|
||||||
|
let (close_event_send, mut close_event_recv) = tokio::sync::mpsc::unbounded_channel();
|
||||||
|
|
||||||
|
// forward from ring to udp
|
||||||
|
let socket_sender = socket.clone();
|
||||||
|
let ring_recv = RingStream::new(ring_for_send_udp.clone());
|
||||||
|
tokio::spawn(async move {
|
||||||
|
let err = forward_from_ring_to_udp(ring_recv, &socket_sender, &dst_addr, conn_id).await;
|
||||||
|
tracing::debug!(?err, "udp forward from ring to udp done");
|
||||||
|
close_event_send.send(err).unwrap();
|
||||||
|
});
|
||||||
|
|
||||||
|
let socket_recv = socket.clone();
|
||||||
|
let ring_sender = RingSink::new(ring_for_recv_udp.clone());
|
||||||
|
tokio::spawn(async move {
|
||||||
|
let mut buf = BytesMut::new();
|
||||||
|
loop {
|
||||||
|
reserve_buf(&mut buf, UDP_DATA_MTU, UDP_DATA_MTU * 128);
|
||||||
|
let ret;
|
||||||
|
tokio::select! {
|
||||||
|
_ = close_event_recv.recv() => {
|
||||||
|
tracing::debug!("connector udp close event");
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
recv_res = socket_recv.recv_buf_from(&mut buf) => ret = Some(recv_res.unwrap()),
|
||||||
|
}
|
||||||
|
let (dg_size, addr) = ret.unwrap();
|
||||||
|
tracing::trace!(
|
||||||
|
"connector udp recv packet: {:?}, buf: {:?}, size: {}",
|
||||||
|
addr,
|
||||||
|
buf,
|
||||||
|
dg_size
|
||||||
|
);
|
||||||
|
|
||||||
|
let zc_packet = match get_zcpacket_from_buf(buf.split()) {
|
||||||
|
Ok(v) => v,
|
||||||
|
Err(e) => {
|
||||||
|
tracing::warn!(?e, "connector udp get zc packet from buf error");
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
let header = zc_packet.udp_tunnel_header().unwrap();
|
||||||
|
if header.conn_id.get() != conn_id {
|
||||||
|
tracing::trace!(
|
||||||
|
"connector udp conn id not match: {:?}, {:?}",
|
||||||
|
header.conn_id.get(),
|
||||||
|
conn_id
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
if header.msg_type == UdpPacketType::Data as u8 {
|
||||||
|
if let Err(e) = ring_sender.push_no_check(zc_packet) {
|
||||||
|
tracing::trace!(?e, "udp forward packet error");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}.instrument(tracing::info_span!("udp connector forward from udp to ring", ?ring_for_recv_udp)));
|
||||||
|
|
||||||
|
Ok(Box::new(TunnelWrapper::new(
|
||||||
|
Box::new(RingStream::new(ring_for_recv_udp)),
|
||||||
|
Box::new(RingSink::new(ring_for_send_udp)),
|
||||||
|
Some(TunnelInfo {
|
||||||
|
tunnel_type: "udp".to_owned(),
|
||||||
|
local_addr: url::Url::parse(&format!("udp://{}", socket.local_addr()?))
|
||||||
|
.unwrap()
|
||||||
|
.into(),
|
||||||
|
remote_addr: self.addr.clone().into(),
|
||||||
|
}),
|
||||||
|
)))
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn try_connect_with_socket(
|
||||||
|
&self,
|
||||||
|
socket: UdpSocket,
|
||||||
|
) -> Result<Box<dyn super::Tunnel>, super::TunnelError> {
|
||||||
|
let addr = super::check_scheme_and_get_socket_addr::<SocketAddr>(&self.addr, "udp")?;
|
||||||
|
log::warn!("udp connect: {:?}", self.addr);
|
||||||
|
|
||||||
|
#[cfg(target_os = "windows")]
|
||||||
|
crate::arch::windows::disable_connection_reset(&socket)?;
|
||||||
|
|
||||||
|
// send syn
|
||||||
|
let conn_id = rand::random();
|
||||||
|
let magic = rand::random();
|
||||||
|
let udp_packet = new_syn_packet(conn_id, magic).into_bytes(ZCPacketType::UDP);
|
||||||
|
let ret = socket.send_to(&udp_packet, &addr).await?;
|
||||||
|
tracing::warn!(?udp_packet, ?ret, "udp send syn");
|
||||||
|
|
||||||
|
// wait sack
|
||||||
|
let recv_addr = tokio::time::timeout(
|
||||||
|
tokio::time::Duration::from_secs(3),
|
||||||
|
Self::wait_sack_loop(&socket, addr, conn_id, magic),
|
||||||
|
)
|
||||||
|
.await??;
|
||||||
|
|
||||||
|
socket.connect(recv_addr).await?;
|
||||||
|
self.build_tunnel(socket, addr, conn_id).await
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn connect_with_default_bind(&mut self) -> Result<Box<dyn Tunnel>, super::TunnelError> {
|
||||||
|
let socket = UdpSocket::bind("0.0.0.0:0").await?;
|
||||||
|
return self.try_connect_with_socket(socket).await;
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn connect_with_custom_bind(&mut self) -> Result<Box<dyn Tunnel>, super::TunnelError> {
|
||||||
|
let futures = FuturesUnordered::new();
|
||||||
|
|
||||||
|
for bind_addr in self.bind_addrs.iter() {
|
||||||
|
let socket2_socket = socket2::Socket::new(
|
||||||
|
socket2::Domain::for_address(*bind_addr),
|
||||||
|
socket2::Type::DGRAM,
|
||||||
|
Some(socket2::Protocol::UDP),
|
||||||
|
)?;
|
||||||
|
setup_sokcet2(&socket2_socket, &bind_addr)?;
|
||||||
|
let socket = UdpSocket::from_std(socket2_socket.into())?;
|
||||||
|
futures.push(self.try_connect_with_socket(socket));
|
||||||
|
}
|
||||||
|
wait_for_connect_futures(futures).await
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
impl super::TunnelConnector for UdpTunnelConnector {
|
||||||
|
async fn connect(&mut self) -> Result<Box<dyn super::Tunnel>, super::TunnelError> {
|
||||||
|
if self.bind_addrs.is_empty() {
|
||||||
|
self.connect_with_default_bind().await
|
||||||
|
} else {
|
||||||
|
self.connect_with_custom_bind().await
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn remote_url(&self) -> url::Url {
|
||||||
|
self.addr.clone()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn set_bind_addrs(&mut self, addrs: Vec<SocketAddr>) {
|
||||||
|
self.bind_addrs = addrs;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use std::time::Duration;
|
||||||
|
|
||||||
|
use futures::SinkExt;
|
||||||
|
use tokio::time::timeout;
|
||||||
|
|
||||||
|
use super::*;
|
||||||
|
use crate::{
|
||||||
|
common::global_ctx::tests::get_mock_global_ctx,
|
||||||
|
tunnel::{
|
||||||
|
check_scheme_and_get_socket_addr,
|
||||||
|
common::{
|
||||||
|
get_interface_name_by_ip,
|
||||||
|
tests::{_tunnel_bench, _tunnel_echo_server, _tunnel_pingpong},
|
||||||
|
},
|
||||||
|
TunnelConnector,
|
||||||
|
},
|
||||||
|
};
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn udp_pingpong() {
|
||||||
|
let listener = UdpTunnelListener::new("udp://0.0.0.0:5556".parse().unwrap());
|
||||||
|
let connector = UdpTunnelConnector::new("udp://127.0.0.1:5556".parse().unwrap());
|
||||||
|
_tunnel_pingpong(listener, connector).await
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn udp_bench() {
|
||||||
|
let listener = UdpTunnelListener::new("udp://0.0.0.0:5555".parse().unwrap());
|
||||||
|
let connector = UdpTunnelConnector::new("udp://127.0.0.1:5555".parse().unwrap());
|
||||||
|
_tunnel_bench(listener, connector).await
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn udp_bench_with_bind() {
|
||||||
|
let listener = UdpTunnelListener::new("udp://127.0.0.1:5554".parse().unwrap());
|
||||||
|
let mut connector = UdpTunnelConnector::new("udp://127.0.0.1:5554".parse().unwrap());
|
||||||
|
connector.set_bind_addrs(vec!["127.0.0.1:0".parse().unwrap()]);
|
||||||
|
_tunnel_pingpong(listener, connector).await
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
#[should_panic]
|
||||||
|
async fn udp_bench_with_bind_fail() {
|
||||||
|
let listener = UdpTunnelListener::new("udp://127.0.0.1:5553".parse().unwrap());
|
||||||
|
let mut connector = UdpTunnelConnector::new("udp://127.0.0.1:5553".parse().unwrap());
|
||||||
|
connector.set_bind_addrs(vec!["10.0.0.1:0".parse().unwrap()]);
|
||||||
|
_tunnel_pingpong(listener, connector).await
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn send_random_data_to_socket(remote_url: url::Url) {
|
||||||
|
let socket = UdpSocket::bind("0.0.0.0:0").await.unwrap();
|
||||||
|
socket
|
||||||
|
.connect(format!(
|
||||||
|
"{}:{}",
|
||||||
|
remote_url.host().unwrap(),
|
||||||
|
remote_url.port().unwrap()
|
||||||
|
))
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
// get a random 100-len buf
|
||||||
|
loop {
|
||||||
|
let mut buf = vec![0u8; 100];
|
||||||
|
rand::thread_rng().fill(&mut buf[..]);
|
||||||
|
socket.send(&buf).await.unwrap();
|
||||||
|
tokio::time::sleep(tokio::time::Duration::from_millis(50)).await;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn udp_multiple_conns() {
|
||||||
|
let mut listener = UdpTunnelListener::new("udp://0.0.0.0:5557".parse().unwrap());
|
||||||
|
listener.listen().await.unwrap();
|
||||||
|
|
||||||
|
let _lis = tokio::spawn(async move {
|
||||||
|
loop {
|
||||||
|
let ret = listener.accept().await.unwrap();
|
||||||
|
assert_eq!(
|
||||||
|
ret.info().unwrap().local_addr,
|
||||||
|
listener.local_url().to_string()
|
||||||
|
);
|
||||||
|
tokio::spawn(async move { _tunnel_echo_server(ret, false).await });
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
let mut connector1 = UdpTunnelConnector::new("udp://127.0.0.1:5557".parse().unwrap());
|
||||||
|
let mut connector2 = UdpTunnelConnector::new("udp://127.0.0.1:5557".parse().unwrap());
|
||||||
|
|
||||||
|
let t1 = connector1.connect().await.unwrap();
|
||||||
|
let t2 = connector2.connect().await.unwrap();
|
||||||
|
|
||||||
|
tokio::spawn(timeout(
|
||||||
|
Duration::from_secs(2),
|
||||||
|
send_random_data_to_socket(t1.info().unwrap().local_addr.parse().unwrap()),
|
||||||
|
));
|
||||||
|
tokio::spawn(timeout(
|
||||||
|
Duration::from_secs(2),
|
||||||
|
send_random_data_to_socket(t1.info().unwrap().remote_addr.parse().unwrap()),
|
||||||
|
));
|
||||||
|
tokio::spawn(timeout(
|
||||||
|
Duration::from_secs(2),
|
||||||
|
send_random_data_to_socket(t2.info().unwrap().remote_addr.parse().unwrap()),
|
||||||
|
));
|
||||||
|
|
||||||
|
let sender1 = tokio::spawn(async move {
|
||||||
|
let (mut stream, mut sink) = t1.split();
|
||||||
|
|
||||||
|
for i in 0..10 {
|
||||||
|
sink.send(ZCPacket::new_with_payload("hello1".as_bytes()))
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
let recv = stream.next().await.unwrap().unwrap();
|
||||||
|
println!("t1 recv: {:?}, {:?}", recv, i);
|
||||||
|
assert_eq!(recv.payload(), "hello1".as_bytes());
|
||||||
|
tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
let sender2 = tokio::spawn(async move {
|
||||||
|
let (mut stream, mut sink) = t2.split();
|
||||||
|
|
||||||
|
for i in 0..10 {
|
||||||
|
sink.send(ZCPacket::new_with_payload("hello2".as_bytes()))
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
let recv = stream.next().await.unwrap().unwrap();
|
||||||
|
println!("t2 recv: {:?}, {:?}", recv, i);
|
||||||
|
assert_eq!(recv.payload(), "hello2".as_bytes());
|
||||||
|
tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
let _ = tokio::join!(sender1, sender2);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn bind_multi_ip_to_same_dev() {
|
||||||
|
let global_ctx = get_mock_global_ctx();
|
||||||
|
let ips = global_ctx
|
||||||
|
.get_ip_collector()
|
||||||
|
.collect_ip_addrs()
|
||||||
|
.await
|
||||||
|
.interface_ipv4s;
|
||||||
|
if ips.is_empty() {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
let bind_dev = get_interface_name_by_ip(&ips[0].parse().unwrap());
|
||||||
|
|
||||||
|
for ip in ips {
|
||||||
|
println!("bind to ip: {:?}, {:?}", ip, bind_dev);
|
||||||
|
let addr = check_scheme_and_get_socket_addr::<SocketAddr>(
|
||||||
|
&format!("udp://{}:11111", ip).parse().unwrap(),
|
||||||
|
"udp",
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
|
let socket2_socket = socket2::Socket::new(
|
||||||
|
socket2::Domain::for_address(addr),
|
||||||
|
socket2::Type::DGRAM,
|
||||||
|
Some(socket2::Protocol::UDP),
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
|
setup_sokcet2_ext(&socket2_socket, &addr, bind_dev.clone()).unwrap();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,827 @@
|
|||||||
|
use std::{
|
||||||
|
collections::hash_map::DefaultHasher,
|
||||||
|
fmt::{Debug, Formatter},
|
||||||
|
hash::Hasher,
|
||||||
|
net::SocketAddr,
|
||||||
|
pin::Pin,
|
||||||
|
sync::{atomic::AtomicBool, Arc},
|
||||||
|
time::Duration,
|
||||||
|
};
|
||||||
|
|
||||||
|
use anyhow::Context;
|
||||||
|
use async_recursion::async_recursion;
|
||||||
|
use async_trait::async_trait;
|
||||||
|
use boringtun::{
|
||||||
|
noise::{errors::WireGuardError, Tunn, TunnResult},
|
||||||
|
x25519::{PublicKey, StaticSecret},
|
||||||
|
};
|
||||||
|
use bytes::BytesMut;
|
||||||
|
use dashmap::DashMap;
|
||||||
|
use futures::{stream::FuturesUnordered, SinkExt, StreamExt};
|
||||||
|
use rand::RngCore;
|
||||||
|
use tokio::{net::UdpSocket, sync::Mutex, task::JoinSet};
|
||||||
|
|
||||||
|
use crate::{
|
||||||
|
rpc::TunnelInfo,
|
||||||
|
tunnel::{
|
||||||
|
build_url_from_socket_addr,
|
||||||
|
common::TunnelWrapper,
|
||||||
|
packet_def::{ZCPacket, WG_TUNNEL_HEADER_SIZE},
|
||||||
|
},
|
||||||
|
};
|
||||||
|
|
||||||
|
use super::{
|
||||||
|
check_scheme_and_get_socket_addr,
|
||||||
|
common::{setup_sokcet2, setup_sokcet2_ext, wait_for_connect_futures},
|
||||||
|
packet_def::{ZCPacketType, PEER_MANAGER_HEADER_SIZE},
|
||||||
|
ring::create_ring_tunnel_pair,
|
||||||
|
Tunnel, TunnelError, TunnelListener, TunnelUrl, ZCPacketSink, ZCPacketStream,
|
||||||
|
};
|
||||||
|
|
||||||
|
const MAX_PACKET: usize = 65500;
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
enum WgType {
|
||||||
|
// used by easytier peer, need remove/add ip header for in/out wg msg
|
||||||
|
InternalUse,
|
||||||
|
// used by wireguard peer, keep original ip header
|
||||||
|
ExternalUse,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone)]
|
||||||
|
pub struct WgConfig {
|
||||||
|
my_secret_key: StaticSecret,
|
||||||
|
my_public_key: PublicKey,
|
||||||
|
|
||||||
|
peer_secret_key: StaticSecret,
|
||||||
|
peer_public_key: PublicKey,
|
||||||
|
|
||||||
|
wg_type: WgType,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl WgConfig {
|
||||||
|
pub fn new_from_network_identity(network_name: &str, network_secret: &str) -> Self {
|
||||||
|
let mut my_sec = [0u8; 32];
|
||||||
|
let mut hasher = DefaultHasher::new();
|
||||||
|
hasher.write(network_name.as_bytes());
|
||||||
|
hasher.write(network_secret.as_bytes());
|
||||||
|
my_sec[0..8].copy_from_slice(&hasher.finish().to_be_bytes());
|
||||||
|
hasher.write(&my_sec[0..8]);
|
||||||
|
my_sec[8..16].copy_from_slice(&hasher.finish().to_be_bytes());
|
||||||
|
hasher.write(&my_sec[0..16]);
|
||||||
|
my_sec[16..24].copy_from_slice(&hasher.finish().to_be_bytes());
|
||||||
|
hasher.write(&my_sec[0..24]);
|
||||||
|
my_sec[24..32].copy_from_slice(&hasher.finish().to_be_bytes());
|
||||||
|
|
||||||
|
let my_secret_key = StaticSecret::from(my_sec);
|
||||||
|
let my_public_key = PublicKey::from(&my_secret_key);
|
||||||
|
let peer_secret_key = StaticSecret::from(my_sec);
|
||||||
|
let peer_public_key = my_public_key.clone();
|
||||||
|
|
||||||
|
WgConfig {
|
||||||
|
my_secret_key,
|
||||||
|
my_public_key,
|
||||||
|
peer_secret_key,
|
||||||
|
peer_public_key,
|
||||||
|
|
||||||
|
wg_type: WgType::InternalUse,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn new_for_portal(server_key_seed: &str, client_key_seed: &str) -> Self {
|
||||||
|
let server_cfg = Self::new_from_network_identity("server", server_key_seed);
|
||||||
|
let client_cfg = Self::new_from_network_identity("client", client_key_seed);
|
||||||
|
Self {
|
||||||
|
my_secret_key: server_cfg.my_secret_key,
|
||||||
|
my_public_key: server_cfg.my_public_key,
|
||||||
|
peer_secret_key: client_cfg.my_secret_key,
|
||||||
|
peer_public_key: client_cfg.my_public_key,
|
||||||
|
|
||||||
|
wg_type: WgType::ExternalUse,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn my_secret_key(&self) -> &[u8] {
|
||||||
|
self.my_secret_key.as_bytes()
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn peer_secret_key(&self) -> &[u8] {
|
||||||
|
self.peer_secret_key.as_bytes()
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn my_public_key(&self) -> &[u8] {
|
||||||
|
self.my_public_key.as_bytes()
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn peer_public_key(&self) -> &[u8] {
|
||||||
|
self.peer_public_key.as_bytes()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone)]
|
||||||
|
struct WgPeerData {
|
||||||
|
udp: Arc<UdpSocket>, // only for send
|
||||||
|
endpoint: SocketAddr,
|
||||||
|
tunn: Arc<Mutex<Tunn>>,
|
||||||
|
wg_type: WgType,
|
||||||
|
stopped: Arc<AtomicBool>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Debug for WgPeerData {
|
||||||
|
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
|
||||||
|
f.debug_struct("WgPeerData")
|
||||||
|
.field("endpoint", &self.endpoint)
|
||||||
|
.field("local", &self.udp.local_addr())
|
||||||
|
.finish()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl WgPeerData {
|
||||||
|
#[tracing::instrument]
|
||||||
|
async fn handle_one_packet_from_me(
|
||||||
|
&self,
|
||||||
|
mut zc_packet: ZCPacket,
|
||||||
|
) -> Result<(), anyhow::Error> {
|
||||||
|
let mut send_buf = vec![0u8; MAX_PACKET];
|
||||||
|
|
||||||
|
let packet = if matches!(self.wg_type, WgType::InternalUse) {
|
||||||
|
Self::fill_ip_header(&mut zc_packet);
|
||||||
|
zc_packet.into_bytes(ZCPacketType::WG)
|
||||||
|
} else {
|
||||||
|
zc_packet.into_bytes(ZCPacketType::WG)
|
||||||
|
};
|
||||||
|
tracing::trace!(?packet, "Sending packet to peer");
|
||||||
|
|
||||||
|
let encapsulate_result = {
|
||||||
|
let mut peer = self.tunn.lock().await;
|
||||||
|
peer.encapsulate(&packet, &mut send_buf)
|
||||||
|
};
|
||||||
|
|
||||||
|
tracing::trace!(
|
||||||
|
?encapsulate_result,
|
||||||
|
"Received {} bytes from me",
|
||||||
|
packet.len()
|
||||||
|
);
|
||||||
|
|
||||||
|
match encapsulate_result {
|
||||||
|
TunnResult::WriteToNetwork(packet) => {
|
||||||
|
self.udp
|
||||||
|
.send_to(packet, self.endpoint)
|
||||||
|
.await
|
||||||
|
.context("Failed to send encrypted IP packet to WireGuard endpoint.")?;
|
||||||
|
tracing::debug!(
|
||||||
|
"Sent {} bytes to WireGuard endpoint (encrypted IP packet)",
|
||||||
|
packet.len()
|
||||||
|
);
|
||||||
|
}
|
||||||
|
TunnResult::Err(e) => {
|
||||||
|
tracing::error!("Failed to encapsulate IP packet: {:?}", e);
|
||||||
|
}
|
||||||
|
TunnResult::Done => {
|
||||||
|
// Ignored
|
||||||
|
}
|
||||||
|
other => {
|
||||||
|
tracing::error!(
|
||||||
|
"Unexpected WireGuard state during encapsulation: {:?}",
|
||||||
|
other
|
||||||
|
);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// WireGuard consumption task. Receives encrypted packets from the WireGuard endpoint,
|
||||||
|
/// decapsulates them, and dispatches newly received IP packets.
|
||||||
|
#[tracing::instrument(skip(sink))]
|
||||||
|
pub async fn handle_one_packet_from_peer<S: ZCPacketSink + Unpin>(
|
||||||
|
&self,
|
||||||
|
mut sink: S,
|
||||||
|
recv_buf: &[u8],
|
||||||
|
) {
|
||||||
|
let mut send_buf = vec![0u8; MAX_PACKET];
|
||||||
|
let data = &recv_buf[..];
|
||||||
|
let decapsulate_result = {
|
||||||
|
let mut peer = self.tunn.lock().await;
|
||||||
|
peer.decapsulate(None, data, &mut send_buf)
|
||||||
|
};
|
||||||
|
|
||||||
|
tracing::debug!("Decapsulation result: {:?}", decapsulate_result);
|
||||||
|
|
||||||
|
match decapsulate_result {
|
||||||
|
TunnResult::WriteToNetwork(packet) => {
|
||||||
|
match self.udp.send_to(packet, self.endpoint).await {
|
||||||
|
Ok(_) => {}
|
||||||
|
Err(e) => {
|
||||||
|
tracing::error!("Failed to send decapsulation-instructed packet to WireGuard endpoint: {:?}", e);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
let mut peer = self.tunn.lock().await;
|
||||||
|
loop {
|
||||||
|
let mut send_buf = vec![0u8; MAX_PACKET];
|
||||||
|
match peer.decapsulate(None, &[], &mut send_buf) {
|
||||||
|
TunnResult::WriteToNetwork(packet) => {
|
||||||
|
match self.udp.send_to(packet, self.endpoint).await {
|
||||||
|
Ok(_) => {}
|
||||||
|
Err(e) => {
|
||||||
|
tracing::error!("Failed to send decapsulation-instructed packet to WireGuard endpoint: {:?}", e);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
}
|
||||||
|
_ => {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
TunnResult::WriteToTunnelV4(packet, _) | TunnResult::WriteToTunnelV6(packet, _) => {
|
||||||
|
tracing::debug!(
|
||||||
|
?packet,
|
||||||
|
"receive IP packet from peer: {} bytes",
|
||||||
|
packet.len()
|
||||||
|
);
|
||||||
|
let mut b = BytesMut::new();
|
||||||
|
if matches!(self.wg_type, WgType::InternalUse) {
|
||||||
|
b.resize(WG_TUNNEL_HEADER_SIZE, 0);
|
||||||
|
b.extend_from_slice(self.remove_ip_header(packet, packet[0] >> 4 == 4));
|
||||||
|
} else {
|
||||||
|
b.extend_from_slice(packet);
|
||||||
|
};
|
||||||
|
let zc_packet = ZCPacket::new_from_buf(b, ZCPacketType::WG);
|
||||||
|
tracing::trace!(?zc_packet, "forward zc_packet to sink");
|
||||||
|
let ret = sink.send(zc_packet).await;
|
||||||
|
if ret.is_err() {
|
||||||
|
tracing::error!("Failed to send packet to tunnel: {:?}", ret);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
_ => {
|
||||||
|
tracing::warn!(
|
||||||
|
"Unexpected WireGuard state during decapsulation: {:?}",
|
||||||
|
decapsulate_result
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tracing::instrument]
|
||||||
|
#[async_recursion]
|
||||||
|
async fn handle_routine_tun_result<'a: 'async_recursion>(&self, result: TunnResult<'a>) -> () {
|
||||||
|
match result {
|
||||||
|
TunnResult::WriteToNetwork(packet) => {
|
||||||
|
tracing::debug!(
|
||||||
|
"Sending routine packet of {} bytes to WireGuard endpoint",
|
||||||
|
packet.len()
|
||||||
|
);
|
||||||
|
match self.udp.send_to(packet, self.endpoint).await {
|
||||||
|
Ok(_) => {}
|
||||||
|
Err(e) => {
|
||||||
|
tracing::error!(
|
||||||
|
"Failed to send routine packet to WireGuard endpoint: {:?}",
|
||||||
|
e
|
||||||
|
);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
}
|
||||||
|
TunnResult::Err(WireGuardError::ConnectionExpired) => {
|
||||||
|
tracing::warn!("Wireguard handshake has expired!");
|
||||||
|
|
||||||
|
let mut buf = vec![0u8; MAX_PACKET];
|
||||||
|
let result = self
|
||||||
|
.tunn
|
||||||
|
.lock()
|
||||||
|
.await
|
||||||
|
.format_handshake_initiation(&mut buf[..], false);
|
||||||
|
|
||||||
|
self.handle_routine_tun_result(result).await
|
||||||
|
}
|
||||||
|
TunnResult::Err(e) => {
|
||||||
|
tracing::error!(
|
||||||
|
"Failed to prepare routine packet for WireGuard endpoint: {:?}",
|
||||||
|
e
|
||||||
|
);
|
||||||
|
}
|
||||||
|
TunnResult::Done => {
|
||||||
|
// Sleep for a bit
|
||||||
|
tokio::time::sleep(Duration::from_millis(250)).await;
|
||||||
|
}
|
||||||
|
other => {
|
||||||
|
tracing::warn!("Unexpected WireGuard routine task state: {:?}", other);
|
||||||
|
tokio::time::sleep(Duration::from_millis(250)).await;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
/// WireGuard Routine task. Handles Handshake, keep-alive, etc.
|
||||||
|
pub async fn routine_task(self) {
|
||||||
|
loop {
|
||||||
|
let mut send_buf = vec![0u8; MAX_PACKET];
|
||||||
|
let tun_result = { self.tunn.lock().await.update_timers(&mut send_buf) };
|
||||||
|
self.handle_routine_tun_result(tun_result).await;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn fill_ip_header(zc_packet: &mut ZCPacket) {
|
||||||
|
let len = zc_packet.payload_len() + PEER_MANAGER_HEADER_SIZE;
|
||||||
|
let ip_header = &mut zc_packet.mut_wg_tunnel_header().unwrap().ipv4_header;
|
||||||
|
ip_header[0] = 0x45;
|
||||||
|
ip_header[1] = 0;
|
||||||
|
ip_header[2..4].copy_from_slice(&((len + 20) as u16).to_be_bytes());
|
||||||
|
ip_header[4..6].copy_from_slice(&0u16.to_be_bytes());
|
||||||
|
ip_header[6..8].copy_from_slice(&0u16.to_be_bytes());
|
||||||
|
ip_header[8] = 64;
|
||||||
|
ip_header[9] = 0;
|
||||||
|
ip_header[10..12].copy_from_slice(&0u16.to_be_bytes());
|
||||||
|
ip_header[12..16].copy_from_slice(&0u32.to_be_bytes());
|
||||||
|
ip_header[16..20].copy_from_slice(&0u32.to_be_bytes());
|
||||||
|
}
|
||||||
|
|
||||||
|
fn remove_ip_header<'a>(&self, packet: &'a [u8], is_v4: bool) -> &'a [u8] {
|
||||||
|
if is_v4 {
|
||||||
|
return &packet[20..];
|
||||||
|
} else {
|
||||||
|
return &packet[40..];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
struct WgPeer {
|
||||||
|
udp: Arc<UdpSocket>, // only for send
|
||||||
|
config: WgConfig,
|
||||||
|
endpoint: SocketAddr,
|
||||||
|
|
||||||
|
sink: std::sync::Mutex<Option<Pin<Box<dyn ZCPacketSink>>>>,
|
||||||
|
|
||||||
|
data: Option<WgPeerData>,
|
||||||
|
tasks: JoinSet<()>,
|
||||||
|
|
||||||
|
access_time: std::time::Instant,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl WgPeer {
|
||||||
|
fn new(udp: Arc<UdpSocket>, config: WgConfig, endpoint: SocketAddr) -> Self {
|
||||||
|
WgPeer {
|
||||||
|
udp,
|
||||||
|
config,
|
||||||
|
endpoint,
|
||||||
|
|
||||||
|
sink: std::sync::Mutex::new(None),
|
||||||
|
|
||||||
|
data: None,
|
||||||
|
tasks: JoinSet::new(),
|
||||||
|
|
||||||
|
access_time: std::time::Instant::now(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn handle_packet_from_me<S: ZCPacketStream + Unpin>(mut stream: S, data: WgPeerData) {
|
||||||
|
while let Some(Ok(packet)) = stream.next().await {
|
||||||
|
let ret = data.handle_one_packet_from_me(packet).await;
|
||||||
|
if let Err(e) = ret {
|
||||||
|
tracing::error!("Failed to handle packet from me: {}", e);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
data.stopped
|
||||||
|
.store(true, std::sync::atomic::Ordering::Relaxed);
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn handle_packet_from_peer(&mut self, packet: &[u8]) {
|
||||||
|
self.access_time = std::time::Instant::now();
|
||||||
|
tracing::trace!("Received {} bytes from peer", packet.len());
|
||||||
|
let data = self.data.as_ref().unwrap();
|
||||||
|
// TODO: improve this
|
||||||
|
let mut sink = self.sink.lock().unwrap().take().unwrap();
|
||||||
|
data.handle_one_packet_from_peer(&mut sink, packet).await;
|
||||||
|
self.sink.lock().unwrap().replace(sink);
|
||||||
|
}
|
||||||
|
|
||||||
|
fn start_and_get_tunnel(&mut self) -> Box<dyn Tunnel> {
|
||||||
|
let (stunnel, ctunnel) = create_ring_tunnel_pair();
|
||||||
|
|
||||||
|
let (stream, sink) = stunnel.split();
|
||||||
|
|
||||||
|
let data = WgPeerData {
|
||||||
|
udp: self.udp.clone(),
|
||||||
|
endpoint: self.endpoint,
|
||||||
|
tunn: Arc::new(Mutex::new(
|
||||||
|
Tunn::new(
|
||||||
|
self.config.my_secret_key.clone(),
|
||||||
|
self.config.peer_public_key.clone(),
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
rand::thread_rng().next_u32(),
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
.unwrap(),
|
||||||
|
)),
|
||||||
|
wg_type: self.config.wg_type.clone(),
|
||||||
|
stopped: Arc::new(AtomicBool::new(false)),
|
||||||
|
};
|
||||||
|
|
||||||
|
self.data = Some(data.clone());
|
||||||
|
self.sink.lock().unwrap().replace(sink);
|
||||||
|
|
||||||
|
self.tasks
|
||||||
|
.spawn(Self::handle_packet_from_me(stream, data.clone()));
|
||||||
|
self.tasks.spawn(data.routine_task());
|
||||||
|
|
||||||
|
ctunnel
|
||||||
|
}
|
||||||
|
|
||||||
|
fn stopped(&self) -> bool {
|
||||||
|
self.data
|
||||||
|
.as_ref()
|
||||||
|
.unwrap()
|
||||||
|
.stopped
|
||||||
|
.load(std::sync::atomic::Ordering::Relaxed)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type ConnSender = tokio::sync::mpsc::UnboundedSender<Box<dyn Tunnel>>;
|
||||||
|
type ConnReceiver = tokio::sync::mpsc::UnboundedReceiver<Box<dyn Tunnel>>;
|
||||||
|
|
||||||
|
pub struct WgTunnelListener {
|
||||||
|
addr: url::Url,
|
||||||
|
config: WgConfig,
|
||||||
|
|
||||||
|
udp: Option<Arc<UdpSocket>>,
|
||||||
|
conn_recv: ConnReceiver,
|
||||||
|
conn_send: Option<ConnSender>,
|
||||||
|
|
||||||
|
wg_peer_map: Arc<DashMap<SocketAddr, WgPeer>>,
|
||||||
|
|
||||||
|
tasks: JoinSet<()>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl WgTunnelListener {
|
||||||
|
pub fn new(addr: url::Url, config: WgConfig) -> Self {
|
||||||
|
let (conn_send, conn_recv) = tokio::sync::mpsc::unbounded_channel();
|
||||||
|
WgTunnelListener {
|
||||||
|
addr,
|
||||||
|
config,
|
||||||
|
|
||||||
|
udp: None,
|
||||||
|
conn_recv,
|
||||||
|
conn_send: Some(conn_send),
|
||||||
|
|
||||||
|
wg_peer_map: Arc::new(DashMap::new()),
|
||||||
|
|
||||||
|
tasks: JoinSet::new(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn get_udp_socket(&self) -> Arc<UdpSocket> {
|
||||||
|
self.udp.as_ref().unwrap().clone()
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn handle_udp_incoming(
|
||||||
|
socket: Arc<UdpSocket>,
|
||||||
|
config: WgConfig,
|
||||||
|
conn_sender: ConnSender,
|
||||||
|
peer_map: Arc<DashMap<SocketAddr, WgPeer>>,
|
||||||
|
) {
|
||||||
|
let mut tasks = JoinSet::new();
|
||||||
|
|
||||||
|
let peer_map_clone = peer_map.clone();
|
||||||
|
tasks.spawn(async move {
|
||||||
|
loop {
|
||||||
|
peer_map_clone
|
||||||
|
.retain(|_, peer| peer.access_time.elapsed().as_secs() < 61 && !peer.stopped());
|
||||||
|
tokio::time::sleep(Duration::from_secs(1)).await;
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
let mut buf = vec![0u8; MAX_PACKET];
|
||||||
|
loop {
|
||||||
|
let Ok((n, addr)) = socket.recv_from(&mut buf).await else {
|
||||||
|
tracing::error!("Failed to receive from UDP socket");
|
||||||
|
break;
|
||||||
|
};
|
||||||
|
|
||||||
|
let data = &buf[..n];
|
||||||
|
tracing::trace!(?n, ?addr, "Received bytes from peer");
|
||||||
|
|
||||||
|
if !peer_map.contains_key(&addr) {
|
||||||
|
tracing::info!("New peer: {}", addr);
|
||||||
|
let mut wg = WgPeer::new(socket.clone(), config.clone(), addr.clone());
|
||||||
|
let (stream, sink) = wg.start_and_get_tunnel().split();
|
||||||
|
let tunnel = Box::new(TunnelWrapper::new(
|
||||||
|
stream,
|
||||||
|
sink,
|
||||||
|
Some(TunnelInfo {
|
||||||
|
tunnel_type: "wg".to_owned(),
|
||||||
|
local_addr: build_url_from_socket_addr(
|
||||||
|
&socket.local_addr().unwrap().to_string(),
|
||||||
|
"wg",
|
||||||
|
)
|
||||||
|
.into(),
|
||||||
|
remote_addr: build_url_from_socket_addr(&addr.to_string(), "wg").into(),
|
||||||
|
}),
|
||||||
|
));
|
||||||
|
if let Err(e) = conn_sender.send(tunnel) {
|
||||||
|
tracing::error!("Failed to send tunnel to conn_sender: {}", e);
|
||||||
|
}
|
||||||
|
peer_map.insert(addr, wg);
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut peer = peer_map.get_mut(&addr).unwrap();
|
||||||
|
peer.handle_packet_from_peer(data).await;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
impl TunnelListener for WgTunnelListener {
|
||||||
|
async fn listen(&mut self) -> Result<(), super::TunnelError> {
|
||||||
|
let addr = check_scheme_and_get_socket_addr::<SocketAddr>(&self.addr, "wg")?;
|
||||||
|
let socket2_socket = socket2::Socket::new(
|
||||||
|
socket2::Domain::for_address(addr),
|
||||||
|
socket2::Type::DGRAM,
|
||||||
|
Some(socket2::Protocol::UDP),
|
||||||
|
)?;
|
||||||
|
|
||||||
|
let tunnel_url: TunnelUrl = self.addr.clone().into();
|
||||||
|
if let Some(bind_dev) = tunnel_url.bind_dev() {
|
||||||
|
setup_sokcet2_ext(&socket2_socket, &addr, Some(bind_dev))?;
|
||||||
|
} else {
|
||||||
|
setup_sokcet2(&socket2_socket, &addr)?;
|
||||||
|
}
|
||||||
|
|
||||||
|
self.udp = Some(Arc::new(UdpSocket::from_std(socket2_socket.into())?));
|
||||||
|
self.tasks.spawn(Self::handle_udp_incoming(
|
||||||
|
self.get_udp_socket(),
|
||||||
|
self.config.clone(),
|
||||||
|
self.conn_send.take().unwrap(),
|
||||||
|
self.wg_peer_map.clone(),
|
||||||
|
));
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn accept(&mut self) -> Result<Box<dyn Tunnel>, super::TunnelError> {
|
||||||
|
while let Some(tunnel) = self.conn_recv.recv().await {
|
||||||
|
tracing::info!(?tunnel, "Accepted tunnel");
|
||||||
|
return Ok(tunnel);
|
||||||
|
}
|
||||||
|
Err(TunnelError::Shutdown)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn local_url(&self) -> url::Url {
|
||||||
|
self.addr.clone()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone)]
|
||||||
|
pub struct WgTunnelConnector {
|
||||||
|
addr: url::Url,
|
||||||
|
config: WgConfig,
|
||||||
|
udp: Option<Arc<UdpSocket>>,
|
||||||
|
|
||||||
|
bind_addrs: Vec<SocketAddr>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Debug for WgTunnelConnector {
|
||||||
|
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
|
||||||
|
f.debug_struct("WgTunnelConnector")
|
||||||
|
.field("addr", &self.addr)
|
||||||
|
.field("udp", &self.udp)
|
||||||
|
.finish()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl WgTunnelConnector {
|
||||||
|
pub fn new(addr: url::Url, config: WgConfig) -> Self {
|
||||||
|
WgTunnelConnector {
|
||||||
|
addr,
|
||||||
|
config,
|
||||||
|
udp: None,
|
||||||
|
bind_addrs: vec![],
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn create_handshake_init(tun: &mut Tunn) -> Vec<u8> {
|
||||||
|
let mut dst = vec![0u8; 2048];
|
||||||
|
let handshake_init = tun.format_handshake_initiation(&mut dst, false);
|
||||||
|
assert!(matches!(handshake_init, TunnResult::WriteToNetwork(_)));
|
||||||
|
let handshake_init = if let TunnResult::WriteToNetwork(sent) = handshake_init {
|
||||||
|
sent
|
||||||
|
} else {
|
||||||
|
unreachable!();
|
||||||
|
};
|
||||||
|
|
||||||
|
handshake_init.into()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn parse_handshake_resp(tun: &mut Tunn, handshake_resp: &[u8]) -> Vec<u8> {
|
||||||
|
let mut dst = vec![0u8; 2048];
|
||||||
|
let keepalive = tun.decapsulate(None, handshake_resp, &mut dst);
|
||||||
|
assert!(
|
||||||
|
matches!(keepalive, TunnResult::WriteToNetwork(_)),
|
||||||
|
"Failed to parse handshake response, {:?}",
|
||||||
|
keepalive
|
||||||
|
);
|
||||||
|
|
||||||
|
let keepalive = if let TunnResult::WriteToNetwork(sent) = keepalive {
|
||||||
|
sent
|
||||||
|
} else {
|
||||||
|
unreachable!();
|
||||||
|
};
|
||||||
|
|
||||||
|
keepalive.into()
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tracing::instrument(skip(config))]
|
||||||
|
async fn connect_with_socket(
|
||||||
|
addr_url: url::Url,
|
||||||
|
config: WgConfig,
|
||||||
|
udp: UdpSocket,
|
||||||
|
) -> Result<Box<dyn super::Tunnel>, super::TunnelError> {
|
||||||
|
let addr = super::check_scheme_and_get_socket_addr::<SocketAddr>(&addr_url, "wg")?;
|
||||||
|
tracing::warn!("wg connect: {:?}", addr);
|
||||||
|
let local_addr = udp.local_addr().unwrap().to_string();
|
||||||
|
|
||||||
|
let mut wg_peer = WgPeer::new(Arc::new(udp), config.clone(), addr);
|
||||||
|
let tunnel = wg_peer.start_and_get_tunnel();
|
||||||
|
|
||||||
|
let data = wg_peer.data.as_ref().unwrap().clone();
|
||||||
|
let mut sink = wg_peer.sink.lock().unwrap().take().unwrap();
|
||||||
|
wg_peer.tasks.spawn(async move {
|
||||||
|
loop {
|
||||||
|
let mut buf = vec![0u8; MAX_PACKET];
|
||||||
|
let (n, recv_addr) = data.udp.recv_from(&mut buf).await.unwrap();
|
||||||
|
if recv_addr != addr {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
data.handle_one_packet_from_peer(&mut sink, &buf[..n]).await;
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
let (stream, sink) = tunnel.split();
|
||||||
|
let ret = Box::new(TunnelWrapper::new_with_associate_data(
|
||||||
|
stream,
|
||||||
|
sink,
|
||||||
|
Some(TunnelInfo {
|
||||||
|
tunnel_type: "wg".to_owned(),
|
||||||
|
local_addr: super::build_url_from_socket_addr(&local_addr, "wg").into(),
|
||||||
|
remote_addr: addr_url.to_string(),
|
||||||
|
}),
|
||||||
|
Some(Box::new(wg_peer)),
|
||||||
|
));
|
||||||
|
|
||||||
|
Ok(ret)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
impl super::TunnelConnector for WgTunnelConnector {
|
||||||
|
#[tracing::instrument]
|
||||||
|
async fn connect(&mut self) -> Result<Box<dyn super::Tunnel>, super::TunnelError> {
|
||||||
|
let bind_addrs = if self.bind_addrs.is_empty() {
|
||||||
|
vec!["0.0.0.0:0".parse().unwrap()]
|
||||||
|
} else {
|
||||||
|
self.bind_addrs.clone()
|
||||||
|
};
|
||||||
|
let futures = FuturesUnordered::new();
|
||||||
|
|
||||||
|
for bind_addr in bind_addrs.into_iter() {
|
||||||
|
let socket2_socket = socket2::Socket::new(
|
||||||
|
socket2::Domain::for_address(bind_addr),
|
||||||
|
socket2::Type::DGRAM,
|
||||||
|
Some(socket2::Protocol::UDP),
|
||||||
|
)?;
|
||||||
|
setup_sokcet2(&socket2_socket, &bind_addr)?;
|
||||||
|
let socket = UdpSocket::from_std(socket2_socket.into())?;
|
||||||
|
tracing::info!(?bind_addr, ?self.addr, "prepare wg connect task");
|
||||||
|
futures.push(Self::connect_with_socket(
|
||||||
|
self.addr.clone(),
|
||||||
|
self.config.clone(),
|
||||||
|
socket,
|
||||||
|
));
|
||||||
|
}
|
||||||
|
|
||||||
|
wait_for_connect_futures(futures).await
|
||||||
|
}
|
||||||
|
|
||||||
|
fn remote_url(&self) -> url::Url {
|
||||||
|
self.addr.clone()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn set_bind_addrs(&mut self, addrs: Vec<SocketAddr>) {
|
||||||
|
self.bind_addrs = addrs;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
pub mod tests {
|
||||||
|
use super::*;
|
||||||
|
use crate::tunnel::{
|
||||||
|
common::tests::{_tunnel_bench, _tunnel_pingpong},
|
||||||
|
TunnelConnector,
|
||||||
|
};
|
||||||
|
use boringtun::*;
|
||||||
|
|
||||||
|
pub fn create_wg_config() -> (WgConfig, WgConfig) {
|
||||||
|
let my_secret_key = x25519::StaticSecret::random_from_rng(rand::thread_rng());
|
||||||
|
let my_public_key = x25519::PublicKey::from(&my_secret_key);
|
||||||
|
|
||||||
|
let their_secret_key = x25519::StaticSecret::random_from_rng(rand::thread_rng());
|
||||||
|
let their_public_key = x25519::PublicKey::from(&their_secret_key);
|
||||||
|
|
||||||
|
let server_cfg = WgConfig {
|
||||||
|
my_secret_key: my_secret_key.clone(),
|
||||||
|
my_public_key,
|
||||||
|
peer_secret_key: their_secret_key.clone(),
|
||||||
|
peer_public_key: their_public_key.clone(),
|
||||||
|
wg_type: WgType::InternalUse,
|
||||||
|
};
|
||||||
|
|
||||||
|
let client_cfg = WgConfig {
|
||||||
|
my_secret_key: their_secret_key,
|
||||||
|
my_public_key: their_public_key,
|
||||||
|
peer_secret_key: my_secret_key,
|
||||||
|
peer_public_key: my_public_key,
|
||||||
|
wg_type: WgType::InternalUse,
|
||||||
|
};
|
||||||
|
|
||||||
|
(server_cfg, client_cfg)
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn wg_pingpong() {
|
||||||
|
let (server_cfg, client_cfg) = create_wg_config();
|
||||||
|
let listener = WgTunnelListener::new("wg://0.0.0.0:5599".parse().unwrap(), server_cfg);
|
||||||
|
let connector = WgTunnelConnector::new("wg://127.0.0.1:5599".parse().unwrap(), client_cfg);
|
||||||
|
_tunnel_pingpong(listener, connector).await
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn wg_bench() {
|
||||||
|
let (server_cfg, client_cfg) = create_wg_config();
|
||||||
|
let listener = WgTunnelListener::new("wg://0.0.0.0:5598".parse().unwrap(), server_cfg);
|
||||||
|
let connector = WgTunnelConnector::new("wg://127.0.0.1:5598".parse().unwrap(), client_cfg);
|
||||||
|
_tunnel_bench(listener, connector).await
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn wg_bench_with_bind() {
|
||||||
|
let (server_cfg, client_cfg) = create_wg_config();
|
||||||
|
let listener = WgTunnelListener::new("wg://127.0.0.1:5597".parse().unwrap(), server_cfg);
|
||||||
|
let mut connector =
|
||||||
|
WgTunnelConnector::new("wg://127.0.0.1:5597".parse().unwrap(), client_cfg);
|
||||||
|
connector.set_bind_addrs(vec!["127.0.0.1:0".parse().unwrap()]);
|
||||||
|
_tunnel_pingpong(listener, connector).await
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
#[should_panic]
|
||||||
|
async fn wg_bench_with_bind_fail() {
|
||||||
|
let (server_cfg, client_cfg) = create_wg_config();
|
||||||
|
let listener = WgTunnelListener::new("wg://127.0.0.1:5596".parse().unwrap(), server_cfg);
|
||||||
|
let mut connector =
|
||||||
|
WgTunnelConnector::new("wg://127.0.0.1:5596".parse().unwrap(), client_cfg);
|
||||||
|
connector.set_bind_addrs(vec!["10.0.0.1:0".parse().unwrap()]);
|
||||||
|
_tunnel_pingpong(listener, connector).await
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn wg_server_erase_from_map_after_close() {
|
||||||
|
let (server_cfg, client_cfg) = create_wg_config();
|
||||||
|
let mut listener =
|
||||||
|
WgTunnelListener::new("wg://127.0.0.1:5595".parse().unwrap(), server_cfg);
|
||||||
|
listener.listen().await.unwrap();
|
||||||
|
|
||||||
|
const CONN_COUNT: usize = 10;
|
||||||
|
|
||||||
|
tokio::spawn(async move {
|
||||||
|
let mut tunnels = vec![];
|
||||||
|
for _ in 0..CONN_COUNT {
|
||||||
|
let mut connector = WgTunnelConnector::new(
|
||||||
|
"wg://127.0.0.1:5595".parse().unwrap(),
|
||||||
|
client_cfg.clone(),
|
||||||
|
);
|
||||||
|
let ret = connector.connect().await;
|
||||||
|
assert!(ret.is_ok());
|
||||||
|
let t = ret.unwrap();
|
||||||
|
let (_stream, mut sink) = t.split();
|
||||||
|
sink.send(ZCPacket::new_with_payload("payload".as_bytes()))
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
tunnels.push(t);
|
||||||
|
}
|
||||||
|
tokio::time::sleep(tokio::time::Duration::from_secs(1)).await;
|
||||||
|
});
|
||||||
|
|
||||||
|
for _ in 0..CONN_COUNT {
|
||||||
|
println!("accepting");
|
||||||
|
let conn = listener.accept().await;
|
||||||
|
let (mut stream, _sink) = conn.unwrap().split();
|
||||||
|
let packet = stream.next().await.unwrap().unwrap();
|
||||||
|
assert_eq!("payload".as_bytes(), packet.payload());
|
||||||
|
println!("accepting drop");
|
||||||
|
}
|
||||||
|
|
||||||
|
tokio::time::sleep(tokio::time::Duration::from_secs(2)).await;
|
||||||
|
|
||||||
|
assert_eq!(0, listener.wg_peer_map.len());
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -1,11 +1,11 @@
|
|||||||
pub mod codec;
|
pub mod codec;
|
||||||
pub mod common;
|
pub mod common;
|
||||||
pub mod ring_tunnel;
|
// pub mod ring_tunnel;
|
||||||
pub mod stats;
|
// pub mod stats;
|
||||||
pub mod tcp_tunnel;
|
// pub mod tcp_tunnel;
|
||||||
pub mod tunnel_filter;
|
// pub mod tunnel_filter;
|
||||||
pub mod udp_tunnel;
|
// pub mod udp_tunnel;
|
||||||
pub mod wireguard;
|
// pub mod wireguard;
|
||||||
|
|
||||||
use std::{fmt::Debug, net::SocketAddr, pin::Pin, sync::Arc};
|
use std::{fmt::Debug, net::SocketAddr, pin::Pin, sync::Arc};
|
||||||
|
|
||||||
|
|||||||
@@ -1,6 +1,5 @@
|
|||||||
use std::{
|
use std::{
|
||||||
net::{Ipv4Addr, SocketAddr},
|
net::{Ipv4Addr, SocketAddr},
|
||||||
pin::Pin,
|
|
||||||
sync::Arc,
|
sync::Arc,
|
||||||
};
|
};
|
||||||
|
|
||||||
@@ -8,24 +7,22 @@ use anyhow::Context;
|
|||||||
use base64::{prelude::BASE64_STANDARD, Engine};
|
use base64::{prelude::BASE64_STANDARD, Engine};
|
||||||
use cidr::Ipv4Inet;
|
use cidr::Ipv4Inet;
|
||||||
use dashmap::DashMap;
|
use dashmap::DashMap;
|
||||||
use futures::{SinkExt, StreamExt};
|
use futures::StreamExt;
|
||||||
use pnet::packet::ipv4::Ipv4Packet;
|
use pnet::packet::ipv4::Ipv4Packet;
|
||||||
use tokio::{sync::Mutex, task::JoinSet};
|
use tokio::task::JoinSet;
|
||||||
use tokio_util::bytes::Bytes;
|
|
||||||
|
|
||||||
use crate::{
|
use crate::{
|
||||||
common::{
|
common::{
|
||||||
|
config::NetworkIdentity,
|
||||||
global_ctx::{ArcGlobalCtx, GlobalCtxEvent},
|
global_ctx::{ArcGlobalCtx, GlobalCtxEvent},
|
||||||
join_joinset_background,
|
join_joinset_background,
|
||||||
},
|
},
|
||||||
peers::{
|
peers::{peer_manager::PeerManager, PeerPacketFilter},
|
||||||
packet::{self, ArchivedPacket},
|
tunnel::{
|
||||||
peer_manager::PeerManager,
|
mpsc::{MpscTunnel, MpscTunnelSender},
|
||||||
PeerPacketFilter,
|
packet_def::{PacketType, ZCPacket, ZCPacketType},
|
||||||
},
|
|
||||||
tunnels::{
|
|
||||||
wireguard::{WgConfig, WgTunnelListener},
|
wireguard::{WgConfig, WgTunnelListener},
|
||||||
DatagramSink, Tunnel, TunnelListener,
|
Tunnel, TunnelListener,
|
||||||
},
|
},
|
||||||
};
|
};
|
||||||
|
|
||||||
@@ -33,9 +30,14 @@ use super::VpnPortal;
|
|||||||
|
|
||||||
type WgPeerIpTable = Arc<DashMap<Ipv4Addr, Arc<ClientEntry>>>;
|
type WgPeerIpTable = Arc<DashMap<Ipv4Addr, Arc<ClientEntry>>>;
|
||||||
|
|
||||||
|
pub(crate) fn get_wg_config_for_portal(nid: &NetworkIdentity) -> WgConfig {
|
||||||
|
let key_seed = format!("{}{}", nid.network_name, nid.network_secret);
|
||||||
|
WgConfig::new_for_portal(&key_seed, &key_seed)
|
||||||
|
}
|
||||||
|
|
||||||
struct ClientEntry {
|
struct ClientEntry {
|
||||||
endpoint_addr: Option<url::Url>,
|
endpoint_addr: Option<url::Url>,
|
||||||
sink: Mutex<Pin<Box<dyn DatagramSink + 'static>>>,
|
sink: MpscTunnelSender,
|
||||||
}
|
}
|
||||||
|
|
||||||
struct WireGuardImpl {
|
struct WireGuardImpl {
|
||||||
@@ -52,8 +54,7 @@ struct WireGuardImpl {
|
|||||||
impl WireGuardImpl {
|
impl WireGuardImpl {
|
||||||
fn new(global_ctx: ArcGlobalCtx, peer_mgr: Arc<PeerManager>) -> Self {
|
fn new(global_ctx: ArcGlobalCtx, peer_mgr: Arc<PeerManager>) -> Self {
|
||||||
let nid = global_ctx.get_network_identity();
|
let nid = global_ctx.get_network_identity();
|
||||||
let key_seed = format!("{}{}", nid.network_name, nid.network_secret);
|
let wg_config = get_wg_config_for_portal(&nid);
|
||||||
let wg_config = WgConfig::new_for_portal(&key_seed, &key_seed);
|
|
||||||
|
|
||||||
let vpn_cfg = global_ctx.config.get_vpn_portal_config().unwrap();
|
let vpn_cfg = global_ctx.config.get_vpn_portal_config().unwrap();
|
||||||
let listenr_addr = vpn_cfg.wireguard_listen;
|
let listenr_addr = vpn_cfg.wireguard_listen;
|
||||||
@@ -73,38 +74,41 @@ impl WireGuardImpl {
|
|||||||
peer_mgr: Arc<PeerManager>,
|
peer_mgr: Arc<PeerManager>,
|
||||||
wg_peer_ip_table: WgPeerIpTable,
|
wg_peer_ip_table: WgPeerIpTable,
|
||||||
) {
|
) {
|
||||||
let mut s = t.pin_stream();
|
let info = t.info().unwrap_or_default();
|
||||||
|
let mut mpsc_tunnel = MpscTunnel::new(t);
|
||||||
|
let mut stream = mpsc_tunnel.get_stream();
|
||||||
let mut ip_registered = false;
|
let mut ip_registered = false;
|
||||||
|
|
||||||
let info = t.info().unwrap_or_default();
|
|
||||||
let remote_addr = info.remote_addr.clone();
|
let remote_addr = info.remote_addr.clone();
|
||||||
peer_mgr
|
peer_mgr
|
||||||
.get_global_ctx()
|
.get_global_ctx()
|
||||||
.issue_event(GlobalCtxEvent::VpnPortalClientConnected(
|
.issue_event(GlobalCtxEvent::VpnPortalClientConnected(
|
||||||
info.local_addr,
|
info.local_addr.clone(),
|
||||||
info.remote_addr,
|
info.remote_addr.clone(),
|
||||||
));
|
));
|
||||||
|
|
||||||
while let Some(Ok(msg)) = s.next().await {
|
while let Some(Ok(msg)) = stream.next().await {
|
||||||
let Some(i) = Ipv4Packet::new(&msg) else {
|
assert_eq!(msg.packet_type(), ZCPacketType::WG);
|
||||||
tracing::error!(?msg, "Failed to parse ipv4 packet");
|
let inner = msg.inner();
|
||||||
|
let Some(i) = Ipv4Packet::new(&inner) else {
|
||||||
|
tracing::error!(?inner, "Failed to parse ipv4 packet");
|
||||||
continue;
|
continue;
|
||||||
};
|
};
|
||||||
if !ip_registered {
|
if !ip_registered {
|
||||||
let client_entry = Arc::new(ClientEntry {
|
let client_entry = Arc::new(ClientEntry {
|
||||||
endpoint_addr: remote_addr.parse().ok(),
|
endpoint_addr: remote_addr.parse().ok(),
|
||||||
sink: Mutex::new(t.pin_sink()),
|
sink: mpsc_tunnel.get_sink(),
|
||||||
});
|
});
|
||||||
wg_peer_ip_table.insert(i.get_source(), client_entry.clone());
|
wg_peer_ip_table.insert(i.get_source(), client_entry.clone());
|
||||||
ip_registered = true;
|
ip_registered = true;
|
||||||
}
|
}
|
||||||
tracing::trace!(?i, "Received from wg client");
|
tracing::trace!(?i, "Received from wg client");
|
||||||
|
let dst = i.get_destination();
|
||||||
let _ = peer_mgr
|
let _ = peer_mgr
|
||||||
.send_msg_ipv4(msg.clone(), i.get_destination())
|
.send_msg_ipv4(ZCPacket::new_with_payload(inner.as_ref()), dst)
|
||||||
.await;
|
.await;
|
||||||
}
|
}
|
||||||
|
|
||||||
let info = t.info().unwrap_or_default();
|
|
||||||
peer_mgr
|
peer_mgr
|
||||||
.get_global_ctx()
|
.get_global_ctx()
|
||||||
.issue_event(GlobalCtxEvent::VpnPortalClientDisconnected(
|
.issue_event(GlobalCtxEvent::VpnPortalClientDisconnected(
|
||||||
@@ -120,34 +124,38 @@ impl WireGuardImpl {
|
|||||||
|
|
||||||
#[async_trait::async_trait]
|
#[async_trait::async_trait]
|
||||||
impl PeerPacketFilter for PeerPacketFilterForVpnPortal {
|
impl PeerPacketFilter for PeerPacketFilterForVpnPortal {
|
||||||
async fn try_process_packet_from_peer(
|
async fn try_process_packet_from_peer(&self, packet: ZCPacket) -> Option<ZCPacket> {
|
||||||
&self,
|
let hdr = packet.peer_manager_header().unwrap();
|
||||||
packet: &ArchivedPacket,
|
if hdr.packet_type != PacketType::Data as u8 {
|
||||||
_: &Bytes,
|
return Some(packet);
|
||||||
) -> Option<()> {
|
|
||||||
if packet.packet_type != packet::PacketType::Data {
|
|
||||||
return None;
|
|
||||||
};
|
};
|
||||||
|
|
||||||
let payload_bytes = packet.payload.as_bytes();
|
let payload_bytes = packet.payload();
|
||||||
|
|
||||||
let ipv4 = Ipv4Packet::new(payload_bytes)?;
|
let ipv4 = Ipv4Packet::new(payload_bytes)?;
|
||||||
if ipv4.get_version() != 4 {
|
if ipv4.get_version() != 4 {
|
||||||
return None;
|
return Some(packet);
|
||||||
}
|
}
|
||||||
|
|
||||||
let entry = self.wg_peer_ip_table.get(&ipv4.get_destination())?.clone();
|
let Some(entry) = self
|
||||||
|
.wg_peer_ip_table
|
||||||
|
.get(&ipv4.get_destination())
|
||||||
|
.map(|f| f.clone())
|
||||||
|
else {
|
||||||
|
return Some(packet);
|
||||||
|
};
|
||||||
|
|
||||||
tracing::trace!(?ipv4, "Packet filter for vpn portal");
|
tracing::trace!(?ipv4, "Packet filter for vpn portal");
|
||||||
|
|
||||||
let ret = entry
|
let payload_offset = packet.packet_type().get_packet_offsets().payload_offset;
|
||||||
.sink
|
let packet = ZCPacket::new_from_buf(
|
||||||
.lock()
|
packet.inner().split_off(payload_offset),
|
||||||
.await
|
ZCPacketType::WG,
|
||||||
.send(Bytes::copy_from_slice(payload_bytes))
|
);
|
||||||
.await;
|
|
||||||
|
|
||||||
ret.ok()
|
if let Err(ret) = entry.sink.send(packet).await {
|
||||||
|
tracing::debug!(?ret, "Failed to send packet to wg client");
|
||||||
|
}
|
||||||
|
None
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -164,9 +172,14 @@ impl WireGuardImpl {
|
|||||||
self.wg_config.clone(),
|
self.wg_config.clone(),
|
||||||
);
|
);
|
||||||
|
|
||||||
l.listen()
|
tracing::info!("Wireguard VPN Portal Starting");
|
||||||
.await
|
|
||||||
.with_context(|| "Failed to start wireguard listener for vpn portal")?;
|
{
|
||||||
|
let _g = self.global_ctx.net_ns.guard();
|
||||||
|
l.listen()
|
||||||
|
.await
|
||||||
|
.with_context(|| "Failed to start wireguard listener for vpn portal")?;
|
||||||
|
}
|
||||||
|
|
||||||
join_joinset_background(self.tasks.clone(), "wireguard".to_string());
|
join_joinset_background(self.tasks.clone(), "wireguard".to_string());
|
||||||
|
|
||||||
@@ -296,62 +309,3 @@ Endpoint = {listenr_addr} # should be the public ip of the vpn server
|
|||||||
.collect()
|
.collect()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(test)]
|
|
||||||
mod tests {
|
|
||||||
use std::sync::Arc;
|
|
||||||
|
|
||||||
use super::*;
|
|
||||||
|
|
||||||
use crate::{
|
|
||||||
common::{
|
|
||||||
config::{NetworkIdentity, VpnPortalConfig},
|
|
||||||
global_ctx::tests::get_mock_global_ctx_with_network,
|
|
||||||
},
|
|
||||||
connector::udp_hole_punch::tests::replace_stun_info_collector,
|
|
||||||
peers::{
|
|
||||||
peer_manager::{PeerManager, RouteAlgoType},
|
|
||||||
tests::wait_for_condition,
|
|
||||||
},
|
|
||||||
rpc::NatType,
|
|
||||||
tunnels::{tcp_tunnel::TcpTunnelConnector, TunnelConnector},
|
|
||||||
};
|
|
||||||
|
|
||||||
async fn portal_test() {
|
|
||||||
let (s, _r) = tokio::sync::mpsc::channel(1000);
|
|
||||||
let peer_mgr = Arc::new(PeerManager::new(
|
|
||||||
RouteAlgoType::Ospf,
|
|
||||||
get_mock_global_ctx_with_network(Some(NetworkIdentity {
|
|
||||||
network_name: "sijie".to_string(),
|
|
||||||
network_secret: "1919119".to_string(),
|
|
||||||
})),
|
|
||||||
s,
|
|
||||||
));
|
|
||||||
replace_stun_info_collector(peer_mgr.clone(), NatType::Unknown);
|
|
||||||
peer_mgr
|
|
||||||
.get_global_ctx()
|
|
||||||
.config
|
|
||||||
.set_vpn_portal_config(VpnPortalConfig {
|
|
||||||
wireguard_listen: "0.0.0.0:11021".parse().unwrap(),
|
|
||||||
client_cidr: "10.14.14.0/24".parse().unwrap(),
|
|
||||||
});
|
|
||||||
peer_mgr.run().await.unwrap();
|
|
||||||
let mut pmgr_conn = TcpTunnelConnector::new("tcp://127.0.0.1:11010".parse().unwrap());
|
|
||||||
let tunnel = pmgr_conn.connect().await;
|
|
||||||
peer_mgr.add_client_tunnel(tunnel.unwrap()).await.unwrap();
|
|
||||||
wait_for_condition(
|
|
||||||
|| async {
|
|
||||||
let routes = peer_mgr.list_routes().await;
|
|
||||||
println!("Routes: {:?}", routes);
|
|
||||||
routes.len() != 0
|
|
||||||
},
|
|
||||||
std::time::Duration::from_secs(10),
|
|
||||||
)
|
|
||||||
.await;
|
|
||||||
|
|
||||||
let mut wg = WireGuard::default();
|
|
||||||
wg.start(peer_mgr.get_global_ctx(), peer_mgr.clone())
|
|
||||||
.await
|
|
||||||
.unwrap();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|||||||
Reference in New Issue
Block a user