diff --git a/easytier-core/Cargo.toml b/easytier-core/Cargo.toml index 267c4943..f8248c6f 100644 --- a/easytier-core/Cargo.toml +++ b/easytier-core/Cargo.toml @@ -90,11 +90,8 @@ cidr = "0.2.2" socket2 = "0.5.5" # for hole punching -stun-format = { git = "https://github.com/KKRainbow/stun-format.git", features = [ - "fmt", - "rfc3489", - "iana", -] } +stun_codec = "0.3.4" +bytecodec = "0.4.15" rand = "0.8.5" serde = { version = "1.0", features = ["derive"] } diff --git a/easytier-core/src/common/mod.rs b/easytier-core/src/common/mod.rs index b6b25596..6647f586 100644 --- a/easytier-core/src/common/mod.rs +++ b/easytier-core/src/common/mod.rs @@ -7,6 +7,7 @@ pub mod netns; pub mod network; pub mod rkyv_util; pub mod stun; +pub mod stun_codec_ext; pub fn get_logger_timer( format: F, diff --git a/easytier-core/src/common/stun.rs b/easytier-core/src/common/stun.rs index a8cb5db6..c5c0a0aa 100644 --- a/easytier-core/src/common/stun.rs +++ b/easytier-core/src/common/stun.rs @@ -1,17 +1,24 @@ -use std::net::{Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6}; +use std::net::SocketAddr; use std::sync::Arc; use std::time::Duration; use crate::rpc::{NatType, StunInfo}; +use anyhow::Context; use crossbeam::atomic::AtomicCell; -use stun_format::Attr; use tokio::net::{lookup_host, UdpSocket}; use tokio::sync::RwLock; use tokio::task::JoinSet; use tracing::Level; +use bytecodec::{DecodeExt, EncodeExt}; +use stun_codec::rfc5389::methods::BINDING; +use stun_codec::rfc5780::attributes::ChangeRequest; +use stun_codec::{Message, MessageClass, MessageDecoder, MessageEncoder}; + use crate::common::error::Error; +use super::stun_codec_ext::*; + struct HostResolverIter { hostnames: Vec, ips: Vec, @@ -51,6 +58,8 @@ impl HostResolverIter { #[derive(Debug, Clone, Copy)] struct BindRequestResponse { source_addr: SocketAddr, + send_to_addr: SocketAddr, + recv_from_addr: SocketAddr, mapped_socket_addr: Option, changed_socket_addr: Option, @@ -78,7 +87,7 @@ impl Stun { pub fn new(stun_server: SocketAddr) -> Self { Self { stun_server, - req_repeat: 3, + req_repeat: 5, resp_timeout: Duration::from_millis(3000), } } @@ -92,7 +101,7 @@ impl Stun { expected_ip_changed: bool, expected_port_changed: bool, stun_host: &SocketAddr, - ) -> Result<(stun_format::Msg<'a>, SocketAddr), Error> { + ) -> Result<(Message, SocketAddr), Error> { let mut now = tokio::time::Instant::now(); let deadline = now + self.resp_timeout; @@ -110,17 +119,19 @@ impl Stun { // TODO:: we cannot borrow `buf` directly in udp recv_from, so we copy it here unsafe { std::ptr::copy(udp_buf.as_ptr(), buf.as_ptr() as *mut u8, len) }; - let msg = stun_format::Msg::<'a>::from(&buf[..]); - tracing::info!(b = ?&udp_buf[..len], ?msg, ?tids, ?remote_addr, ?stun_host, "recv stun response"); - - if msg.typ().is_none() || msg.tid().is_none() { + let mut decoder = MessageDecoder::::new(); + let Ok(msg) = decoder + .decode_from_bytes(&buf[..len]) + .with_context(|| format!("decode stun msg {:?}", buf))? + else { continue; - } + }; - if !matches!( - msg.typ().as_ref().unwrap(), - stun_format::MsgType::BindingResponse - ) || !tids.contains(msg.tid().as_ref().unwrap()) + tracing::info!(b = ?&udp_buf[..len], ?tids, ?remote_addr, ?stun_host, "recv stun response, msg: {:#?}", msg); + + if msg.class() != MessageClass::SuccessResponse + || msg.method() != BINDING + || !tids.contains(&tid_to_u128(&msg.transaction_id())) { continue; } @@ -143,29 +154,18 @@ impl Stun { Err(Error::Unknown) } - fn stun_addr(addr: stun_format::SocketAddr) -> SocketAddr { - match addr { - stun_format::SocketAddr::V4(ip, port) => { - SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::from(ip), port)) - } - stun_format::SocketAddr::V6(ip, port) => { - SocketAddr::V6(SocketAddrV6::new(Ipv6Addr::from(ip), port, 0, 0)) - } - } - } - - fn extrace_mapped_addr(msg: &stun_format::Msg) -> Option { + fn extrace_mapped_addr(msg: &Message) -> Option { let mut mapped_addr = None; - for x in msg.attrs_iter() { + for x in msg.attributes() { match x { - Attr::MappedAddress(addr) => { + Attribute::MappedAddress(addr) => { if mapped_addr.is_none() { - let _ = mapped_addr.insert(Self::stun_addr(addr)); + let _ = mapped_addr.insert(addr.address()); } } - Attr::XorMappedAddress(addr) => { + Attribute::XorMappedAddress(addr) => { if mapped_addr.is_none() { - let _ = mapped_addr.insert(Self::stun_addr(addr)); + let _ = mapped_addr.insert(addr.address()); } } _ => {} @@ -174,13 +174,18 @@ impl Stun { mapped_addr } - fn extract_changed_addr(msg: &stun_format::Msg) -> Option { + fn extract_changed_addr(msg: &Message) -> Option { let mut changed_addr = None; - for x in msg.attrs_iter() { + for x in msg.attributes() { match x { - Attr::ChangedAddress(addr) => { + Attribute::OtherAddress(m) => { if changed_addr.is_none() { - let _ = changed_addr.insert(Self::stun_addr(addr)); + let _ = changed_addr.insert(m.address()); + } + } + Attribute::ChangedAddress(m) => { + if changed_addr.is_none() { + let _ = changed_addr.insert(m.address()); } } _ => {} @@ -202,24 +207,24 @@ impl Stun { // repeat req in case of packet loss let mut tids = vec![]; for _ in 0..self.req_repeat { + let tid = rand::random::(); let mut buf = [0u8; 28]; // memset buf unsafe { std::ptr::write_bytes(buf.as_mut_ptr(), 0, buf.len()) }; - let mut msg = stun_format::MsgBuilder::from(buf.as_mut_slice()); - msg.typ(stun_format::MsgType::BindingRequest).unwrap(); - let tid = rand::random::(); - msg.tid(tid as u128).unwrap(); - if change_ip || change_port { - msg.add_attr(Attr::ChangeRequest { - change_ip, - change_port, - }) - .unwrap(); - } + + let mut message = + Message::::new(MessageClass::Request, BINDING, u128_to_tid(tid as u128)); + message.add_attribute(ChangeRequest::new(change_ip, change_port)); + + // Encodes the message + let mut encoder = MessageEncoder::new(); + let msg = encoder + .encode_into_bytes(message.clone()) + .with_context(|| "encode stun message")?; tids.push(tid as u128); - tracing::trace!(b = ?msg.as_bytes(), tid, "send stun request"); - udp.send_to(msg.as_bytes(), &stun_host).await?; + tracing::trace!(?message, ?msg, tid, "send stun request"); + udp.send_to(msg.as_slice().into(), &stun_host).await?; } tracing::trace!("waiting stun response"); @@ -234,6 +239,8 @@ impl Stun { let resp = BindRequestResponse { source_addr: udp.local_addr()?, + send_to_addr: stun_host, + recv_from_addr: recv_addr, mapped_socket_addr: Self::extrace_mapped_addr(&msg), changed_socket_addr, ip_changed: change_ip, @@ -489,6 +496,18 @@ impl StunInfoCollector { mod tests { use super::*; + pub fn enable_log() { + let filter = tracing_subscriber::EnvFilter::builder() + .with_default_directive(tracing::level_filters::LevelFilter::TRACE.into()) + .from_env() + .unwrap() + .add_directive("tarpc=error".parse().unwrap()); + tracing_subscriber::fmt::fmt() + .pretty() + .with_env_filter(filter) + .init(); + } + #[tokio::test] async fn test_stun_bind_request() { // miwifi / qq seems not correctly responde to change_ip and change_port, they always try to change the src ip and port. diff --git a/easytier-core/src/common/stun_codec_ext.rs b/easytier-core/src/common/stun_codec_ext.rs new file mode 100644 index 00000000..ca7750b9 --- /dev/null +++ b/easytier-core/src/common/stun_codec_ext.rs @@ -0,0 +1,229 @@ +use std::net::SocketAddr; + +use stun_codec::net::{socket_addr_xor, SocketAddrDecoder, SocketAddrEncoder}; + +use stun_codec::rfc5389::attributes::{ + MappedAddress, Software, XorMappedAddress, XorMappedAddress2, +}; +use stun_codec::rfc5780::attributes::{ChangeRequest, OtherAddress, ResponseOrigin}; +use stun_codec::{define_attribute_enums, AttributeType, Message, TransactionId}; + +use bytecodec::{ByteCount, Decode, Encode, Eos, Result, SizedEncode, TryTaggedDecode}; + +use stun_codec::macros::track; + +macro_rules! impl_decode { + ($decoder:ty, $item:ident, $and_then:expr) => { + impl Decode for $decoder { + type Item = $item; + + fn decode(&mut self, buf: &[u8], eos: Eos) -> Result { + track!(self.0.decode(buf, eos)) + } + + fn finish_decoding(&mut self) -> Result { + track!(self.0.finish_decoding()).and_then($and_then) + } + + fn requiring_bytes(&self) -> ByteCount { + self.0.requiring_bytes() + } + + fn is_idle(&self) -> bool { + self.0.is_idle() + } + } + impl TryTaggedDecode for $decoder { + type Tag = AttributeType; + + fn try_start_decoding(&mut self, attr_type: Self::Tag) -> Result { + Ok(attr_type.as_u16() == $item::CODEPOINT) + } + } + }; +} + +macro_rules! impl_encode { + ($encoder:ty, $item:ty, $map_from:expr) => { + impl Encode for $encoder { + type Item = $item; + + fn encode(&mut self, buf: &mut [u8], eos: Eos) -> Result { + track!(self.0.encode(buf, eos)) + } + + #[allow(clippy::redundant_closure_call)] + fn start_encoding(&mut self, item: Self::Item) -> Result<()> { + track!(self.0.start_encoding($map_from(item))) + } + + fn requiring_bytes(&self) -> ByteCount { + self.0.requiring_bytes() + } + + fn is_idle(&self) -> bool { + self.0.is_idle() + } + } + impl SizedEncode for $encoder { + fn exact_requiring_bytes(&self) -> u64 { + self.0.exact_requiring_bytes() + } + } + }; +} + +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub struct ChangedAddress(SocketAddr); +impl ChangedAddress { + /// The codepoint of the type of the attribute. + pub const CODEPOINT: u16 = 0x0005; + + pub fn new(addr: SocketAddr) -> Self { + ChangedAddress(addr) + } + + /// Returns the address of this instance. + pub fn address(&self) -> SocketAddr { + self.0 + } +} +impl stun_codec::Attribute for ChangedAddress { + type Decoder = ChangedAddressDecoder; + type Encoder = ChangedAddressEncoder; + + fn get_type(&self) -> AttributeType { + AttributeType::new(Self::CODEPOINT) + } + + fn before_encode( + &mut self, + message: &Message, + ) -> bytecodec::Result<()> { + self.0 = socket_addr_xor(self.0, message.transaction_id()); + Ok(()) + } + + fn after_decode( + &mut self, + message: &Message, + ) -> bytecodec::Result<()> { + self.0 = socket_addr_xor(self.0, message.transaction_id()); + Ok(()) + } +} + +#[derive(Debug, Default)] +pub struct ChangedAddressDecoder(SocketAddrDecoder); +impl ChangedAddressDecoder { + pub fn new() -> Self { + Self::default() + } +} +impl_decode!(ChangedAddressDecoder, ChangedAddress, |item| Ok( + ChangedAddress(item) +)); + +#[derive(Debug, Default)] +pub struct ChangedAddressEncoder(SocketAddrEncoder); +impl ChangedAddressEncoder { + pub fn new() -> Self { + Self::default() + } +} +impl_encode!(ChangedAddressEncoder, ChangedAddress, |item: Self::Item| { + item.0 +}); + +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub struct SourceAddress(SocketAddr); +impl SourceAddress { + /// The codepoint of the type of the attribute. + pub const CODEPOINT: u16 = 0x0004; + + pub fn new(addr: SocketAddr) -> Self { + SourceAddress(addr) + } + + /// Returns the address of this instance. + pub fn address(&self) -> SocketAddr { + self.0 + } +} +impl stun_codec::Attribute for SourceAddress { + type Decoder = SourceAddressDecoder; + type Encoder = SourceAddressEncoder; + + fn get_type(&self) -> AttributeType { + AttributeType::new(Self::CODEPOINT) + } + + fn before_encode( + &mut self, + message: &Message, + ) -> bytecodec::Result<()> { + self.0 = socket_addr_xor(self.0, message.transaction_id()); + Ok(()) + } + + fn after_decode( + &mut self, + message: &Message, + ) -> bytecodec::Result<()> { + self.0 = socket_addr_xor(self.0, message.transaction_id()); + Ok(()) + } +} + +#[derive(Debug, Default)] +pub struct SourceAddressDecoder(SocketAddrDecoder); +impl SourceAddressDecoder { + pub fn new() -> Self { + Self::default() + } +} +impl_decode!(SourceAddressDecoder, SourceAddress, |item| Ok( + SourceAddress(item) +)); + +#[derive(Debug, Default)] +pub struct SourceAddressEncoder(SocketAddrEncoder); +impl SourceAddressEncoder { + pub fn new() -> Self { + Self::default() + } +} +impl_encode!(SourceAddressEncoder, SourceAddress, |item: Self::Item| { + item.0 +}); + +pub fn tid_to_u128(tid: &TransactionId) -> u128 { + let mut tid_buf = [0u8; 16]; + // copy bytes from msg_tid to tid_buf + tid_buf[..tid.as_bytes().len()].copy_from_slice(tid.as_bytes()); + u128::from_le_bytes(tid_buf) +} + +pub fn u128_to_tid(tid: u128) -> TransactionId { + let tid_buf = tid.to_le_bytes(); + let mut tid_arr = [0u8; 12]; + tid_arr.copy_from_slice(&tid_buf[..12]); + TransactionId::new(tid_arr) +} + +define_attribute_enums!( + Attribute, + AttributeDecoder, + AttributeEncoder, + [ + Software, + MappedAddress, + XorMappedAddress, + XorMappedAddress2, + OtherAddress, + ChangeRequest, + ChangedAddress, + SourceAddress, + ResponseOrigin + ] +);