optimize packet def (#31)

This commit is contained in:
Sijie.Sun
2024-03-13 22:43:52 +08:00
committed by GitHub
parent b0494687b5
commit ecb385a82c
15 changed files with 240 additions and 235 deletions
@@ -24,7 +24,7 @@ use crate::common::{
};
use super::{
packet::{self, ArchivedPacketBody},
packet::{self},
peer_conn::PeerConn,
peer_map::PeerMap,
peer_rpc::{PeerRpcManager, PeerRpcManagerTransport},
@@ -245,7 +245,7 @@ impl ForeignNetworkManager {
let from_peer_id = packet.from_peer.into();
let to_peer_id = packet.to_peer.into();
if to_peer_id == my_node_id {
if let ArchivedPacketBody::TaRpc(..) = &packet.body {
if packet.packet_type == packet::PacketType::TaRpc {
rpc_sender.send(packet_bytes.clone()).unwrap();
continue;
}
+96 -84
View File
@@ -5,7 +5,7 @@ use tokio_util::bytes::Bytes;
use crate::common::{
global_ctx::NetworkIdentity,
rkyv_util::{decode_from_bytes, encode_to_bytes},
rkyv_util::{decode_from_bytes, encode_to_bytes, vec_to_string},
PeerId,
};
@@ -50,69 +50,23 @@ impl From<&ArchivedUUID> for UUID {
}
}
#[derive(Archive, Deserialize, Serialize)]
#[archive(compare(PartialEq), check_bytes)]
// Derives can be passed through to the generated type:
pub struct NetworkIdentityForPacket(Vec<u8>);
impl From<NetworkIdentity> for NetworkIdentityForPacket {
fn from(network: NetworkIdentity) -> Self {
Self(bincode::serialize(&network).unwrap())
}
}
impl From<NetworkIdentityForPacket> for NetworkIdentity {
fn from(network: NetworkIdentityForPacket) -> Self {
bincode::deserialize(&network.0).unwrap()
}
}
impl From<&ArchivedNetworkIdentityForPacket> for NetworkIdentity {
fn from(network: &ArchivedNetworkIdentityForPacket) -> Self {
NetworkIdentityForPacket(network.0.to_vec()).into()
}
}
impl Debug for NetworkIdentityForPacket {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let network: NetworkIdentity = bincode::deserialize(&self.0).unwrap();
write!(f, "{:?}", network)
}
}
impl Debug for ArchivedNetworkIdentityForPacket {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let network: NetworkIdentity = bincode::deserialize(&self.0).unwrap();
write!(f, "{:?}", network)
}
}
#[derive(Archive, Deserialize, Serialize, Debug)]
#[archive(compare(PartialEq), check_bytes)]
// Derives can be passed through to the generated type:
#[archive_attr(derive(Debug))]
#[derive(serde::Serialize, serde::Deserialize, Debug)]
pub struct HandShake {
pub magic: u32,
pub my_peer_id: PeerId,
pub version: u32,
pub features: Vec<String>,
pub network_identity: NetworkIdentityForPacket,
pub network_identity: NetworkIdentity,
}
#[derive(Archive, Deserialize, Serialize, Debug)]
#[archive(compare(PartialEq), check_bytes)]
#[archive_attr(derive(Debug))]
#[derive(serde::Serialize, serde::Deserialize, Debug)]
pub struct RoutePacket {
pub route_id: u8,
pub body: Vec<u8>,
}
#[derive(Archive, Deserialize, Serialize, Debug)]
#[archive(compare(PartialEq), check_bytes)]
// Derives can be passed through to the generated type:
#[archive_attr(derive(Debug))]
pub enum PacketBody {
Data(Vec<u8>),
#[derive(Debug, serde::Serialize, serde::Deserialize)]
pub enum CtrlPacketPayload {
HandShake(HandShake),
RoutePacket(RoutePacket),
Ping(u32),
@@ -120,20 +74,72 @@ pub enum PacketBody {
TaRpc(u32, bool, Vec<u8>), // u32: service_id, bool: is_req, Vec<u8>: rpc body
}
impl CtrlPacketPayload {
pub fn from_packet(p: &ArchivedPacket) -> CtrlPacketPayload {
assert_ne!(p.packet_type, PacketType::Data);
postcard::from_bytes(p.payload.as_bytes()).unwrap()
}
pub fn from_packet2(p: &Packet) -> CtrlPacketPayload {
postcard::from_bytes(p.payload.as_bytes()).unwrap()
}
}
#[repr(u8)]
#[derive(Archive, Deserialize, Serialize, Debug)]
#[archive(compare(PartialEq), check_bytes)]
// Derives can be passed through to the generated type:
#[archive_attr(derive(Debug))]
pub enum PacketType {
Data = 1,
HandShake = 2,
RoutePacket = 3,
Ping = 4,
Pong = 5,
TaRpc = 6,
}
#[derive(Archive, Deserialize, Serialize, Debug)]
#[archive(compare(PartialEq), check_bytes)]
// Derives can be passed through to the generated type:
pub struct Packet {
pub from_peer: PeerId,
pub to_peer: PeerId,
pub body: PacketBody,
pub packet_type: PacketType,
pub payload: String,
}
impl std::fmt::Debug for ArchivedPacket {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"Packet {{ from_peer: {}, to_peer: {}, packet_type: {:?}, payload: {:?} }}",
self.from_peer,
self.to_peer,
self.packet_type,
&self.payload.as_bytes()
)
}
}
impl Packet {
pub fn decode(v: &[u8]) -> &ArchivedPacket {
decode_from_bytes::<Packet>(v).unwrap()
}
pub fn new(
from_peer: PeerId,
to_peer: PeerId,
packet_type: PacketType,
payload: Vec<u8>,
) -> Self {
Packet {
from_peer,
to_peer,
packet_type,
payload: vec_to_string(payload),
}
}
}
impl From<Packet> for Bytes {
@@ -144,52 +150,56 @@ impl From<Packet> for Bytes {
impl Packet {
pub fn new_handshake(from_peer: PeerId, network: &NetworkIdentity) -> Self {
Packet {
from_peer: from_peer.into(),
to_peer: 0,
body: PacketBody::HandShake(HandShake {
magic: MAGIC,
my_peer_id: from_peer,
version: VERSION,
features: Vec::new(),
network_identity: network.clone().into(),
}),
}
let handshake = CtrlPacketPayload::HandShake(HandShake {
magic: MAGIC,
my_peer_id: from_peer,
version: VERSION,
features: Vec::new(),
network_identity: network.clone().into(),
});
Packet::new(
from_peer.into(),
0,
PacketType::HandShake,
postcard::to_allocvec(&handshake).unwrap(),
)
}
pub fn new_data_packet(from_peer: PeerId, to_peer: PeerId, data: &[u8]) -> Self {
Packet {
from_peer,
to_peer,
body: PacketBody::Data(data.to_vec()),
}
Packet::new(from_peer, to_peer, PacketType::Data, data.to_vec())
}
pub fn new_route_packet(from_peer: PeerId, to_peer: PeerId, route_id: u8, data: &[u8]) -> Self {
Packet {
let route = CtrlPacketPayload::RoutePacket(RoutePacket {
route_id,
body: data.to_vec(),
});
Packet::new(
from_peer,
to_peer,
body: PacketBody::RoutePacket(RoutePacket {
route_id,
body: data.to_vec(),
}),
}
PacketType::RoutePacket,
postcard::to_allocvec(&route).unwrap(),
)
}
pub fn new_ping_packet(from_peer: PeerId, to_peer: PeerId, seq: u32) -> Self {
Packet {
let ping = CtrlPacketPayload::Ping(seq);
Packet::new(
from_peer,
to_peer,
body: PacketBody::Ping(seq),
}
PacketType::Ping,
postcard::to_allocvec(&ping).unwrap(),
)
}
pub fn new_pong_packet(from_peer: PeerId, to_peer: PeerId, seq: u32) -> Self {
Packet {
let pong = CtrlPacketPayload::Pong(seq);
Packet::new(
from_peer,
to_peer,
body: PacketBody::Pong(seq),
}
PacketType::Pong,
postcard::to_allocvec(&pong).unwrap(),
)
}
pub fn new_tarpc_packet(
@@ -199,11 +209,13 @@ impl Packet {
is_req: bool,
body: Vec<u8>,
) -> Self {
Packet {
let ta_rpc = CtrlPacketPayload::TaRpc(service_id, is_req, body);
Packet::new(
from_peer,
to_peer,
body: PacketBody::TaRpc(service_id, is_req, body),
}
PacketType::TaRpc,
postcard::to_allocvec(&ta_rpc).unwrap(),
)
}
}
+41 -82
View File
@@ -16,10 +16,7 @@ use tokio::{
time::{timeout, Duration},
};
use tokio_util::{
bytes::{Bytes, BytesMut},
sync::PollSender,
};
use tokio_util::{bytes::Bytes, sync::PollSender};
use tracing::Instrument;
use crate::{
@@ -28,6 +25,7 @@ use crate::{
PeerId,
},
define_tunnel_filter_chain,
peers::packet::{ArchivedPacketType, CtrlPacketPayload},
rpc::{PeerConnInfo, PeerConnStats},
tunnels::{
stats::{Throughput, WindowLatency},
@@ -36,7 +34,7 @@ use crate::{
},
};
use super::packet::{self, ArchivedHandShake, Packet};
use super::packet::{self, HandShake, Packet};
pub type PacketRecvChan = mpsc::Sender<Bytes>;
@@ -54,7 +52,8 @@ macro_rules! wait_response {
let $out_var;
let rsp_bytes = Packet::decode(&rsp_vec);
match &rsp_bytes.body {
let resp_payload = CtrlPacketPayload::from_packet(&rsp_bytes);
match &resp_payload {
$pattern => $out_var = $value,
_ => {
log::error!(
@@ -68,19 +67,6 @@ macro_rules! wait_response {
};
}
fn build_ctrl_msg(msg: Bytes, is_req: bool) -> Bytes {
let prefix: &'static [u8] = if is_req {
CTRL_REQ_PACKET_PREFIX
} else {
CTRL_RESP_PACKET_PREFIX
};
let mut new_msg = BytesMut::new();
new_msg.reserve(prefix.len() + msg.len());
new_msg.extend_from_slice(prefix);
new_msg.extend_from_slice(&msg);
new_msg.into()
}
pub struct PeerInfo {
magic: u32,
pub my_peer_id: PeerId,
@@ -90,15 +76,15 @@ pub struct PeerInfo {
pub network_identity: NetworkIdentity,
}
impl<'a> From<&ArchivedHandShake> for PeerInfo {
fn from(hs: &ArchivedHandShake) -> Self {
impl<'a> From<&HandShake> for PeerInfo {
fn from(hs: &HandShake) -> Self {
PeerInfo {
magic: hs.magic.into(),
my_peer_id: hs.my_peer_id.into(),
version: hs.version.into(),
features: hs.features.iter().map(|x| x.to_string()).collect(),
interfaces: Vec::new(),
network_identity: (&hs.network_identity).into(),
network_identity: hs.network_identity.clone(),
}
}
}
@@ -150,10 +136,7 @@ impl PeerConnPinger {
seq: u32,
) -> Result<u128, TunnelError> {
// should add seq here. so latency can be calculated more accurately
let req = build_ctrl_msg(
packet::Packet::new_ping_packet(my_node_id, peer_id, seq).into(),
true,
);
let req = packet::Packet::new_ping_packet(my_node_id, peer_id, seq).into();
tracing::trace!("send ping packet: {:?}", req);
sink.lock().await.send(req).await.map_err(|e| {
tracing::warn!("send ping packet error: {:?}", e);
@@ -167,9 +150,10 @@ impl PeerConnPinger {
loop {
match receiver.recv().await {
Ok(p) => {
if let packet::ArchivedPacketBody::Pong(resp_seq) = &Packet::decode(&p).body
{
if *resp_seq == seq {
let ctrl_payload =
packet::CtrlPacketPayload::from_packet(Packet::decode(&p));
if let packet::CtrlPacketPayload::Pong(resp_seq) = ctrl_payload {
if resp_seq == seq {
break;
}
}
@@ -247,7 +231,7 @@ impl PeerConnPinger {
});
req_seq += 1;
tokio::time::sleep(Duration::from_millis(350)).await;
tokio::time::sleep(Duration::from_millis(1000)).await;
}
});
@@ -332,9 +316,6 @@ enum PeerConnPacketType {
CtrlResp(Bytes),
}
static CTRL_REQ_PACKET_PREFIX: &[u8] = &[0x12, 0x34, 0x56, 0x78, 0x9a, 0xbc, 0xde, 0xf0];
static CTRL_RESP_PACKET_PREFIX: &[u8] = &[0x12, 0x34, 0x56, 0x78, 0x9a, 0xbc, 0xde, 0xf1];
impl PeerConn {
pub fn new(my_peer_id: PeerId, global_ctx: ArcGlobalCtx, tunnel: Box<dyn Tunnel>) -> Self {
let (ctrl_sender, _ctrl_receiver) = broadcast::channel(100);
@@ -371,7 +352,7 @@ impl PeerConn {
let mut stream = self.tunnel.pin_stream();
let mut sink = self.tunnel.pin_sink();
wait_response!(stream, hs_req, packet::ArchivedPacketBody::HandShake(x) => x);
wait_response!(stream, hs_req, CtrlPacketPayload::HandShake(x) => x);
self.info = Some(PeerInfo::from(hs_req));
log::info!("handshake request: {:?}", hs_req);
@@ -394,7 +375,7 @@ impl PeerConn {
.run(|| packet::Packet::new_handshake(self.my_peer_id, &self.global_ctx.network));
sink.send(hs_req.into()).await?;
wait_response!(stream, hs_rsp, packet::ArchivedPacketBody::HandShake(x) => x);
wait_response!(stream, hs_rsp, CtrlPacketPayload::HandShake(x) => x);
self.info = Some(PeerInfo::from(hs_rsp));
log::info!("handshake response: {:?}", hs_rsp);
@@ -405,41 +386,6 @@ impl PeerConn {
self.info.is_some()
}
fn get_packet_type(mut bytes_item: Bytes) -> PeerConnPacketType {
if bytes_item.starts_with(CTRL_REQ_PACKET_PREFIX) {
PeerConnPacketType::CtrlReq(bytes_item.split_off(CTRL_REQ_PACKET_PREFIX.len()))
} else if bytes_item.starts_with(CTRL_RESP_PACKET_PREFIX) {
PeerConnPacketType::CtrlResp(bytes_item.split_off(CTRL_RESP_PACKET_PREFIX.len()))
} else {
PeerConnPacketType::Data(bytes_item)
}
}
fn handle_ctrl_req_packet(
bytes_item: Bytes,
conn_info: &PeerConnInfo,
) -> Result<Bytes, TunnelError> {
let packet = Packet::decode(&bytes_item);
match packet.body {
packet::ArchivedPacketBody::Ping(seq) => {
log::trace!("recv ping packet: {:?}", packet);
Ok(build_ctrl_msg(
packet::Packet::new_pong_packet(
conn_info.my_peer_id,
conn_info.peer_id,
seq.into(),
)
.into(),
false,
))
}
_ => {
log::error!("unexpected packet: {:?}", packet);
Err(TunnelError::CommonError("unexpected packet".to_owned()))
}
}
}
pub fn start_pingpong(&mut self) {
let mut pingpong = PeerConnPinger::new(
self.my_peer_id,
@@ -487,23 +433,36 @@ impl PeerConn {
break;
}
match Self::get_packet_type(ret.unwrap().into()) {
PeerConnPacketType::Data(item) => {
if sender.send(item).await.is_err() {
break;
}
}
PeerConnPacketType::CtrlReq(item) => {
let ret = Self::handle_ctrl_req_packet(item, &conn_info).unwrap();
if let Err(e) = sink.send(ret).await {
let buf = ret.unwrap();
let p = Packet::decode(&buf);
match p.packet_type {
ArchivedPacketType::Ping => {
let CtrlPacketPayload::Ping(seq) = CtrlPacketPayload::from_packet(p)
else {
log::error!("unexpected packet: {:?}", p);
continue;
};
let pong = packet::Packet::new_pong_packet(
conn_info.my_peer_id,
conn_info.peer_id,
seq.into(),
);
if let Err(e) = sink.send(pong.into()).await {
tracing::error!(?e, "peer conn send req error");
}
}
PeerConnPacketType::CtrlResp(item) => {
if let Err(e) = ctrl_sender.send(item) {
ArchivedPacketType::Pong => {
if let Err(e) = ctrl_sender.send(buf.into()) {
tracing::error!(?e, "peer conn send ctrl resp error");
}
}
_ => {
if sender.send(buf.into()).await.is_err() {
break;
}
}
}
}
@@ -676,7 +635,7 @@ mod tests {
c_peer.start_recv_loop(tokio::sync::mpsc::channel(200).0);
// wait 5s, conn should not be disconnected
tokio::time::sleep(Duration::from_secs(5)).await;
tokio::time::sleep(Duration::from_secs(15)).await;
if conn_closed {
assert!(close_recv.try_recv().is_ok());
+5 -6
View File
@@ -15,7 +15,8 @@ use tokio_util::bytes::{Bytes, BytesMut};
use crate::{
common::{
error::Error, global_ctx::ArcGlobalCtx, rkyv_util::extract_bytes_from_archived_vec, PeerId,
error::Error, global_ctx::ArcGlobalCtx, rkyv_util::extract_bytes_from_archived_string,
PeerId,
},
peers::{
packet, peer_conn::PeerConn, peer_rpc::PeerRpcManagerTransport, route_trait::RouteInterface,
@@ -287,8 +288,6 @@ impl PeerManager {
}
async fn init_packet_process_pipeline(&self) {
use packet::ArchivedPacketBody;
// for tun/tap ip/eth packet.
struct NicPacketProcessor {
nic_channel: mpsc::Sender<SinkItem>,
@@ -300,10 +299,10 @@ impl PeerManager {
packet: &packet::ArchivedPacket,
data: &Bytes,
) -> Option<()> {
if let packet::ArchivedPacketBody::Data(x) = &packet.body {
if packet.packet_type == packet::PacketType::Data {
// TODO: use a function to get the body ref directly for zero copy
self.nic_channel
.send(extract_bytes_from_archived_vec(&data, &x))
.send(extract_bytes_from_archived_string(data, &packet.payload))
.await
.unwrap();
Some(())
@@ -333,7 +332,7 @@ impl PeerManager {
packet: &packet::ArchivedPacket,
data: &Bytes,
) -> Option<()> {
if let ArchivedPacketBody::TaRpc(..) = &packet.body {
if packet.packet_type == packet::PacketType::TaRpc {
self.peer_rpc_tspt_sender.send(data.clone()).unwrap();
Some(())
} else {
+18 -24
View File
@@ -6,7 +6,6 @@ use std::{
use async_trait::async_trait;
use dashmap::DashMap;
use rkyv::{Archive, Deserialize, Serialize};
use tokio::{
sync::{Mutex, RwLock},
task::JoinSet,
@@ -15,21 +14,15 @@ use tokio_util::bytes::Bytes;
use tracing::Instrument;
use crate::{
common::{
error::Error,
global_ctx::ArcGlobalCtx,
rkyv_util::{decode_from_bytes, encode_to_bytes, extract_bytes_from_archived_vec},
stun::StunInfoCollectorTrait,
PeerId,
},
common::{error::Error, global_ctx::ArcGlobalCtx, stun::StunInfoCollectorTrait, PeerId},
peers::{
packet::{self},
packet,
route_trait::{Route, RouteInterfaceBox},
},
rpc::{NatType, StunInfo},
};
use super::{packet::ArchivedPacketBody, peer_manager::PeerPacketFilter};
use super::{packet::CtrlPacketPayload, peer_manager::PeerPacketFilter};
const SEND_ROUTE_PERIOD_SEC: u64 = 60;
const SEND_ROUTE_FAST_REPLY_SEC: u64 = 5;
@@ -37,10 +30,8 @@ const ROUTE_EXPIRED_SEC: u64 = 70;
type Version = u32;
#[derive(Archive, Deserialize, Serialize, Clone, Debug, PartialEq)]
#[archive(compare(PartialEq), check_bytes)]
#[derive(serde::Deserialize, serde::Serialize, Clone, Debug, PartialEq)]
// Derives can be passed through to the generated type:
#[archive_attr(derive(Debug))]
pub struct SyncPeerInfo {
// means next hop in route table.
pub peer_id: PeerId,
@@ -82,10 +73,7 @@ impl SyncPeerInfo {
}
}
#[derive(Archive, Deserialize, Serialize, Clone, Debug)]
#[archive(compare(PartialEq), check_bytes)]
// Derives can be passed through to the generated type:
#[archive_attr(derive(Debug))]
#[derive(serde::Deserialize, serde::Serialize, Clone, Debug)]
pub struct SyncPeer {
pub myself: SyncPeerInfo,
pub neighbors: Vec<SyncPeerInfo>,
@@ -341,7 +329,7 @@ impl BasicRoute {
);
// TODO: this may exceed the MTU of the tunnel
interface
.send_route_packet(encode_to_bytes::<_, 4096>(&msg), 1, peer_id)
.send_route_packet(postcard::to_allocvec(&msg).unwrap().into(), 1, peer_id)
.await
}
@@ -380,7 +368,7 @@ impl BasicRoute {
continue;
}
tracing::info!(
tracing::trace!(
my_id = ?my_peer_id,
dst_peer_id = ?peer,
version = version.get(),
@@ -504,8 +492,8 @@ impl BasicRoute {
#[tracing::instrument(skip(self, packet), fields(my_id = ?self.my_peer_id, ctx = ?self.global_ctx))]
async fn handle_route_packet(&self, src_peer_id: PeerId, packet: Bytes) {
let packet = decode_from_bytes::<SyncPeer>(&packet).unwrap();
let p: SyncPeer = packet.deserialize(&mut rkyv::Infallible).unwrap();
let packet = postcard::from_bytes::<SyncPeer>(&packet).unwrap();
let p = &packet;
let mut updated = true;
assert_eq!(packet.myself.peer_id, src_peer_id);
self.sync_peer_from_remote
@@ -639,12 +627,18 @@ impl PeerPacketFilter for BasicRoute {
async fn try_process_packet_from_peer(
&self,
packet: &packet::ArchivedPacket,
data: &Bytes,
_data: &Bytes,
) -> Option<()> {
if let ArchivedPacketBody::RoutePacket(route_packet) = &packet.body {
if packet.packet_type == packet::PacketType::RoutePacket {
let CtrlPacketPayload::RoutePacket(route_packet) =
CtrlPacketPayload::from_packet(packet)
else {
return None;
};
self.handle_route_packet(
packet.from_peer.into(),
extract_bytes_from_archived_vec(&data, &route_packet.body),
route_packet.body.into_boxed_slice().into(),
)
.await;
Some(())
+4 -3
View File
@@ -16,7 +16,7 @@ use crate::{
peers::packet::Packet,
};
use super::packet::PacketBody;
use super::packet::CtrlPacketPayload;
type PeerRpcServiceId = u32;
@@ -206,8 +206,9 @@ impl PeerRpcManager {
}
fn parse_rpc_packet(packet: &Packet) -> Result<TaRpcPacketInfo, Error> {
match &packet.body {
PacketBody::TaRpc(id, is_req, body) => Ok(TaRpcPacketInfo {
let ctrl_packet_payload = CtrlPacketPayload::from_packet2(&packet);
match &ctrl_packet_payload {
CtrlPacketPayload::TaRpc(id, is_req, body) => Ok(TaRpcPacketInfo {
from_peer: packet.from_peer.into(),
to_peer: packet.to_peer.into(),
service_id: *id,