optimize packet def (#31)

This commit is contained in:
Sijie.Sun
2024-03-13 22:43:52 +08:00
committed by GitHub
parent b0494687b5
commit ecb385a82c
15 changed files with 240 additions and 235 deletions
+38 -14
View File
@@ -13,7 +13,7 @@ use tokio_util::{
use tracing::Instrument;
use crate::{
common::rkyv_util::{self, encode_to_bytes},
common::rkyv_util::{self, encode_to_bytes, vec_to_string},
rpc::TunnelInfo,
tunnels::{build_url_from_socket_addr, close_tunnel, TunnelConnCounter, TunnelConnector},
};
@@ -30,12 +30,11 @@ pub const UDP_DATA_MTU: usize = 2500;
#[derive(Archive, Deserialize, Serialize, Debug)]
#[archive(compare(PartialEq), check_bytes)]
// Derives can be passed through to the generated type:
#[archive_attr(derive(Debug))]
pub enum UdpPacketPayload {
Syn,
Sack,
HolePunch(Vec<u8>),
Data(Vec<u8>),
HolePunch(String),
Data(String),
}
#[derive(Archive, Deserialize, Serialize, Debug)]
@@ -46,18 +45,32 @@ pub struct UdpPacket {
pub payload: UdpPacketPayload,
}
impl std::fmt::Debug for ArchivedUdpPacketPayload {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let mut tmp = f.debug_struct("ArchivedUdpPacketPayload");
match self {
ArchivedUdpPacketPayload::Syn => tmp.field("Syn", &"").finish(),
ArchivedUdpPacketPayload::Sack => tmp.field("Sack", &"").finish(),
ArchivedUdpPacketPayload::HolePunch(s) => {
tmp.field("HolePunch", &s.as_bytes()).finish()
}
ArchivedUdpPacketPayload::Data(s) => tmp.field("Data", &s.as_bytes()).finish(),
}
}
}
impl UdpPacket {
pub fn new_data_packet(conn_id: u32, data: Vec<u8>) -> Self {
Self {
conn_id,
payload: UdpPacketPayload::Data(data),
payload: UdpPacketPayload::Data(vec_to_string(data)),
}
}
pub fn new_hole_punch_packet(data: Vec<u8>) -> Self {
Self {
conn_id: 0,
payload: UdpPacketPayload::HolePunch(data),
payload: UdpPacketPayload::HolePunch(vec_to_string(data)),
}
}
@@ -77,7 +90,7 @@ impl UdpPacket {
}
fn try_get_data_payload(mut buf: BytesMut, conn_id: u32) -> Option<BytesMut> {
let Ok(udp_packet) = rkyv_util::decode_from_bytes_checked::<UdpPacket>(&buf) else {
let Ok(udp_packet) = rkyv_util::decode_from_bytes::<UdpPacket>(&buf) else {
tracing::warn!(?buf, "udp decode error");
return None;
};
@@ -92,9 +105,13 @@ fn try_get_data_payload(mut buf: BytesMut, conn_id: u32) -> Option<BytesMut> {
return None;
};
let ptr_range = payload.as_ptr_range();
let offset = ptr_range.start as usize - buf.as_ptr() as usize;
let len = ptr_range.end as usize - ptr_range.start as usize;
let offset = payload.as_ptr() as usize - buf.as_ptr() as usize;
let len = payload.len();
if offset + len > buf.len() {
tracing::warn!(?offset, ?len, ?buf, "udp payload data out of range");
return None;
}
buf.advance(offset);
buf.truncate(len);
tracing::trace!(?offset, ?len, ?buf, "udp payload data");
@@ -138,8 +155,8 @@ fn get_tunnel_from_socket(
// TODO: two copy here, how to avoid?
let udp_packet = UdpPacket::new_data_packet(conn_id, v.to_vec());
tracing::trace!(?udp_packet, ?v, "udp send packet");
let v = encode_to_bytes::<_, UDP_DATA_MTU>(&udp_packet);
tracing::trace!(?udp_packet, ?v, "udp send packet");
Ok((v, sender_addr))
}));
@@ -301,8 +318,7 @@ impl TunnelListener for UdpTunnelListener {
_size
);
let Ok(udp_packet) = rkyv_util::decode_from_bytes_checked::<UdpPacket>(&buf)
else {
let Ok(udp_packet) = rkyv_util::decode_from_bytes::<UdpPacket>(&buf) else {
tracing::warn!(?buf, "udp decode error in forward task");
continue;
};
@@ -429,7 +445,7 @@ impl UdpTunnelConnector {
let _ = buf.split_off(usize);
let Ok(udp_packet) = rkyv_util::decode_from_bytes_checked::<UdpPacket>(&buf) else {
let Ok(udp_packet) = rkyv_util::decode_from_bytes::<UdpPacket>(&buf) else {
tracing::warn!(?buf, "udp decode error in wait sack");
return Err(super::TunnelError::ConnectError(format!(
"udp connect error, decode error. buf: {:?}",
@@ -677,4 +693,12 @@ mod tests {
let _ = tokio::join!(sender1, sender2);
}
#[tokio::test]
async fn udp_packet_print() {
let udp_packet = UdpPacket::new_data_packet(1, vec![1, 2, 3, 4, 5]);
let b = encode_to_bytes::<_, UDP_DATA_MTU>(&udp_packet);
let a_udp_packet = rkyv_util::decode_from_bytes::<UdpPacket>(&b).unwrap();
println!("{:?}, {:?}", udp_packet, a_udp_packet);
}
}