diff --git a/Cargo.lock b/Cargo.lock index cd0657b1..3c6d7572 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -275,7 +275,7 @@ dependencies = [ "bitflags 1.3.2", "bytes", "futures-util", - "http", + "http 0.2.11", "http-body", "hyper", "itoa 1.0.10", @@ -301,7 +301,7 @@ dependencies = [ "async-trait", "bytes", "futures-util", - "http", + "http 0.2.11", "http-body", "mime", "rustversion", @@ -1292,6 +1292,7 @@ dependencies = [ "derivative", "futures", "gethostname", + "http 1.1.0", "humansize", "indexmap 1.9.3", "log", @@ -1325,8 +1326,10 @@ dependencies = [ "time", "timedmap", "tokio", + "tokio-rustls", "tokio-stream", "tokio-util", + "tokio-websockets", "toml 0.8.12", "tonic", "tonic-build", @@ -1965,7 +1968,7 @@ dependencies = [ "futures-core", "futures-sink", "futures-util", - "http", + "http 0.2.11", "indexmap 2.2.6", "slab", "tokio", @@ -2084,6 +2087,17 @@ dependencies = [ "itoa 1.0.10", ] +[[package]] +name = "http" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "21b9ddb458710bc376481b842f5da65cdf31522de232c1ca8146abce2a358258" +dependencies = [ + "bytes", + "fnv", + "itoa 1.0.10", +] + [[package]] name = "http-body" version = "0.4.6" @@ -2091,7 +2105,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7ceab25649e9960c0311ea418d17bee82c0dcec1bd053b5f9a66e265a693bed2" dependencies = [ "bytes", - "http", + "http 0.2.11", "pin-project-lite", ] @@ -2139,7 +2153,7 @@ dependencies = [ "futures-core", "futures-util", "h2", - "http", + "http 0.2.11", "http-body", "httparse", "httpdate", @@ -2424,6 +2438,20 @@ dependencies = [ "system-deps 5.0.0", ] +[[package]] +name = "jni" +version = "0.19.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c6df18c2e3db7e453d3c6ac5b3e9d5182664d28788126d39b91f2d1e22b017ec" +dependencies = [ + "cesu8", + "combine", + "jni-sys", + "log", + "thiserror", + "walkdir", +] + [[package]] name = "jni" version = "0.20.0" @@ -2919,6 +2947,25 @@ dependencies = [ "winapi", ] +[[package]] +name = "num-bigint" +version = "0.4.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c165a9ab64cf766f73521c0dd2cfdff64f488b8f0b3e621face3462d3db536d7" +dependencies = [ + "num-integer", + "num-traits", +] + +[[package]] +name = "num-integer" +version = "0.1.46" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7969661fd2958a5cb096e56c8e1ad0444ac2bbcd0061bd28660485a44879858f" +dependencies = [ + "num-traits", +] + [[package]] name = "num-traits" version = "0.2.18" @@ -3716,7 +3763,7 @@ dependencies = [ "dns-lookup", "futures-core", "futures-util", - "http", + "http 0.2.11", "hyper", "hyper-system-resolver", "pin-project-lite", @@ -3739,9 +3786,9 @@ dependencies = [ [[package]] name = "quinn" -version = "0.10.2" +version = "0.11.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8cc2c5017e4b43d5995dcea317bc46c1e09404c0a9664d2908f7f02dfe943d75" +checksum = "4bb80dc034523335a9fcc34271931dd97e9132d1fb078695db500339eb72e712" dependencies = [ "bytes", "pin-project-lite", @@ -3756,16 +3803,16 @@ dependencies = [ [[package]] name = "quinn-proto" -version = "0.10.6" +version = "0.11.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "141bf7dfde2fbc246bfd3fe12f2455aa24b0fbd9af535d8c86c7bd1381ff2b1a" +checksum = "a063a47a1aaee4b3b1c2dd44edb7867c10107a2ef171f3543ac40ec5e9092002" dependencies = [ "bytes", "rand 0.8.5", - "ring 0.16.20", + "ring 0.17.8", "rustc-hash", "rustls", - "rustls-native-certs", + "rustls-platform-verifier", "slab", "thiserror", "tinyvec", @@ -3774,15 +3821,15 @@ dependencies = [ [[package]] name = "quinn-udp" -version = "0.4.1" +version = "0.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "055b4e778e8feb9f93c4e439f71dc2156ef13360b432b799e179a8c4cdf0b1d7" +checksum = "cb7ad7bc932e4968523fa7d9c320ee135ff779de720e9350fee8728838551764" dependencies = [ - "bytes", "libc", + "once_cell", "socket2 0.5.5", "tracing", - "windows-sys 0.48.0", + "windows-sys 0.52.0", ] [[package]] @@ -3985,7 +4032,7 @@ dependencies = [ "futures-core", "futures-util", "h2", - "http", + "http 0.2.11", "http-body", "hyper", "hyper-tls", @@ -3997,7 +4044,7 @@ dependencies = [ "once_cell", "percent-encoding", "pin-project-lite", - "rustls-pemfile", + "rustls-pemfile 1.0.4", "serde", "serde_json", "serde_urlencoded", @@ -4108,24 +4155,27 @@ dependencies = [ [[package]] name = "rustls" -version = "0.21.11" +version = "0.23.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7fecbfb7b1444f477b345853b1fce097a2c6fb637b2bfb87e6bc5db0f043fae4" +checksum = "afabcee0551bd1aa3e18e5adbf2c0544722014b899adb31bd186ec638d3da97e" dependencies = [ - "log", + "once_cell", "ring 0.17.8", + "rustls-pki-types", "rustls-webpki", - "sct", + "subtle", + "zeroize", ] [[package]] name = "rustls-native-certs" -version = "0.6.3" +version = "0.7.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a9aace74cb666635c918e9c12bc0d348266037aa8eb599b5cba565709a8dff00" +checksum = "8f1fb85efa936c42c6d5fc28d2629bb51e4b2f4b8a5211e297d599cc5a093792" dependencies = [ "openssl-probe", - "rustls-pemfile", + "rustls-pemfile 2.1.2", + "rustls-pki-types", "schannel", "security-framework", ] @@ -4140,12 +4190,56 @@ dependencies = [ ] [[package]] -name = "rustls-webpki" -version = "0.101.7" +name = "rustls-pemfile" +version = "2.1.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8b6275d1ee7a1cd780b64aca7726599a1dbc893b1e64144529e55c3c2f745765" +checksum = "29993a25686778eb88d4189742cd713c9bce943bc54251a33509dc63cbacf73d" +dependencies = [ + "base64 0.22.0", + "rustls-pki-types", +] + +[[package]] +name = "rustls-pki-types" +version = "1.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ecd36cc4259e3e4514335c4a138c6b43171a8d61d8f5c9348f9fc7529416f247" + +[[package]] +name = "rustls-platform-verifier" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b5f0d26fa1ce3c790f9590868f0109289a044acb954525f933e2aa3b871c157d" +dependencies = [ + "core-foundation", + "core-foundation-sys", + "jni 0.19.0", + "log", + "once_cell", + "rustls", + "rustls-native-certs", + "rustls-platform-verifier-android", + "rustls-webpki", + "security-framework", + "security-framework-sys", + "webpki-roots", + "winapi", +] + +[[package]] +name = "rustls-platform-verifier-android" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "84e217e7fdc8466b5b35d30f8c0a30febd29173df4a3a0c2115d306b9c4117ad" + +[[package]] +name = "rustls-webpki" +version = "0.102.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f3bce581c0dd41bce533ce695a1437fa16a7ab5ac3ccfa99fe1a620a7885eabf" dependencies = [ "ring 0.17.8", + "rustls-pki-types", "untrusted 0.9.0", ] @@ -4191,34 +4285,25 @@ version = "1.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "94143f37725109f92c262ed2cf5e59bce7498c01bcc1502d7b9afe439a4e9f49" -[[package]] -name = "sct" -version = "0.7.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "da046153aa2352493d6cb7da4b6e5c0c057d8a1d0a9aa8560baffdd945acd414" -dependencies = [ - "ring 0.17.8", - "untrusted 0.9.0", -] - [[package]] name = "security-framework" -version = "2.9.2" +version = "2.11.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "05b64fb303737d99b81884b2c63433e9ae28abebe5eb5045dcdd175dc2ecf4de" +checksum = "c627723fd09706bacdb5cf41499e95098555af3c3c29d014dc3c458ef6be11c0" dependencies = [ - "bitflags 1.3.2", + "bitflags 2.5.0", "core-foundation", "core-foundation-sys", "libc", + "num-bigint", "security-framework-sys", ] [[package]] name = "security-framework-sys" -version = "2.9.1" +version = "2.11.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e932934257d3b408ed8f30db49d85ea163bfe74961f017f405b025af298f0c7a" +checksum = "317936bbbd05227752583946b9e66d7ce3b489f84e11a94a510b4437fef407d7" dependencies = [ "core-foundation-sys", "libc", @@ -4752,7 +4837,7 @@ dependencies = [ "gtk", "image", "instant", - "jni", + "jni 0.20.0", "lazy_static", "libappindicator", "libc", @@ -4855,7 +4940,7 @@ dependencies = [ "glob", "gtk", "heck 0.4.1", - "http", + "http 0.2.11", "ignore", "objc", "once_cell", @@ -4951,7 +5036,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "cf2d0652aa2891ff3e9caa2401405257ea29ab8372cce01f186a5825f1bd0e76" dependencies = [ "gtk", - "http", + "http 0.2.11", "http-range", "rand 0.8.5", "raw-window-handle", @@ -5197,6 +5282,17 @@ dependencies = [ "tokio", ] +[[package]] +name = "tokio-rustls" +version = "0.26.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0c7bc40d0e5a97695bb96e27995cd3a08538541b0a846f65bba7a359f36700d4" +dependencies = [ + "rustls", + "rustls-pki-types", + "tokio", +] + [[package]] name = "tokio-stream" version = "0.1.14" @@ -5223,6 +5319,27 @@ dependencies = [ "tracing", ] +[[package]] +name = "tokio-websockets" +version = "0.8.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b37fdd357781d7336924ff59e916d67384f867312ef83c58f0242d22fa31651b" +dependencies = [ + "base64 0.22.0", + "bytes", + "fastrand", + "futures-core", + "futures-sink", + "http 1.1.0", + "httparse", + "ring 0.17.8", + "rustls-pki-types", + "tokio", + "tokio-rustls", + "tokio-util", + "webpki-roots", +] + [[package]] name = "toml" version = "0.5.11" @@ -5303,7 +5420,7 @@ dependencies = [ "base64 0.21.7", "bytes", "h2", - "http", + "http 0.2.11", "http-body", "hyper", "hyper-timeout", @@ -5897,6 +6014,15 @@ dependencies = [ "system-deps 6.2.2", ] +[[package]] +name = "webpki-roots" +version = "0.26.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b3de34ae270483955a94f4b21bdaaeb83d508bb84a01435f393818edb0012009" +dependencies = [ + "rustls-pki-types", +] + [[package]] name = "webview2-com" version = "0.19.1" @@ -6357,7 +6483,7 @@ dependencies = [ "glib", "gtk", "html5ever", - "http", + "http 0.2.11", "kuchikiki", "libc", "log", diff --git a/easytier/Cargo.toml b/easytier/Cargo.toml index 938ffc3c..15875a71 100644 --- a/easytier/Cargo.toml +++ b/easytier/Cargo.toml @@ -67,12 +67,25 @@ pin-project-lite = "0.2.13" atomicbox = "0.4.0" tachyonix = "0.2.1" -quinn = { version = "0.10.2", optional = true } -rustls = { version = "0.21.0", features = [ - "dangerous_configuration", -], optional = true } +quinn = { version = "0.11.0", optional = true, features = ["ring"] } +rustls = { version = "0.23.0", features = [ + "ring", +], default-features = false, optional = true } rcgen = { version = "0.11.1", optional = true } +# for websocket +tokio-websockets = { version = "0.8.2", optional = true, features = [ + "rustls-webpki-roots", + "client", + "server", + "fastrand", + "ring", +] } +http = { version = "1", default-features = false, features = [ + "std", +], optional = true } +tokio-rustls = { version = "0.26", default-features = false, optional = true } + # for tap device tun = { version = "0.6.1", features = ["async"] } # for net ns @@ -169,9 +182,10 @@ defguard_wireguard_rs = "0.4.2" [features] -default = ["wireguard", "quic", "mimalloc"] +default = ["wireguard", "quic", "mimalloc", "websocket"] mips = ["aes-gcm", "mimalloc", "wireguard"] wireguard = ["dep:boringtun", "dep:ring"] quic = ["dep:quinn", "dep:rustls", "dep:rcgen"] mimalloc = ["dep:mimalloc-rust"] aes-gcm = ["dep:aes-gcm"] +websocket = ["dep:tokio-websockets", "dep:http", "dep:tokio-rustls"] diff --git a/easytier/src/connector/mod.rs b/easytier/src/connector/mod.rs index db43c1cc..24396bf5 100644 --- a/easytier/src/connector/mod.rs +++ b/easytier/src/connector/mod.rs @@ -105,6 +105,11 @@ pub async fn create_connector_by_url( .await; return Ok(Box::new(connector)); } + #[cfg(feature = "websocket")] + "ws" | "wss" => { + let connector = crate::tunnel::websocket::WSTunnelConnector::new(url); + return Ok(Box::new(connector)); + } _ => { return Err(Error::InvalidUrl(url.into())); } diff --git a/easytier/src/instance/listeners.rs b/easytier/src/instance/listeners.rs index e74697c7..00dc688d 100644 --- a/easytier/src/instance/listeners.rs +++ b/easytier/src/instance/listeners.rs @@ -39,6 +39,11 @@ pub fn get_listener_by_url( } #[cfg(feature = "quic")] "quic" => Box::new(QUICTunnelListener::new(l.clone())), + #[cfg(feature = "websocket")] + "ws" | "wss" => { + use crate::tunnel::websocket::WSTunnelListener; + Box::new(WSTunnelListener::new(l.clone())) + } _ => { unreachable!("unsupported listener uri"); } @@ -154,6 +159,7 @@ impl ListenerManage } }); } + tracing::warn!("listener exit"); } pub async fn run(&mut self) -> Result<(), Error> { diff --git a/easytier/src/tests/three_node.rs b/easytier/src/tests/three_node.rs index b9f37c34..1ecdee33 100644 --- a/easytier/src/tests/three_node.rs +++ b/easytier/src/tests/three_node.rs @@ -53,6 +53,8 @@ pub fn get_inst_config(inst_name: &str, ns: Option<&str>, ipv4: &str) -> TomlCon "tcp://0.0.0.0:11010".parse().unwrap(), "udp://0.0.0.0:11010".parse().unwrap(), "wg://0.0.0.0:11011".parse().unwrap(), + "ws://0.0.0.0:11011".parse().unwrap(), + "wss://0.0.0.0:11012".parse().unwrap(), ]); config } @@ -96,6 +98,20 @@ pub async fn init_three_node(proto: &str) -> Vec { .unwrap_or_default(), ), )); + } else if proto == "ws" { + #[cfg(feature = "websocket")] + inst2 + .get_conn_manager() + .add_connector(crate::tunnel::websocket::WSTunnelConnector::new( + "ws://10.1.1.1:11011".parse().unwrap(), + )); + } else if proto == "wss" { + #[cfg(feature = "websocket")] + inst2 + .get_conn_manager() + .add_connector(crate::tunnel::websocket::WSTunnelConnector::new( + "wss://10.1.1.1:11012".parse().unwrap(), + )); } inst2 @@ -105,16 +121,17 @@ pub async fn init_three_node(proto: &str) -> Vec { )); // wait inst2 have two route. - let now = std::time::Instant::now(); - loop { - if inst2.get_peer_manager().list_routes().await.len() == 2 { - break; - } - if now.elapsed().as_secs() > 5 { - panic!("wait inst2 have two route timeout"); - } - tokio::time::sleep(tokio::time::Duration::from_millis(100)).await; - } + wait_for_condition( + || async { inst2.get_peer_manager().list_routes().await.len() == 2 }, + Duration::from_secs(5000), + ) + .await; + + wait_for_condition( + || async { inst1.get_peer_manager().list_routes().await.len() == 2 }, + Duration::from_secs(5000), + ) + .await; vec![inst1, inst2, inst3] } @@ -142,7 +159,7 @@ async fn ping_test(from_netns: &str, target_ip: &str) -> bool { #[rstest::rstest] #[tokio::test] #[serial_test::serial] -pub async fn basic_three_node_test(#[values("tcp", "udp", "wg")] proto: &str) { +pub async fn basic_three_node_test(#[values("tcp", "udp", "wg", "ws", "wss")] proto: &str) { let insts = init_three_node(proto).await; check_route( diff --git a/easytier/src/tunnel/common.rs b/easytier/src/tunnel/common.rs index 04d170a8..04632d46 100644 --- a/easytier/src/tunnel/common.rs +++ b/easytier/src/tunnel/common.rs @@ -431,7 +431,14 @@ pub mod tests { let (mut recv, mut send) = tunnel.split(); if !once { - recv.forward(send).await.unwrap(); + while let Some(item) = recv.next().await { + let Ok(msg) = item else { + continue; + }; + if let Err(_) = send.send(msg).await { + break; + } + } } else { let Some(ret) = recv.next().await else { assert!(false, "recv error"); @@ -447,6 +454,8 @@ pub mod tests { tracing::debug!(?res, "recv a msg, try echo back"); send.send(res).await.unwrap(); } + let _ = send.flush().await; + let _ = send.close().await; tracing::warn!("echo server exit..."); } @@ -506,7 +515,7 @@ pub mod tests { println!("echo back: {:?}", ret); assert_eq!(ret.payload(), Bytes::from("12345678abcdefg")); - drop(send); + send.close().await.unwrap(); if ["udp", "wg"].contains(&connector.remote_url().scheme()) { lis.abort(); @@ -562,6 +571,7 @@ pub mod tests { let _ = send.feed(item).await.unwrap(); } + send.close().await.unwrap(); drop(send); drop(connector); drop(tunnel); @@ -576,7 +586,7 @@ pub mod tests { pub fn enable_log() { let filter = tracing_subscriber::EnvFilter::builder() - .with_default_directive(tracing::level_filters::LevelFilter::TRACE.into()) + .with_default_directive(tracing::level_filters::LevelFilter::DEBUG.into()) .from_env() .unwrap() .add_directive("tarpc=error".parse().unwrap()); diff --git a/easytier/src/tunnel/insecure_tls.rs b/easytier/src/tunnel/insecure_tls.rs new file mode 100644 index 00000000..f829ee28 --- /dev/null +++ b/easytier/src/tunnel/insecure_tls.rs @@ -0,0 +1,86 @@ +use std::sync::Arc; + +use rustls::pki_types::{CertificateDer, PrivateKeyDer, ServerName, UnixTime}; + +/// Dummy certificate verifier that treats any certificate as valid. +/// NOTE, such verification is vulnerable to MITM attacks, but convenient for testing. +#[derive(Debug)] +struct SkipServerVerification(Arc); + +impl SkipServerVerification { + fn new(provider: Arc) -> Arc { + Arc::new(Self(provider)) + } +} + +impl rustls::client::danger::ServerCertVerifier for SkipServerVerification { + fn verify_server_cert( + &self, + _end_entity: &CertificateDer<'_>, + _intermediates: &[CertificateDer<'_>], + _server_name: &ServerName<'_>, + _ocsp: &[u8], + _now: UnixTime, + ) -> Result { + Ok(rustls::client::danger::ServerCertVerified::assertion()) + } + + fn verify_tls12_signature( + &self, + message: &[u8], + cert: &CertificateDer<'_>, + dss: &rustls::DigitallySignedStruct, + ) -> Result { + rustls::crypto::verify_tls12_signature( + message, + cert, + dss, + &self.0.signature_verification_algorithms, + ) + } + + fn verify_tls13_signature( + &self, + message: &[u8], + cert: &CertificateDer<'_>, + dss: &rustls::DigitallySignedStruct, + ) -> Result { + rustls::crypto::verify_tls13_signature( + message, + cert, + dss, + &self.0.signature_verification_algorithms, + ) + } + + fn supported_verify_schemes(&self) -> Vec { + self.0.signature_verification_algorithms.supported_schemes() + } +} + +pub fn init_crypto_provider() { + let _ = + rustls::crypto::CryptoProvider::install_default(rustls::crypto::ring::default_provider()); +} + +pub fn get_insecure_tls_client_config() -> rustls::ClientConfig { + init_crypto_provider(); + let provider = rustls::crypto::CryptoProvider::get_default().unwrap(); + let mut config = rustls::ClientConfig::builder() + .dangerous() + .with_custom_certificate_verifier(SkipServerVerification::new(provider.clone())) + .with_no_client_auth(); + config.enable_sni = false; + config.enable_early_data = false; + config +} + +pub fn get_insecure_tls_cert<'a>() -> (Vec>, PrivateKeyDer<'a>) { + let cert = rcgen::generate_simple_self_signed(vec!["localhost".into()]).unwrap(); + let cert_der = cert.serialize_der().unwrap(); + let priv_key = cert.serialize_private_key_der(); + let priv_key = rustls::pki_types::PrivatePkcs8KeyDer::from(priv_key); + let cert_chain = vec![cert_der.clone().into()]; + + (cert_chain, priv_key.into()) +} diff --git a/easytier/src/tunnel/mod.rs b/easytier/src/tunnel/mod.rs index 07b1d3eb..cad6568f 100644 --- a/easytier/src/tunnel/mod.rs +++ b/easytier/src/tunnel/mod.rs @@ -28,6 +28,12 @@ pub mod wireguard; #[cfg(feature = "quic")] pub mod quic; +#[cfg(feature = "websocket")] +pub mod websocket; + +#[cfg(any(feature = "quic", feature = "websocket"))] +pub mod insecure_tls; + #[derive(thiserror::Error, Debug)] pub enum TunnelError { #[error("io error")] @@ -62,6 +68,10 @@ pub enum TunnelError { #[error("no dns record found")] NoDnsRecordFound(IpVersion), + #[cfg(feature = "websocket")] + #[error("websocket error: {0}")] + WebSocketError(#[from] tokio_websockets::Error), + #[error("tunnel error: {0}")] TunError(String), } diff --git a/easytier/src/tunnel/mpsc.rs b/easytier/src/tunnel/mpsc.rs index e587b729..b49e2ac4 100644 --- a/easytier/src/tunnel/mpsc.rs +++ b/easytier/src/tunnel/mpsc.rs @@ -1,9 +1,9 @@ // this mod wrap tunnel to a mpsc tunnel, based on crossbeam_channel -use std::pin::Pin; +use std::{pin::Pin, time::Duration}; use anyhow::Context; -use tokio::task::JoinHandle; +use tokio::{task::JoinHandle, time::timeout}; use super::{packet_def::ZCPacket, Tunnel, TunnelError, ZCPacketSink, ZCPacketStream}; @@ -42,6 +42,8 @@ impl MpscTunnel { break; } } + let close_ret = timeout(Duration::from_secs(5), sink.close()).await; + tracing::warn!(?close_ret, "mpsc close sink"); }); Self { diff --git a/easytier/src/tunnel/packet_def.rs b/easytier/src/tunnel/packet_def.rs index dec77872..036b40d8 100644 --- a/easytier/src/tunnel/packet_def.rs +++ b/easytier/src/tunnel/packet_def.rs @@ -114,6 +114,7 @@ pub struct ZCPacketOffsets { pub tcp_tunnel_header_offset: usize, pub udp_tunnel_header_offset: usize, pub wg_tunnel_header_offset: usize, + pub dummy_tunnel_header_offset: usize, } #[derive(Debug, Clone, Copy, PartialEq)] @@ -126,6 +127,8 @@ pub enum ZCPacketType { WG, // received from local tun device, should reserve header space for tcp or udp tunnel NIC, + // tunnel without header + DummyTunnel, } const PAYLOAD_OFFSET_FOR_NIC_PACKET: usize = max( @@ -158,6 +161,7 @@ impl ZCPacketType { TCP_TUNNEL_HEADER_SIZE, WG_TUNNEL_HEADER_SIZE, ), + dummy_tunnel_header_offset: get_converted_offset(TCP_TUNNEL_HEADER_SIZE, 0), }, ZCPacketType::UDP => ZCPacketOffsets { payload_offset: UDP_TUNNEL_HEADER_SIZE + PEER_MANAGER_HEADER_SIZE, @@ -171,6 +175,7 @@ impl ZCPacketType { UDP_TUNNEL_HEADER_SIZE, WG_TUNNEL_HEADER_SIZE, ), + dummy_tunnel_header_offset: get_converted_offset(UDP_TUNNEL_HEADER_SIZE, 0), }, ZCPacketType::WG => ZCPacketOffsets { payload_offset: WG_TUNNEL_HEADER_SIZE + PEER_MANAGER_HEADER_SIZE, @@ -184,6 +189,7 @@ impl ZCPacketType { UDP_TUNNEL_HEADER_SIZE, ), wg_tunnel_header_offset: 0, + dummy_tunnel_header_offset: get_converted_offset(WG_TUNNEL_HEADER_SIZE, 0), }, ZCPacketType::NIC => ZCPacketOffsets { payload_offset: PAYLOAD_OFFSET_FOR_NIC_PACKET, @@ -198,6 +204,16 @@ impl ZCPacketType { wg_tunnel_header_offset: PAYLOAD_OFFSET_FOR_NIC_PACKET - PEER_MANAGER_HEADER_SIZE - WG_TUNNEL_HEADER_SIZE, + dummy_tunnel_header_offset: PAYLOAD_OFFSET_FOR_NIC_PACKET + - PEER_MANAGER_HEADER_SIZE, + }, + ZCPacketType::DummyTunnel => ZCPacketOffsets { + payload_offset: PEER_MANAGER_HEADER_SIZE, + peer_manager_header_offset: 0, + tcp_tunnel_header_offset: get_converted_offset(0, TCP_TUNNEL_HEADER_SIZE), + udp_tunnel_header_offset: get_converted_offset(0, UDP_TUNNEL_HEADER_SIZE), + wg_tunnel_header_offset: get_converted_offset(0, WG_TUNNEL_HEADER_SIZE), + dummy_tunnel_header_offset: 0, }, } } @@ -349,13 +365,21 @@ impl ZCPacket { hdr.len.set(payload_len as u32); } - fn tunnel_payload(&self) -> &[u8] { + pub fn tunnel_payload(&self) -> &[u8] { &self.inner[self .packet_type .get_packet_offsets() .peer_manager_header_offset..] } + pub fn tunnel_payload_bytes(mut self) -> BytesMut { + self.inner.split_off( + self.packet_type + .get_packet_offsets() + .peer_manager_header_offset, + ) + } + pub fn convert_type(mut self, target_packet_type: ZCPacketType) -> Self { if target_packet_type == self.packet_type { return self; @@ -377,6 +401,11 @@ impl ZCPacket { .get_packet_offsets() .wg_tunnel_header_offset } + ZCPacketType::DummyTunnel => { + self.packet_type + .get_packet_offsets() + .dummy_tunnel_header_offset + } ZCPacketType::NIC => unreachable!(), }; diff --git a/easytier/src/tunnel/quic.rs b/easytier/src/tunnel/quic.rs index 9ae7cd5d..497fea75 100644 --- a/easytier/src/tunnel/quic.rs +++ b/easytier/src/tunnel/quic.rs @@ -12,44 +12,18 @@ use crate::{ }, }; use anyhow::Context; -use quinn::{ClientConfig, Connection, Endpoint, ServerConfig}; +use quinn::{crypto::rustls::QuicClientConfig, ClientConfig, Connection, Endpoint, ServerConfig}; use super::{ - check_scheme_and_get_socket_addr, IpVersion, Tunnel, TunnelConnector, TunnelError, - TunnelListener, + check_scheme_and_get_socket_addr, + insecure_tls::{get_insecure_tls_cert, get_insecure_tls_client_config}, + IpVersion, Tunnel, TunnelConnector, TunnelError, TunnelListener, }; -/// Dummy certificate verifier that treats any certificate as valid. -/// NOTE, such verification is vulnerable to MITM attacks, but convenient for testing. -struct SkipServerVerification; - -impl SkipServerVerification { - fn new() -> Arc { - Arc::new(Self) - } -} - -impl rustls::client::ServerCertVerifier for SkipServerVerification { - fn verify_server_cert( - &self, - _end_entity: &rustls::Certificate, - _intermediates: &[rustls::Certificate], - _server_name: &rustls::ServerName, - _scts: &mut dyn Iterator, - _ocsp_response: &[u8], - _now: std::time::SystemTime, - ) -> Result { - Ok(rustls::client::ServerCertVerified::assertion()) - } -} - fn configure_client() -> ClientConfig { - let crypto = rustls::ClientConfig::builder() - .with_safe_defaults() - .with_custom_certificate_verifier(SkipServerVerification::new()) - .with_no_client_auth(); - - ClientConfig::new(Arc::new(crypto)) + ClientConfig::new(Arc::new( + QuicClientConfig::try_from(get_insecure_tls_client_config()).unwrap(), + )) } /// Constructs a QUIC endpoint configured to listen for incoming connections on a certain address @@ -68,18 +42,14 @@ pub fn make_server_endpoint(bind_addr: SocketAddr) -> Result<(Endpoint, Vec) /// Returns default server configuration along with its certificate. fn configure_server() -> Result<(ServerConfig, Vec), Box> { - let cert = rcgen::generate_simple_self_signed(vec!["localhost".into()]).unwrap(); - let cert_der = cert.serialize_der().unwrap(); - let priv_key = cert.serialize_private_key_der(); - let priv_key = rustls::PrivateKey(priv_key); - let cert_chain = vec![rustls::Certificate(cert_der.clone())]; + let (certs, key) = get_insecure_tls_cert(); - let mut server_config = ServerConfig::with_single_cert(cert_chain, priv_key)?; + let mut server_config = ServerConfig::with_single_cert(certs.clone(), key.into())?; let transport_config = Arc::get_mut(&mut server_config.transport).unwrap(); transport_config.max_concurrent_uni_streams(10_u8.into()); transport_config.max_concurrent_bidi_streams(10_u8.into()); - Ok((server_config, cert_der)) + Ok((server_config, certs[0].to_vec())) } #[allow(unused)] diff --git a/easytier/src/tunnel/websocket.rs b/easytier/src/tunnel/websocket.rs new file mode 100644 index 00000000..45005406 --- /dev/null +++ b/easytier/src/tunnel/websocket.rs @@ -0,0 +1,262 @@ +use std::{net::SocketAddr, sync::Arc}; + +use anyhow::Context; +use bytes::BytesMut; +use futures::{SinkExt, StreamExt}; +use tokio::net::{TcpListener, TcpSocket, TcpStream}; +use tokio_rustls::TlsAcceptor; +use tokio_websockets::{ClientBuilder, Limits, Message}; +use zerocopy::AsBytes; + +use crate::{rpc::TunnelInfo, tunnel::insecure_tls::get_insecure_tls_client_config}; + +use super::{ + common::{setup_sokcet2, TunnelWrapper}, + insecure_tls::{get_insecure_tls_cert, init_crypto_provider}, + packet_def::{ZCPacket, ZCPacketType}, + FromUrl, IpVersion, Tunnel, TunnelConnector, TunnelError, TunnelListener, +}; + +fn is_wss(addr: &url::Url) -> Result { + match addr.scheme() { + "ws" => Ok(false), + "wss" => Ok(true), + _ => Err(TunnelError::InvalidProtocol(addr.scheme().to_string())), + } +} + +async fn sink_from_zc_packet(msg: ZCPacket) -> Result { + Ok(Message::binary(msg.tunnel_payload_bytes().freeze())) +} + +async fn map_from_ws_message( + msg: Result, +) -> Option> { + if msg.is_err() { + tracing::error!(?msg, "recv from websocket error"); + return Some(Err(TunnelError::WebSocketError(msg.unwrap_err()))); + } + + let msg = msg.unwrap(); + if msg.is_close() { + tracing::warn!("recv close message from websocket"); + return None; + } + + if !msg.is_binary() { + let msg = format!("{:?}", msg); + tracing::error!(?msg, "Invalid packet"); + return Some(Err(TunnelError::InvalidPacket(msg))); + } + + Some(Ok(ZCPacket::new_from_buf( + BytesMut::from(msg.into_payload().as_bytes()), + ZCPacketType::DummyTunnel, + ))) +} + +#[derive(Debug)] +pub struct WSTunnelListener { + addr: url::Url, + listener: Option, +} + +impl WSTunnelListener { + pub fn new(addr: url::Url) -> Self { + WSTunnelListener { + addr, + listener: None, + } + } + + async fn try_accept(&mut self, stream: TcpStream) -> Result, TunnelError> { + let info = TunnelInfo { + tunnel_type: self.addr.scheme().to_owned(), + local_addr: self.local_url().into(), + remote_addr: super::build_url_from_socket_addr( + &stream.peer_addr()?.to_string(), + self.addr.scheme().to_string().as_str(), + ) + .into(), + }; + + let server_bulder = tokio_websockets::ServerBuilder::new().limits(Limits::unlimited()); + + let ret: Box = if is_wss(&self.addr)? { + init_crypto_provider(); + let (certs, key) = get_insecure_tls_cert(); + let config = rustls::ServerConfig::builder() + .with_no_client_auth() + .with_single_cert(certs, key) + .with_context(|| "Failed to create server config")?; + let acceptor = TlsAcceptor::from(Arc::new(config)); + + let stream = acceptor.accept(stream).await?; + let (write, read) = server_bulder.accept(stream).await?.split(); + + Box::new(TunnelWrapper::new( + read.filter_map(move |msg| map_from_ws_message(msg)), + write.with(move |msg| sink_from_zc_packet(msg)), + Some(info), + )) + } else { + let (write, read) = server_bulder.accept(stream).await?.split(); + Box::new(TunnelWrapper::new( + read.filter_map(move |msg| map_from_ws_message(msg)), + write.with(move |msg| sink_from_zc_packet(msg)), + Some(info), + )) + }; + + Ok(ret) + } +} + +#[async_trait::async_trait] +impl TunnelListener for WSTunnelListener { + async fn listen(&mut self) -> Result<(), TunnelError> { + let addr = SocketAddr::from_url(self.addr.clone(), IpVersion::Both)?; + let socket2_socket = socket2::Socket::new( + socket2::Domain::for_address(addr), + socket2::Type::STREAM, + Some(socket2::Protocol::TCP), + )?; + setup_sokcet2(&socket2_socket, &addr)?; + let socket = TcpSocket::from_std_stream(socket2_socket.into()); + + self.addr + .set_port(Some(socket.local_addr()?.port())) + .unwrap(); + + self.listener = Some(socket.listen(1024)?); + Ok(()) + } + + async fn accept(&mut self) -> Result, super::TunnelError> { + loop { + let listener = self.listener.as_ref().unwrap(); + // only fail on tcp accept error + let (stream, _) = listener.accept().await?; + stream.set_nodelay(true).unwrap(); + match self.try_accept(stream).await { + Ok(tunnel) => return Ok(tunnel), + Err(e) => { + tracing::error!(?e, ?self, "Failed to accept ws/wss tunnel"); + continue; + } + } + } + } + + fn local_url(&self) -> url::Url { + self.addr.clone() + } +} + +pub struct WSTunnelConnector { + addr: url::Url, + ip_version: IpVersion, +} + +impl WSTunnelConnector { + pub fn new(addr: url::Url) -> Self { + WSTunnelConnector { + addr, + ip_version: IpVersion::Both, + } + } +} + +#[async_trait::async_trait] +impl TunnelConnector for WSTunnelConnector { + async fn connect(&mut self) -> Result, super::TunnelError> { + let is_wss = is_wss(&self.addr)?; + let addr = SocketAddr::from_url(self.addr.clone(), self.ip_version)?; + let local_addr = if addr.is_ipv4() { + "0.0.0.0:0" + } else { + "[::]:0" + }; + + let info = TunnelInfo { + tunnel_type: self.addr.scheme().to_owned(), + local_addr: super::build_url_from_socket_addr( + &local_addr.to_string(), + self.addr.scheme().to_string().as_str(), + ) + .into(), + remote_addr: self.addr.to_string(), + }; + + let connector = + tokio_websockets::Connector::Rustls(Arc::new(get_insecure_tls_client_config()).into()); + let mut client_builder = + ClientBuilder::from_uri(http::Uri::try_from(self.addr.to_string()).unwrap()); + if is_wss { + init_crypto_provider(); + client_builder = client_builder.connector(&connector); + } + + let (client, _) = client_builder.connect().await?; + + let (write, read) = client.split(); + let read = read.filter_map(move |msg| map_from_ws_message(msg)); + let write = write.with(move |msg| sink_from_zc_packet(msg)); + + Ok(Box::new(TunnelWrapper::new(read, write, Some(info)))) + } + + fn remote_url(&self) -> url::Url { + self.addr.clone() + } + + fn set_ip_version(&mut self, ip_version: IpVersion) { + self.ip_version = ip_version; + } +} + +#[cfg(test)] +pub mod tests { + use crate::tunnel::common::tests::_tunnel_pingpong; + use crate::tunnel::websocket::{WSTunnelConnector, WSTunnelListener}; + use crate::tunnel::{TunnelConnector, TunnelListener}; + + #[rstest::rstest] + #[tokio::test] + #[serial_test::serial] + async fn ws_pingpong(#[values("ws", "wss")] proto: &str) { + let listener = WSTunnelListener::new(format!("{}://0.0.0.0:25556", proto).parse().unwrap()); + let connector = + WSTunnelConnector::new(format!("{}://127.0.0.1:25556", proto).parse().unwrap()); + _tunnel_pingpong(listener, connector).await + } + + // TODO: tokio-websockets cannot correctly handle close, benchmark case is disabled + // #[rstest::rstest] + // #[tokio::test] + // #[serial_test::serial] + // async fn ws_bench(#[values("ws", "wss")] proto: &str) { + // enable_log(); + // let listener = WSTunnelListener::new(format!("{}://0.0.0.0:25557", proto).parse().unwrap()); + // let connector = + // WSTunnelConnector::new(format!("{}://127.0.0.1:25557", proto).parse().unwrap()); + // _tunnel_bench(listener, connector).await + // } + + #[tokio::test] + async fn ws_accept_wss() { + let mut listener = WSTunnelListener::new("wss://0.0.0.0:25558".parse().unwrap()); + listener.listen().await.unwrap(); + let j = tokio::spawn(async move { + let _ = listener.accept().await; + }); + + let mut connector = WSTunnelConnector::new("ws://127.0.0.1:25558".parse().unwrap()); + connector.connect().await.unwrap_err(); + + let mut connector = WSTunnelConnector::new("wss://127.0.0.1:25558".parse().unwrap()); + connector.connect().await.unwrap(); + + j.abort(); + } +}