diff --git a/easytier/proto/cli.proto b/easytier/proto/cli.proto index 675093d8..84f85fb6 100644 --- a/easytier/proto/cli.proto +++ b/easytier/proto/cli.proto @@ -166,4 +166,7 @@ message TaRpcPacket { uint32 transact_id = 4; bool is_req = 5; bytes content = 6; + + uint32 total_pieces = 7; + uint32 piece_idx = 8; } diff --git a/easytier/src/peers/peer_rpc.rs b/easytier/src/peers/peer_rpc.rs index 901d0d85..ec0694de 100644 --- a/easytier/src/peers/peer_rpc.rs +++ b/easytier/src/peers/peer_rpc.rs @@ -1,5 +1,9 @@ -use std::sync::{atomic::AtomicU32, Arc}; +use std::{ + sync::{atomic::AtomicU32, Arc}, + time::Instant, +}; +use crossbeam::atomic::AtomicCell; use dashmap::DashMap; use futures::{SinkExt, StreamExt}; use prost::Message; @@ -18,6 +22,8 @@ use crate::{ tunnel::packet_def::{PacketType, ZCPacket}, }; +const RPC_PACKET_CONTENT_MTU: usize = 1300; + type PeerRpcServiceId = u32; type PeerRpcTransactId = u32; @@ -34,6 +40,7 @@ type PacketSender = UnboundedSender; struct PeerRpcEndPoint { peer_id: PeerId, packet_sender: PacketSender, + last_used: AtomicCell, tasks: JoinSet<()>, } @@ -63,6 +70,87 @@ impl std::fmt::Debug for PeerRpcManager { } } +struct PacketMerger { + first_piece: Option, + pieces: Vec, +} + +impl PacketMerger { + fn new() -> Self { + Self { + first_piece: None, + pieces: Vec::new(), + } + } + + fn try_merge_pieces(&self) -> Option { + if self.first_piece.is_none() || self.pieces.is_empty() { + return None; + } + + for p in &self.pieces { + // some piece is missing + if p.total_pieces == 0 { + return None; + } + } + + // all pieces are received + let mut content = Vec::new(); + for p in &self.pieces { + content.extend_from_slice(&p.content); + } + + let mut tmpl_packet = self.first_piece.as_ref().unwrap().clone(); + tmpl_packet.total_pieces = 1; + tmpl_packet.piece_idx = 0; + tmpl_packet.content = content; + + Some(tmpl_packet) + } + + fn feed(&mut self, packet: ZCPacket) -> Result, Error> { + let payload = packet.payload(); + let rpc_packet = + TaRpcPacket::decode(payload).map_err(|e| Error::MessageDecodeError(e.to_string()))?; + + let total_pieces = rpc_packet.total_pieces; + let piece_idx = rpc_packet.piece_idx; + + // for compatibility with old version + if total_pieces == 0 && piece_idx == 0 { + return Ok(Some(rpc_packet)); + } + + if total_pieces > 100 || total_pieces == 0 { + return Err(Error::MessageDecodeError(format!( + "total_pieces is invalid: {}", + total_pieces + ))); + } + + if piece_idx >= total_pieces { + return Err(Error::MessageDecodeError( + "piece_idx >= total_pieces".to_owned(), + )); + } + + if self.first_piece.is_none() + || self.first_piece.as_ref().unwrap().transact_id != rpc_packet.transact_id + || self.first_piece.as_ref().unwrap().from_peer != rpc_packet.from_peer + { + self.first_piece = Some(rpc_packet.clone()); + self.pieces.clear(); + } + + self.pieces + .resize(total_pieces as usize, Default::default()); + self.pieces[piece_idx as usize] = rpc_packet; + + Ok(self.try_merge_pieces()) + } +} + impl PeerRpcManager { pub fn new(tspt: impl PeerRpcManagerTransport) -> Self { Self { @@ -104,6 +192,7 @@ impl PeerRpcManager { tasks.spawn(async move { let mut cur_req_peer_id = None; let mut cur_transact_id = 0; + let mut packet_merger = PacketMerger::new(); loop { tokio::select! { Some(resp) = client_transport.next() => { @@ -126,7 +215,7 @@ impl PeerRpcManager { continue; } - let msg = Self::build_rpc_packet( + let msgs = Self::build_rpc_packet( tspt.my_peer_id(), cur_req_peer_id, service_id, @@ -135,34 +224,33 @@ impl PeerRpcManager { serialized_resp.unwrap(), ); - if let Err(e) = tspt.send(msg, peer_id).await { - tracing::error!(error = ?e, peer_id = ?peer_id, service_id = ?service_id, "send resp to peer failed"); + for msg in msgs { + if let Err(e) = tspt.send(msg, peer_id).await { + tracing::error!(error = ?e, peer_id = ?peer_id, service_id = ?service_id, "send resp to peer failed"); + break; + } } } Some(packet) = packet_receiver.recv() => { - let info = Self::parse_rpc_packet(&packet); - tracing::debug!(?info, "server recv packet from peer"); - if let Err(e) = info { - tracing::error!(error = ?e, packet = ?packet, "parse rpc packet failed"); - continue; - } - let info = info.unwrap(); + tracing::trace!("recv packet from peer, packet: {:?}", packet); - if info.from_peer != peer_id { - tracing::warn!("recv packet from peer, but peer_id not match, ignore it"); - continue; - } + let info = match packet_merger.feed(packet) { + Err(e) => { + tracing::error!(error = ?e, "feed packet to merger failed"); + continue; + }, + Ok(None) => { + continue; + }, + Ok(Some(info)) => { + info + } + }; - if cur_req_peer_id.is_some() { - tracing::warn!("cur_req_peer_id is not none, ignore this packet"); - continue; - } - - assert_eq!(info.service_id, service_id); cur_req_peer_id = Some(info.from_peer); cur_transact_id = info.transact_id; - tracing::trace!("recv packet from peer, packet: {:?}", packet); + assert_eq!(info.service_id, service_id); let decoded_ret = postcard::from_bytes(&info.content.as_slice()); if let Err(e) = decoded_ret { @@ -191,6 +279,7 @@ impl PeerRpcManager { return PeerRpcEndPoint { peer_id, packet_sender, + last_used: AtomicCell::new(Instant::now()), tasks, }; // let resp = client_transport.next().await; @@ -222,21 +311,40 @@ impl PeerRpcManager { transact_id: PeerRpcTransactId, is_req: bool, content: Vec, - ) -> ZCPacket { - let packet = TaRpcPacket { - from_peer, - to_peer, - service_id, - transact_id, - is_req, - content, - }; - let mut buf = Vec::new(); - packet.encode(&mut buf).unwrap(); + ) -> Vec { + let mut ret = Vec::new(); + let content_mtu = RPC_PACKET_CONTENT_MTU; + let total_pieces = (content.len() + content_mtu - 1) / content_mtu; + let mut cur_offset = 0; + while cur_offset < content.len() { + let mut cur_len = content_mtu; + if cur_offset + cur_len > content.len() { + cur_len = content.len() - cur_offset; + } - let mut zc_packet = ZCPacket::new_with_payload(&buf); - zc_packet.fill_peer_manager_hdr(from_peer, to_peer, PacketType::TaRpc as u8); - zc_packet + let mut cur_content = Vec::new(); + cur_content.extend_from_slice(&content[cur_offset..cur_offset + cur_len]); + + let cur_packet = TaRpcPacket { + from_peer, + to_peer, + service_id, + transact_id, + is_req, + total_pieces: total_pieces as u32, + piece_idx: (cur_offset / content_mtu) as u32, + content: cur_content, + }; + cur_offset += cur_len; + + let mut buf = Vec::new(); + cur_packet.encode(&mut buf).unwrap(); + let mut zc_packet = ZCPacket::new_with_payload(&buf); + zc_packet.fill_peer_manager_hdr(from_peer, to_peer, PacketType::TaRpc as u8); + ret.push(zc_packet); + } + + ret } pub fn run(&self) { @@ -270,6 +378,7 @@ impl PeerRpcManager { service_registry.get(&info.service_id).unwrap()(info.from_peer) }); + endpoint.last_used.store(Instant::now()); endpoint.packet_sender.send(o).unwrap(); } else { if let Some(a) = client_resp_receivers.get(&PeerRpcClientCtxKey( @@ -287,6 +396,14 @@ impl PeerRpcManager { } } }); + + let peer_rpc_endpoints = self.peer_rpc_endpoints.clone(); + tokio::spawn(async move { + loop { + tokio::time::sleep(tokio::time::Duration::from_secs(60)).await; + peer_rpc_endpoints.retain(|_, v| v.last_used.load().elapsed().as_secs() < 60); + } + }); } #[tracing::instrument(skip(f))] @@ -327,7 +444,7 @@ impl PeerRpcManager { continue; } - let packet = Self::build_rpc_packet( + let packets = Self::build_rpc_packet( tspt.my_peer_id(), dst_peer_id, service_id, @@ -336,10 +453,13 @@ impl PeerRpcManager { a.unwrap(), ); - tracing::debug!(?packet, "client send rpc packet to peer"); + tracing::debug!(?packets, "client send rpc packet to peer"); - if let Err(e) = tspt.send(packet, dst_peer_id).await { - tracing::error!(error = ?e, dst_peer_id = ?dst_peer_id, "send to peer failed"); + for packet in packets { + if let Err(e) = tspt.send(packet, dst_peer_id).await { + tracing::error!(error = ?e, dst_peer_id = ?dst_peer_id, "send to peer failed"); + break; + } } } @@ -347,17 +467,24 @@ impl PeerRpcManager { }); tasks.spawn(async move { + let mut packet_merger = PacketMerger::new(); while let Some(packet) = packet_receiver.recv().await { tracing::trace!("tunnel recv: {:?}", packet); - let info = Self::parse_rpc_packet(&packet); - if let Err(e) = info { - tracing::error!(error = ?e, "parse rpc packet failed"); - continue; - } + let info = match packet_merger.feed(packet) { + Err(e) => { + tracing::error!(error = ?e, "feed packet to merger failed"); + continue; + } + Ok(None) => { + continue; + } + Ok(Some(info)) => info, + }; + tracing::debug!(?info, "client recv rpc packet from peer"); - let decoded = postcard::from_bytes(&info.unwrap().content.as_slice()); + let decoded = postcard::from_bytes(&info.content.as_slice()); if let Err(e) = decoded { tracing::error!(error = ?e, "decode rpc packet failed"); continue; @@ -426,6 +553,17 @@ pub mod tests { } } + fn random_string(len: usize) -> String { + use rand::distributions::Alphanumeric; + use rand::Rng; + let mut rng = rand::thread_rng(); + let s: Vec = std::iter::repeat(()) + .map(|()| rng.sample(Alphanumeric)) + .take(len) + .collect(); + String::from_utf8(s).unwrap() + } + #[tokio::test] async fn peer_rpc_basic_test() { struct MockTransport { @@ -473,16 +611,29 @@ pub mod tests { }); client_rpc_mgr.run(); + let msg = random_string(8192); let ret = client_rpc_mgr .do_client_rpc_scoped(1, server_rpc_mgr.my_peer_id(), |c| async { let c = TestRpcServiceClient::new(tarpc::client::Config::default(), c).spawn(); - let ret = c.hello(tarpc::context::current(), "abc".to_owned()).await; + let ret = c.hello(tarpc::context::current(), msg.clone()).await; ret }) .await; println!("ret: {:?}", ret); - assert_eq!(ret.unwrap(), "hello abc"); + assert_eq!(ret.unwrap(), format!("hello {}", msg)); + + let msg = random_string(10); + let ret = client_rpc_mgr + .do_client_rpc_scoped(1, server_rpc_mgr.my_peer_id(), |c| async { + let c = TestRpcServiceClient::new(tarpc::client::Config::default(), c).spawn(); + let ret = c.hello(tarpc::context::current(), msg.clone()).await; + ret + }) + .await; + + println!("ret: {:?}", ret); + assert_eq!(ret.unwrap(), format!("hello {}", msg)); } #[tokio::test] @@ -516,39 +667,42 @@ pub mod tests { }; peer_mgr_b.get_peer_rpc_mgr().run_service(1, s.serve()); + let msg = random_string(16 * 1024); let ip_list = peer_mgr_a .get_peer_rpc_mgr() .do_client_rpc_scoped(1, peer_mgr_b.my_peer_id(), |c| async { let c = TestRpcServiceClient::new(tarpc::client::Config::default(), c).spawn(); - let ret = c.hello(tarpc::context::current(), "abc".to_owned()).await; + let ret = c.hello(tarpc::context::current(), msg.clone()).await; ret }) .await; println!("ip_list: {:?}", ip_list); - assert_eq!(ip_list.as_ref().unwrap(), "hello abc"); + assert_eq!(ip_list.unwrap(), format!("hello {}", msg)); // call again + let msg = random_string(16 * 1024); let ip_list = peer_mgr_a .get_peer_rpc_mgr() .do_client_rpc_scoped(1, peer_mgr_b.my_peer_id(), |c| async { let c = TestRpcServiceClient::new(tarpc::client::Config::default(), c).spawn(); - let ret = c.hello(tarpc::context::current(), "abcd".to_owned()).await; + let ret = c.hello(tarpc::context::current(), msg.clone()).await; ret }) .await; println!("ip_list: {:?}", ip_list); - assert_eq!(ip_list.as_ref().unwrap(), "hello abcd"); + assert_eq!(ip_list.unwrap(), format!("hello {}", msg)); + let msg = random_string(16 * 1024); let ip_list = peer_mgr_c .get_peer_rpc_mgr() .do_client_rpc_scoped(1, peer_mgr_b.my_peer_id(), |c| async { let c = TestRpcServiceClient::new(tarpc::client::Config::default(), c).spawn(); - let ret = c.hello(tarpc::context::current(), "bcd".to_owned()).await; + let ret = c.hello(tarpc::context::current(), msg.clone()).await; ret }) .await; println!("ip_list: {:?}", ip_list); - assert_eq!(ip_list.as_ref().unwrap(), "hello bcd"); + assert_eq!(ip_list.unwrap(), format!("hello {}", msg)); } #[tokio::test] @@ -575,26 +729,27 @@ pub mod tests { }; peer_mgr_b.get_peer_rpc_mgr().run_service(2, b.serve()); + let msg = random_string(16 * 1024); let ip_list = peer_mgr_a .get_peer_rpc_mgr() .do_client_rpc_scoped(1, peer_mgr_b.my_peer_id(), |c| async { let c = TestRpcServiceClient::new(tarpc::client::Config::default(), c).spawn(); - let ret = c.hello(tarpc::context::current(), "abc".to_owned()).await; + let ret = c.hello(tarpc::context::current(), msg.clone()).await; ret }) .await; + assert_eq!(ip_list.unwrap(), format!("hello_a {}", msg)); - assert_eq!(ip_list.as_ref().unwrap(), "hello_a abc"); - + let msg = random_string(16 * 1024); let ip_list = peer_mgr_a .get_peer_rpc_mgr() .do_client_rpc_scoped(2, peer_mgr_b.my_peer_id(), |c| async { let c = TestRpcServiceClient::new(tarpc::client::Config::default(), c).spawn(); - let ret = c.hello(tarpc::context::current(), "abc".to_owned()).await; + let ret = c.hello(tarpc::context::current(), msg.clone()).await; ret }) .await; - assert_eq!(ip_list.as_ref().unwrap(), "hello_b abc"); + assert_eq!(ip_list.unwrap(), format!("hello_b {}", msg)); } } diff --git a/easytier/src/tunnel/common.rs b/easytier/src/tunnel/common.rs index 04632d46..5a2a4f45 100644 --- a/easytier/src/tunnel/common.rs +++ b/easytier/src/tunnel/common.rs @@ -146,7 +146,7 @@ where reserve_buf( &mut self_mut.buf, *self_mut.max_packet_size, - *self_mut.max_packet_size * 64, + *self_mut.max_packet_size * 32, ); let cap = self_mut.buf.capacity() - self_mut.buf.len(); diff --git a/easytier/src/tunnel/quic.rs b/easytier/src/tunnel/quic.rs index 497fea75..b09895ad 100644 --- a/easytier/src/tunnel/quic.rs +++ b/easytier/src/tunnel/quic.rs @@ -118,7 +118,7 @@ impl TunnelListener for QUICTunnelListener { }; Ok(Box::new(TunnelWrapper::new( - FramedReader::new_with_associate_data(r, 4500, Some(Box::new(arc_conn.clone()))), + FramedReader::new_with_associate_data(r, 2000, Some(Box::new(arc_conn.clone()))), FramedWriter::new_with_associate_data(w, Some(Box::new(arc_conn))), Some(info), ))) diff --git a/easytier/src/tunnel/tcp.rs b/easytier/src/tunnel/tcp.rs index 60cc6b69..ed652663 100644 --- a/easytier/src/tunnel/tcp.rs +++ b/easytier/src/tunnel/tcp.rs @@ -12,7 +12,7 @@ use super::{ IpVersion, Tunnel, TunnelError, TunnelListener, }; -const TCP_MTU_BYTES: usize = 64 * 1024; +const TCP_MTU_BYTES: usize = 2000; #[derive(Debug)] pub struct TcpTunnelListener { diff --git a/easytier/src/tunnel/udp.rs b/easytier/src/tunnel/udp.rs index 415385a5..9e4709bb 100644 --- a/easytier/src/tunnel/udp.rs +++ b/easytier/src/tunnel/udp.rs @@ -33,7 +33,7 @@ use super::{ IpVersion, Tunnel, TunnelConnCounter, TunnelError, TunnelListener, TunnelUrl, }; -pub const UDP_DATA_MTU: usize = 65000; +pub const UDP_DATA_MTU: usize = 2000; type UdpCloseEventSender = UnboundedSender>; type UdpCloseEventReceiver = UnboundedReceiver>; @@ -318,7 +318,7 @@ impl UdpTunnelListenerData { let socket = self.socket.as_ref().unwrap().clone(); let mut buf = BytesMut::new(); loop { - reserve_buf(&mut buf, UDP_DATA_MTU, UDP_DATA_MTU * 128); + reserve_buf(&mut buf, UDP_DATA_MTU, UDP_DATA_MTU * 16); let (dg_size, addr) = socket.recv_buf_from(&mut buf).await.unwrap(); tracing::trace!( "udp recv packet: {:?}, buf: {:?}, size: {}", @@ -555,7 +555,7 @@ impl UdpTunnelConnector { tokio::spawn(async move { let mut buf = BytesMut::new(); loop { - reserve_buf(&mut buf, UDP_DATA_MTU, UDP_DATA_MTU * 128); + reserve_buf(&mut buf, UDP_DATA_MTU, UDP_DATA_MTU * 16); let ret; tokio::select! { _ = close_event_recv.recv() => {