diff --git a/Cargo.toml b/Cargo.toml index 6ac2945c..44c26045 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,17 +1,12 @@ [workspace] -resolver= "2" -members = [ - "easytier-core", - "easytier-cli" -] +resolver = "2" +members = ["easytier-core", "easytier-cli"] -default-members = [ - "easytier-core", - "easytier-cli" -] +default-members = ["easytier-core", "easytier-cli"] [profile.dev] panic = "abort" [profile.release] panic = "abort" +lto = true diff --git a/easytier-core/Cargo.toml b/easytier-core/Cargo.toml index 696d464d..b70a7882 100644 --- a/easytier-core/Cargo.toml +++ b/easytier-core/Cargo.toml @@ -56,7 +56,8 @@ crossbeam-queue = "0.3" once_cell = "1.18.0" # for packet -rkyv = { "version" = "0.7.42", features = ["validation", "archive_le"] } +rkyv = { "version" = "0.7.42", features = ["validation", "archive_le", "strict", "copy_unsafe", "arbitrary_enum_discriminant"] } +postcard = {"version"= "*", features = ["alloc"]} # for rpc tonic = "0.10" diff --git a/easytier-core/src/common/rkyv_util.rs b/easytier-core/src/common/rkyv_util.rs index 65881210..0dfa730d 100644 --- a/easytier-core/src/common/rkyv_util.rs +++ b/easytier-core/src/common/rkyv_util.rs @@ -1,4 +1,5 @@ use rkyv::{ + string::ArchivedString, validation::{validators::DefaultValidator, CheckTypeError}, vec::ArchivedVec, Archive, CheckBytes, Serialize, @@ -43,6 +44,19 @@ pub fn extract_bytes_from_archived_vec(raw_data: &Bytes, archived_data: &Archive return raw_data.slice(offset..offset + len); } +pub fn extract_bytes_from_archived_string( + raw_data: &Bytes, + archived_data: &ArchivedString, +) -> Bytes { + let offset = archived_data.as_ptr() as usize - raw_data.as_ptr() as usize; + let len = archived_data.len(); + if offset + len > raw_data.len() { + return Bytes::new(); + } + + return raw_data.slice(offset..offset + archived_data.len()); +} + pub fn extract_bytes_mut_from_archived_vec( raw_data: &mut BytesMut, archived_data: &ArchivedVec, @@ -52,3 +66,7 @@ pub fn extract_bytes_mut_from_archived_vec( let len = ptr_range.end as usize - ptr_range.start as usize; raw_data.split_off(offset).split_to(len) } + +pub fn vec_to_string(vec: Vec) -> String { + unsafe { String::from_utf8_unchecked(vec) } +} diff --git a/easytier-core/src/gateway/icmp_proxy.rs b/easytier-core/src/gateway/icmp_proxy.rs index 5970e8d2..c660b458 100644 --- a/easytier-core/src/gateway/icmp_proxy.rs +++ b/easytier-core/src/gateway/icmp_proxy.rs @@ -151,11 +151,11 @@ impl PeerPacketFilter for IcmpProxy { ) -> Option<()> { let _ = self.global_ctx.get_ipv4()?; - let packet::ArchivedPacketBody::Data(x) = &packet.body else { + if packet.packet_type != packet::PacketType::Data { return None; }; - let ipv4 = Ipv4Packet::new(&x)?; + let ipv4 = Ipv4Packet::new(&packet.payload.as_bytes())?; if ipv4.get_version() != 4 || ipv4.get_next_level_protocol() != IpNextHeaderProtocols::Icmp { diff --git a/easytier-core/src/gateway/tcp_proxy.rs b/easytier-core/src/gateway/tcp_proxy.rs index 15d13085..2f1e2b76 100644 --- a/easytier-core/src/gateway/tcp_proxy.rs +++ b/easytier-core/src/gateway/tcp_proxy.rs @@ -84,11 +84,13 @@ impl PeerPacketFilter for TcpProxy { async fn try_process_packet_from_peer(&self, packet: &ArchivedPacket, _: &Bytes) -> Option<()> { let ipv4_addr = self.global_ctx.get_ipv4()?; - let packet::ArchivedPacketBody::Data(x) = &packet.body else { + if packet.packet_type != packet::PacketType::Data { return None; }; - let ipv4 = Ipv4Packet::new(&x)?; + let payload_bytes = packet.payload.as_bytes(); + + let ipv4 = Ipv4Packet::new(payload_bytes)?; if ipv4.get_version() != 4 || ipv4.get_next_level_protocol() != IpNextHeaderProtocols::Tcp { return None; } @@ -99,8 +101,8 @@ impl PeerPacketFilter for TcpProxy { tracing::trace!(ipv4 = ?ipv4, cidr_set = ?self.cidr_set, "proxy tcp packet received"); - let mut packet_buffer = BytesMut::with_capacity(x.len()); - packet_buffer.extend_from_slice(&x.to_vec()); + let mut packet_buffer = BytesMut::with_capacity(payload_bytes.len()); + packet_buffer.extend_from_slice(&payload_bytes.to_vec()); let (ip_buffer, tcp_buffer) = packet_buffer.split_at_mut(ipv4.get_header_length() as usize * 4); diff --git a/easytier-core/src/gateway/udp_proxy.rs b/easytier-core/src/gateway/udp_proxy.rs index 41e72a30..970629a0 100644 --- a/easytier-core/src/gateway/udp_proxy.rs +++ b/easytier-core/src/gateway/udp_proxy.rs @@ -242,11 +242,11 @@ impl PeerPacketFilter for UdpProxy { let _ = self.global_ctx.get_ipv4()?; - let packet::ArchivedPacketBody::Data(x) = &packet.body else { + if packet.packet_type != packet::PacketType::Data { return None; }; - let ipv4 = Ipv4Packet::new(&x)?; + let ipv4 = Ipv4Packet::new(packet.payload.as_bytes())?; if ipv4.get_version() != 4 || ipv4.get_next_level_protocol() != IpNextHeaderProtocols::Udp { return None; diff --git a/easytier-core/src/peer_center/instance.rs b/easytier-core/src/peer_center/instance.rs index d2069959..dbf7ef78 100644 --- a/easytier-core/src/peer_center/instance.rs +++ b/easytier-core/src/peer_center/instance.rs @@ -89,7 +89,7 @@ impl PeerCenterBase { tokio::time::sleep(Duration::from_secs(1)).await; continue; }; - tracing::info!(?center_peer, "run periodic job"); + tracing::trace!(?center_peer, "run periodic job"); let rpc_mgr = peer_mgr.get_peer_rpc_mgr(); let _g = lock.lock().await; let ret = rpc_mgr diff --git a/easytier-core/src/peer_center/server.rs b/easytier-core/src/peer_center/server.rs index cf6d40e8..89a929f8 100644 --- a/easytier-core/src/peer_center/server.rs +++ b/easytier-core/src/peer_center/server.rs @@ -98,7 +98,7 @@ impl PeerCenterService for PeerCenterServer { peers: Option, digest: Digest, ) -> Result<(), Error> { - tracing::info!("receive report_peers"); + tracing::trace!("receive report_peers"); let data = get_global_data(self.my_node_id); let mut locked_data = data.write().await; diff --git a/easytier-core/src/peers/foreign_network_manager.rs b/easytier-core/src/peers/foreign_network_manager.rs index 32743b4d..e231443f 100644 --- a/easytier-core/src/peers/foreign_network_manager.rs +++ b/easytier-core/src/peers/foreign_network_manager.rs @@ -24,7 +24,7 @@ use crate::common::{ }; use super::{ - packet::{self, ArchivedPacketBody}, + packet::{self}, peer_conn::PeerConn, peer_map::PeerMap, peer_rpc::{PeerRpcManager, PeerRpcManagerTransport}, @@ -245,7 +245,7 @@ impl ForeignNetworkManager { let from_peer_id = packet.from_peer.into(); let to_peer_id = packet.to_peer.into(); if to_peer_id == my_node_id { - if let ArchivedPacketBody::TaRpc(..) = &packet.body { + if packet.packet_type == packet::PacketType::TaRpc { rpc_sender.send(packet_bytes.clone()).unwrap(); continue; } diff --git a/easytier-core/src/peers/packet.rs b/easytier-core/src/peers/packet.rs index 6ca3e142..02d9ac2c 100644 --- a/easytier-core/src/peers/packet.rs +++ b/easytier-core/src/peers/packet.rs @@ -5,7 +5,7 @@ use tokio_util::bytes::Bytes; use crate::common::{ global_ctx::NetworkIdentity, - rkyv_util::{decode_from_bytes, encode_to_bytes}, + rkyv_util::{decode_from_bytes, encode_to_bytes, vec_to_string}, PeerId, }; @@ -50,69 +50,23 @@ impl From<&ArchivedUUID> for UUID { } } -#[derive(Archive, Deserialize, Serialize)] -#[archive(compare(PartialEq), check_bytes)] -// Derives can be passed through to the generated type: -pub struct NetworkIdentityForPacket(Vec); - -impl From for NetworkIdentityForPacket { - fn from(network: NetworkIdentity) -> Self { - Self(bincode::serialize(&network).unwrap()) - } -} - -impl From for NetworkIdentity { - fn from(network: NetworkIdentityForPacket) -> Self { - bincode::deserialize(&network.0).unwrap() - } -} - -impl From<&ArchivedNetworkIdentityForPacket> for NetworkIdentity { - fn from(network: &ArchivedNetworkIdentityForPacket) -> Self { - NetworkIdentityForPacket(network.0.to_vec()).into() - } -} - -impl Debug for NetworkIdentityForPacket { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - let network: NetworkIdentity = bincode::deserialize(&self.0).unwrap(); - write!(f, "{:?}", network) - } -} - -impl Debug for ArchivedNetworkIdentityForPacket { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - let network: NetworkIdentity = bincode::deserialize(&self.0).unwrap(); - write!(f, "{:?}", network) - } -} - -#[derive(Archive, Deserialize, Serialize, Debug)] -#[archive(compare(PartialEq), check_bytes)] -// Derives can be passed through to the generated type: -#[archive_attr(derive(Debug))] +#[derive(serde::Serialize, serde::Deserialize, Debug)] pub struct HandShake { pub magic: u32, pub my_peer_id: PeerId, pub version: u32, pub features: Vec, - pub network_identity: NetworkIdentityForPacket, + pub network_identity: NetworkIdentity, } -#[derive(Archive, Deserialize, Serialize, Debug)] -#[archive(compare(PartialEq), check_bytes)] -#[archive_attr(derive(Debug))] +#[derive(serde::Serialize, serde::Deserialize, Debug)] pub struct RoutePacket { pub route_id: u8, pub body: Vec, } -#[derive(Archive, Deserialize, Serialize, Debug)] -#[archive(compare(PartialEq), check_bytes)] -// Derives can be passed through to the generated type: -#[archive_attr(derive(Debug))] -pub enum PacketBody { - Data(Vec), +#[derive(Debug, serde::Serialize, serde::Deserialize)] +pub enum CtrlPacketPayload { HandShake(HandShake), RoutePacket(RoutePacket), Ping(u32), @@ -120,20 +74,72 @@ pub enum PacketBody { TaRpc(u32, bool, Vec), // u32: service_id, bool: is_req, Vec: rpc body } +impl CtrlPacketPayload { + pub fn from_packet(p: &ArchivedPacket) -> CtrlPacketPayload { + assert_ne!(p.packet_type, PacketType::Data); + postcard::from_bytes(p.payload.as_bytes()).unwrap() + } + + pub fn from_packet2(p: &Packet) -> CtrlPacketPayload { + postcard::from_bytes(p.payload.as_bytes()).unwrap() + } +} + +#[repr(u8)] #[derive(Archive, Deserialize, Serialize, Debug)] #[archive(compare(PartialEq), check_bytes)] // Derives can be passed through to the generated type: #[archive_attr(derive(Debug))] +pub enum PacketType { + Data = 1, + HandShake = 2, + RoutePacket = 3, + Ping = 4, + Pong = 5, + TaRpc = 6, +} + +#[derive(Archive, Deserialize, Serialize, Debug)] +#[archive(compare(PartialEq), check_bytes)] +// Derives can be passed through to the generated type: pub struct Packet { pub from_peer: PeerId, pub to_peer: PeerId, - pub body: PacketBody, + pub packet_type: PacketType, + pub payload: String, +} + +impl std::fmt::Debug for ArchivedPacket { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!( + f, + "Packet {{ from_peer: {}, to_peer: {}, packet_type: {:?}, payload: {:?} }}", + self.from_peer, + self.to_peer, + self.packet_type, + &self.payload.as_bytes() + ) + } } impl Packet { pub fn decode(v: &[u8]) -> &ArchivedPacket { decode_from_bytes::(v).unwrap() } + + pub fn new( + from_peer: PeerId, + to_peer: PeerId, + packet_type: PacketType, + payload: Vec, + ) -> Self { + Packet { + from_peer, + to_peer, + packet_type, + payload: vec_to_string(payload), + } + } } impl From for Bytes { @@ -144,52 +150,56 @@ impl From for Bytes { impl Packet { pub fn new_handshake(from_peer: PeerId, network: &NetworkIdentity) -> Self { - Packet { - from_peer: from_peer.into(), - to_peer: 0, - body: PacketBody::HandShake(HandShake { - magic: MAGIC, - my_peer_id: from_peer, - version: VERSION, - features: Vec::new(), - network_identity: network.clone().into(), - }), - } + let handshake = CtrlPacketPayload::HandShake(HandShake { + magic: MAGIC, + my_peer_id: from_peer, + version: VERSION, + features: Vec::new(), + network_identity: network.clone().into(), + }); + Packet::new( + from_peer.into(), + 0, + PacketType::HandShake, + postcard::to_allocvec(&handshake).unwrap(), + ) } pub fn new_data_packet(from_peer: PeerId, to_peer: PeerId, data: &[u8]) -> Self { - Packet { - from_peer, - to_peer, - body: PacketBody::Data(data.to_vec()), - } + Packet::new(from_peer, to_peer, PacketType::Data, data.to_vec()) } pub fn new_route_packet(from_peer: PeerId, to_peer: PeerId, route_id: u8, data: &[u8]) -> Self { - Packet { + let route = CtrlPacketPayload::RoutePacket(RoutePacket { + route_id, + body: data.to_vec(), + }); + Packet::new( from_peer, to_peer, - body: PacketBody::RoutePacket(RoutePacket { - route_id, - body: data.to_vec(), - }), - } + PacketType::RoutePacket, + postcard::to_allocvec(&route).unwrap(), + ) } pub fn new_ping_packet(from_peer: PeerId, to_peer: PeerId, seq: u32) -> Self { - Packet { + let ping = CtrlPacketPayload::Ping(seq); + Packet::new( from_peer, to_peer, - body: PacketBody::Ping(seq), - } + PacketType::Ping, + postcard::to_allocvec(&ping).unwrap(), + ) } pub fn new_pong_packet(from_peer: PeerId, to_peer: PeerId, seq: u32) -> Self { - Packet { + let pong = CtrlPacketPayload::Pong(seq); + Packet::new( from_peer, to_peer, - body: PacketBody::Pong(seq), - } + PacketType::Pong, + postcard::to_allocvec(&pong).unwrap(), + ) } pub fn new_tarpc_packet( @@ -199,11 +209,13 @@ impl Packet { is_req: bool, body: Vec, ) -> Self { - Packet { + let ta_rpc = CtrlPacketPayload::TaRpc(service_id, is_req, body); + Packet::new( from_peer, to_peer, - body: PacketBody::TaRpc(service_id, is_req, body), - } + PacketType::TaRpc, + postcard::to_allocvec(&ta_rpc).unwrap(), + ) } } diff --git a/easytier-core/src/peers/peer_conn.rs b/easytier-core/src/peers/peer_conn.rs index 44a36a1a..416cff16 100644 --- a/easytier-core/src/peers/peer_conn.rs +++ b/easytier-core/src/peers/peer_conn.rs @@ -16,10 +16,7 @@ use tokio::{ time::{timeout, Duration}, }; -use tokio_util::{ - bytes::{Bytes, BytesMut}, - sync::PollSender, -}; +use tokio_util::{bytes::Bytes, sync::PollSender}; use tracing::Instrument; use crate::{ @@ -28,6 +25,7 @@ use crate::{ PeerId, }, define_tunnel_filter_chain, + peers::packet::{ArchivedPacketType, CtrlPacketPayload}, rpc::{PeerConnInfo, PeerConnStats}, tunnels::{ stats::{Throughput, WindowLatency}, @@ -36,7 +34,7 @@ use crate::{ }, }; -use super::packet::{self, ArchivedHandShake, Packet}; +use super::packet::{self, HandShake, Packet}; pub type PacketRecvChan = mpsc::Sender; @@ -54,7 +52,8 @@ macro_rules! wait_response { let $out_var; let rsp_bytes = Packet::decode(&rsp_vec); - match &rsp_bytes.body { + let resp_payload = CtrlPacketPayload::from_packet(&rsp_bytes); + match &resp_payload { $pattern => $out_var = $value, _ => { log::error!( @@ -68,19 +67,6 @@ macro_rules! wait_response { }; } -fn build_ctrl_msg(msg: Bytes, is_req: bool) -> Bytes { - let prefix: &'static [u8] = if is_req { - CTRL_REQ_PACKET_PREFIX - } else { - CTRL_RESP_PACKET_PREFIX - }; - let mut new_msg = BytesMut::new(); - new_msg.reserve(prefix.len() + msg.len()); - new_msg.extend_from_slice(prefix); - new_msg.extend_from_slice(&msg); - new_msg.into() -} - pub struct PeerInfo { magic: u32, pub my_peer_id: PeerId, @@ -90,15 +76,15 @@ pub struct PeerInfo { pub network_identity: NetworkIdentity, } -impl<'a> From<&ArchivedHandShake> for PeerInfo { - fn from(hs: &ArchivedHandShake) -> Self { +impl<'a> From<&HandShake> for PeerInfo { + fn from(hs: &HandShake) -> Self { PeerInfo { magic: hs.magic.into(), my_peer_id: hs.my_peer_id.into(), version: hs.version.into(), features: hs.features.iter().map(|x| x.to_string()).collect(), interfaces: Vec::new(), - network_identity: (&hs.network_identity).into(), + network_identity: hs.network_identity.clone(), } } } @@ -150,10 +136,7 @@ impl PeerConnPinger { seq: u32, ) -> Result { // should add seq here. so latency can be calculated more accurately - let req = build_ctrl_msg( - packet::Packet::new_ping_packet(my_node_id, peer_id, seq).into(), - true, - ); + let req = packet::Packet::new_ping_packet(my_node_id, peer_id, seq).into(); tracing::trace!("send ping packet: {:?}", req); sink.lock().await.send(req).await.map_err(|e| { tracing::warn!("send ping packet error: {:?}", e); @@ -167,9 +150,10 @@ impl PeerConnPinger { loop { match receiver.recv().await { Ok(p) => { - if let packet::ArchivedPacketBody::Pong(resp_seq) = &Packet::decode(&p).body - { - if *resp_seq == seq { + let ctrl_payload = + packet::CtrlPacketPayload::from_packet(Packet::decode(&p)); + if let packet::CtrlPacketPayload::Pong(resp_seq) = ctrl_payload { + if resp_seq == seq { break; } } @@ -247,7 +231,7 @@ impl PeerConnPinger { }); req_seq += 1; - tokio::time::sleep(Duration::from_millis(350)).await; + tokio::time::sleep(Duration::from_millis(1000)).await; } }); @@ -332,9 +316,6 @@ enum PeerConnPacketType { CtrlResp(Bytes), } -static CTRL_REQ_PACKET_PREFIX: &[u8] = &[0x12, 0x34, 0x56, 0x78, 0x9a, 0xbc, 0xde, 0xf0]; -static CTRL_RESP_PACKET_PREFIX: &[u8] = &[0x12, 0x34, 0x56, 0x78, 0x9a, 0xbc, 0xde, 0xf1]; - impl PeerConn { pub fn new(my_peer_id: PeerId, global_ctx: ArcGlobalCtx, tunnel: Box) -> Self { let (ctrl_sender, _ctrl_receiver) = broadcast::channel(100); @@ -371,7 +352,7 @@ impl PeerConn { let mut stream = self.tunnel.pin_stream(); let mut sink = self.tunnel.pin_sink(); - wait_response!(stream, hs_req, packet::ArchivedPacketBody::HandShake(x) => x); + wait_response!(stream, hs_req, CtrlPacketPayload::HandShake(x) => x); self.info = Some(PeerInfo::from(hs_req)); log::info!("handshake request: {:?}", hs_req); @@ -394,7 +375,7 @@ impl PeerConn { .run(|| packet::Packet::new_handshake(self.my_peer_id, &self.global_ctx.network)); sink.send(hs_req.into()).await?; - wait_response!(stream, hs_rsp, packet::ArchivedPacketBody::HandShake(x) => x); + wait_response!(stream, hs_rsp, CtrlPacketPayload::HandShake(x) => x); self.info = Some(PeerInfo::from(hs_rsp)); log::info!("handshake response: {:?}", hs_rsp); @@ -405,41 +386,6 @@ impl PeerConn { self.info.is_some() } - fn get_packet_type(mut bytes_item: Bytes) -> PeerConnPacketType { - if bytes_item.starts_with(CTRL_REQ_PACKET_PREFIX) { - PeerConnPacketType::CtrlReq(bytes_item.split_off(CTRL_REQ_PACKET_PREFIX.len())) - } else if bytes_item.starts_with(CTRL_RESP_PACKET_PREFIX) { - PeerConnPacketType::CtrlResp(bytes_item.split_off(CTRL_RESP_PACKET_PREFIX.len())) - } else { - PeerConnPacketType::Data(bytes_item) - } - } - - fn handle_ctrl_req_packet( - bytes_item: Bytes, - conn_info: &PeerConnInfo, - ) -> Result { - let packet = Packet::decode(&bytes_item); - match packet.body { - packet::ArchivedPacketBody::Ping(seq) => { - log::trace!("recv ping packet: {:?}", packet); - Ok(build_ctrl_msg( - packet::Packet::new_pong_packet( - conn_info.my_peer_id, - conn_info.peer_id, - seq.into(), - ) - .into(), - false, - )) - } - _ => { - log::error!("unexpected packet: {:?}", packet); - Err(TunnelError::CommonError("unexpected packet".to_owned())) - } - } - } - pub fn start_pingpong(&mut self) { let mut pingpong = PeerConnPinger::new( self.my_peer_id, @@ -487,23 +433,36 @@ impl PeerConn { break; } - match Self::get_packet_type(ret.unwrap().into()) { - PeerConnPacketType::Data(item) => { - if sender.send(item).await.is_err() { - break; - } - } - PeerConnPacketType::CtrlReq(item) => { - let ret = Self::handle_ctrl_req_packet(item, &conn_info).unwrap(); - if let Err(e) = sink.send(ret).await { + let buf = ret.unwrap(); + let p = Packet::decode(&buf); + match p.packet_type { + ArchivedPacketType::Ping => { + let CtrlPacketPayload::Ping(seq) = CtrlPacketPayload::from_packet(p) + else { + log::error!("unexpected packet: {:?}", p); + continue; + }; + + let pong = packet::Packet::new_pong_packet( + conn_info.my_peer_id, + conn_info.peer_id, + seq.into(), + ); + + if let Err(e) = sink.send(pong.into()).await { tracing::error!(?e, "peer conn send req error"); } } - PeerConnPacketType::CtrlResp(item) => { - if let Err(e) = ctrl_sender.send(item) { + ArchivedPacketType::Pong => { + if let Err(e) = ctrl_sender.send(buf.into()) { tracing::error!(?e, "peer conn send ctrl resp error"); } } + _ => { + if sender.send(buf.into()).await.is_err() { + break; + } + } } } @@ -676,7 +635,7 @@ mod tests { c_peer.start_recv_loop(tokio::sync::mpsc::channel(200).0); // wait 5s, conn should not be disconnected - tokio::time::sleep(Duration::from_secs(5)).await; + tokio::time::sleep(Duration::from_secs(15)).await; if conn_closed { assert!(close_recv.try_recv().is_ok()); diff --git a/easytier-core/src/peers/peer_manager.rs b/easytier-core/src/peers/peer_manager.rs index a37d4833..a36a197f 100644 --- a/easytier-core/src/peers/peer_manager.rs +++ b/easytier-core/src/peers/peer_manager.rs @@ -15,7 +15,8 @@ use tokio_util::bytes::{Bytes, BytesMut}; use crate::{ common::{ - error::Error, global_ctx::ArcGlobalCtx, rkyv_util::extract_bytes_from_archived_vec, PeerId, + error::Error, global_ctx::ArcGlobalCtx, rkyv_util::extract_bytes_from_archived_string, + PeerId, }, peers::{ packet, peer_conn::PeerConn, peer_rpc::PeerRpcManagerTransport, route_trait::RouteInterface, @@ -287,8 +288,6 @@ impl PeerManager { } async fn init_packet_process_pipeline(&self) { - use packet::ArchivedPacketBody; - // for tun/tap ip/eth packet. struct NicPacketProcessor { nic_channel: mpsc::Sender, @@ -300,10 +299,10 @@ impl PeerManager { packet: &packet::ArchivedPacket, data: &Bytes, ) -> Option<()> { - if let packet::ArchivedPacketBody::Data(x) = &packet.body { + if packet.packet_type == packet::PacketType::Data { // TODO: use a function to get the body ref directly for zero copy self.nic_channel - .send(extract_bytes_from_archived_vec(&data, &x)) + .send(extract_bytes_from_archived_string(data, &packet.payload)) .await .unwrap(); Some(()) @@ -333,7 +332,7 @@ impl PeerManager { packet: &packet::ArchivedPacket, data: &Bytes, ) -> Option<()> { - if let ArchivedPacketBody::TaRpc(..) = &packet.body { + if packet.packet_type == packet::PacketType::TaRpc { self.peer_rpc_tspt_sender.send(data.clone()).unwrap(); Some(()) } else { diff --git a/easytier-core/src/peers/peer_rip_route.rs b/easytier-core/src/peers/peer_rip_route.rs index a98b5c75..ae39c759 100644 --- a/easytier-core/src/peers/peer_rip_route.rs +++ b/easytier-core/src/peers/peer_rip_route.rs @@ -6,7 +6,6 @@ use std::{ use async_trait::async_trait; use dashmap::DashMap; -use rkyv::{Archive, Deserialize, Serialize}; use tokio::{ sync::{Mutex, RwLock}, task::JoinSet, @@ -15,21 +14,15 @@ use tokio_util::bytes::Bytes; use tracing::Instrument; use crate::{ - common::{ - error::Error, - global_ctx::ArcGlobalCtx, - rkyv_util::{decode_from_bytes, encode_to_bytes, extract_bytes_from_archived_vec}, - stun::StunInfoCollectorTrait, - PeerId, - }, + common::{error::Error, global_ctx::ArcGlobalCtx, stun::StunInfoCollectorTrait, PeerId}, peers::{ - packet::{self}, + packet, route_trait::{Route, RouteInterfaceBox}, }, rpc::{NatType, StunInfo}, }; -use super::{packet::ArchivedPacketBody, peer_manager::PeerPacketFilter}; +use super::{packet::CtrlPacketPayload, peer_manager::PeerPacketFilter}; const SEND_ROUTE_PERIOD_SEC: u64 = 60; const SEND_ROUTE_FAST_REPLY_SEC: u64 = 5; @@ -37,10 +30,8 @@ const ROUTE_EXPIRED_SEC: u64 = 70; type Version = u32; -#[derive(Archive, Deserialize, Serialize, Clone, Debug, PartialEq)] -#[archive(compare(PartialEq), check_bytes)] +#[derive(serde::Deserialize, serde::Serialize, Clone, Debug, PartialEq)] // Derives can be passed through to the generated type: -#[archive_attr(derive(Debug))] pub struct SyncPeerInfo { // means next hop in route table. pub peer_id: PeerId, @@ -82,10 +73,7 @@ impl SyncPeerInfo { } } -#[derive(Archive, Deserialize, Serialize, Clone, Debug)] -#[archive(compare(PartialEq), check_bytes)] -// Derives can be passed through to the generated type: -#[archive_attr(derive(Debug))] +#[derive(serde::Deserialize, serde::Serialize, Clone, Debug)] pub struct SyncPeer { pub myself: SyncPeerInfo, pub neighbors: Vec, @@ -341,7 +329,7 @@ impl BasicRoute { ); // TODO: this may exceed the MTU of the tunnel interface - .send_route_packet(encode_to_bytes::<_, 4096>(&msg), 1, peer_id) + .send_route_packet(postcard::to_allocvec(&msg).unwrap().into(), 1, peer_id) .await } @@ -380,7 +368,7 @@ impl BasicRoute { continue; } - tracing::info!( + tracing::trace!( my_id = ?my_peer_id, dst_peer_id = ?peer, version = version.get(), @@ -504,8 +492,8 @@ impl BasicRoute { #[tracing::instrument(skip(self, packet), fields(my_id = ?self.my_peer_id, ctx = ?self.global_ctx))] async fn handle_route_packet(&self, src_peer_id: PeerId, packet: Bytes) { - let packet = decode_from_bytes::(&packet).unwrap(); - let p: SyncPeer = packet.deserialize(&mut rkyv::Infallible).unwrap(); + let packet = postcard::from_bytes::(&packet).unwrap(); + let p = &packet; let mut updated = true; assert_eq!(packet.myself.peer_id, src_peer_id); self.sync_peer_from_remote @@ -639,12 +627,18 @@ impl PeerPacketFilter for BasicRoute { async fn try_process_packet_from_peer( &self, packet: &packet::ArchivedPacket, - data: &Bytes, + _data: &Bytes, ) -> Option<()> { - if let ArchivedPacketBody::RoutePacket(route_packet) = &packet.body { + if packet.packet_type == packet::PacketType::RoutePacket { + let CtrlPacketPayload::RoutePacket(route_packet) = + CtrlPacketPayload::from_packet(packet) + else { + return None; + }; + self.handle_route_packet( packet.from_peer.into(), - extract_bytes_from_archived_vec(&data, &route_packet.body), + route_packet.body.into_boxed_slice().into(), ) .await; Some(()) diff --git a/easytier-core/src/peers/peer_rpc.rs b/easytier-core/src/peers/peer_rpc.rs index b1f19881..c88fe6bd 100644 --- a/easytier-core/src/peers/peer_rpc.rs +++ b/easytier-core/src/peers/peer_rpc.rs @@ -16,7 +16,7 @@ use crate::{ peers::packet::Packet, }; -use super::packet::PacketBody; +use super::packet::CtrlPacketPayload; type PeerRpcServiceId = u32; @@ -206,8 +206,9 @@ impl PeerRpcManager { } fn parse_rpc_packet(packet: &Packet) -> Result { - match &packet.body { - PacketBody::TaRpc(id, is_req, body) => Ok(TaRpcPacketInfo { + let ctrl_packet_payload = CtrlPacketPayload::from_packet2(&packet); + match &ctrl_packet_payload { + CtrlPacketPayload::TaRpc(id, is_req, body) => Ok(TaRpcPacketInfo { from_peer: packet.from_peer.into(), to_peer: packet.to_peer.into(), service_id: *id, diff --git a/easytier-core/src/tunnels/udp_tunnel.rs b/easytier-core/src/tunnels/udp_tunnel.rs index b2d649a1..366169a0 100644 --- a/easytier-core/src/tunnels/udp_tunnel.rs +++ b/easytier-core/src/tunnels/udp_tunnel.rs @@ -13,7 +13,7 @@ use tokio_util::{ use tracing::Instrument; use crate::{ - common::rkyv_util::{self, encode_to_bytes}, + common::rkyv_util::{self, encode_to_bytes, vec_to_string}, rpc::TunnelInfo, tunnels::{build_url_from_socket_addr, close_tunnel, TunnelConnCounter, TunnelConnector}, }; @@ -30,12 +30,11 @@ pub const UDP_DATA_MTU: usize = 2500; #[derive(Archive, Deserialize, Serialize, Debug)] #[archive(compare(PartialEq), check_bytes)] // Derives can be passed through to the generated type: -#[archive_attr(derive(Debug))] pub enum UdpPacketPayload { Syn, Sack, - HolePunch(Vec), - Data(Vec), + HolePunch(String), + Data(String), } #[derive(Archive, Deserialize, Serialize, Debug)] @@ -46,18 +45,32 @@ pub struct UdpPacket { pub payload: UdpPacketPayload, } +impl std::fmt::Debug for ArchivedUdpPacketPayload { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + let mut tmp = f.debug_struct("ArchivedUdpPacketPayload"); + match self { + ArchivedUdpPacketPayload::Syn => tmp.field("Syn", &"").finish(), + ArchivedUdpPacketPayload::Sack => tmp.field("Sack", &"").finish(), + ArchivedUdpPacketPayload::HolePunch(s) => { + tmp.field("HolePunch", &s.as_bytes()).finish() + } + ArchivedUdpPacketPayload::Data(s) => tmp.field("Data", &s.as_bytes()).finish(), + } + } +} + impl UdpPacket { pub fn new_data_packet(conn_id: u32, data: Vec) -> Self { Self { conn_id, - payload: UdpPacketPayload::Data(data), + payload: UdpPacketPayload::Data(vec_to_string(data)), } } pub fn new_hole_punch_packet(data: Vec) -> Self { Self { conn_id: 0, - payload: UdpPacketPayload::HolePunch(data), + payload: UdpPacketPayload::HolePunch(vec_to_string(data)), } } @@ -77,7 +90,7 @@ impl UdpPacket { } fn try_get_data_payload(mut buf: BytesMut, conn_id: u32) -> Option { - let Ok(udp_packet) = rkyv_util::decode_from_bytes_checked::(&buf) else { + let Ok(udp_packet) = rkyv_util::decode_from_bytes::(&buf) else { tracing::warn!(?buf, "udp decode error"); return None; }; @@ -92,9 +105,13 @@ fn try_get_data_payload(mut buf: BytesMut, conn_id: u32) -> Option { return None; }; - let ptr_range = payload.as_ptr_range(); - let offset = ptr_range.start as usize - buf.as_ptr() as usize; - let len = ptr_range.end as usize - ptr_range.start as usize; + let offset = payload.as_ptr() as usize - buf.as_ptr() as usize; + let len = payload.len(); + if offset + len > buf.len() { + tracing::warn!(?offset, ?len, ?buf, "udp payload data out of range"); + return None; + } + buf.advance(offset); buf.truncate(len); tracing::trace!(?offset, ?len, ?buf, "udp payload data"); @@ -138,8 +155,8 @@ fn get_tunnel_from_socket( // TODO: two copy here, how to avoid? let udp_packet = UdpPacket::new_data_packet(conn_id, v.to_vec()); - tracing::trace!(?udp_packet, ?v, "udp send packet"); let v = encode_to_bytes::<_, UDP_DATA_MTU>(&udp_packet); + tracing::trace!(?udp_packet, ?v, "udp send packet"); Ok((v, sender_addr)) })); @@ -301,8 +318,7 @@ impl TunnelListener for UdpTunnelListener { _size ); - let Ok(udp_packet) = rkyv_util::decode_from_bytes_checked::(&buf) - else { + let Ok(udp_packet) = rkyv_util::decode_from_bytes::(&buf) else { tracing::warn!(?buf, "udp decode error in forward task"); continue; }; @@ -429,7 +445,7 @@ impl UdpTunnelConnector { let _ = buf.split_off(usize); - let Ok(udp_packet) = rkyv_util::decode_from_bytes_checked::(&buf) else { + let Ok(udp_packet) = rkyv_util::decode_from_bytes::(&buf) else { tracing::warn!(?buf, "udp decode error in wait sack"); return Err(super::TunnelError::ConnectError(format!( "udp connect error, decode error. buf: {:?}", @@ -677,4 +693,12 @@ mod tests { let _ = tokio::join!(sender1, sender2); } + + #[tokio::test] + async fn udp_packet_print() { + let udp_packet = UdpPacket::new_data_packet(1, vec![1, 2, 3, 4, 5]); + let b = encode_to_bytes::<_, UDP_DATA_MTU>(&udp_packet); + let a_udp_packet = rkyv_util::decode_from_bytes::(&b).unwrap(); + println!("{:?}, {:?}", udp_packet, a_udp_packet); + } }