Initial Version
This commit is contained in:
@@ -0,0 +1,301 @@
|
||||
use std::{
|
||||
mem::MaybeUninit,
|
||||
net::{IpAddr, Ipv4Addr, SocketAddrV4},
|
||||
sync::Arc,
|
||||
thread,
|
||||
};
|
||||
|
||||
use pnet::packet::{
|
||||
icmp::{self, IcmpTypes},
|
||||
ip::IpNextHeaderProtocols,
|
||||
ipv4::{self, Ipv4Packet, MutableIpv4Packet},
|
||||
Packet,
|
||||
};
|
||||
use socket2::Socket;
|
||||
use tokio::{
|
||||
sync::{mpsc::UnboundedSender, Mutex},
|
||||
task::JoinSet,
|
||||
};
|
||||
use tokio_util::bytes::Bytes;
|
||||
use tracing::Instrument;
|
||||
|
||||
use crate::{
|
||||
common::{error::Error, global_ctx::ArcGlobalCtx},
|
||||
peers::{
|
||||
packet,
|
||||
peer_manager::{PeerManager, PeerPacketFilter},
|
||||
PeerId,
|
||||
},
|
||||
};
|
||||
|
||||
use super::CidrSet;
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
|
||||
struct IcmpNatKey {
|
||||
dst_ip: std::net::IpAddr,
|
||||
icmp_id: u16,
|
||||
icmp_seq: u16,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct IcmpNatEntry {
|
||||
src_peer_id: PeerId,
|
||||
my_peer_id: PeerId,
|
||||
src_ip: IpAddr,
|
||||
start_time: std::time::Instant,
|
||||
}
|
||||
|
||||
impl IcmpNatEntry {
|
||||
fn new(src_peer_id: PeerId, my_peer_id: PeerId, src_ip: IpAddr) -> Result<Self, Error> {
|
||||
Ok(Self {
|
||||
src_peer_id,
|
||||
my_peer_id,
|
||||
src_ip,
|
||||
start_time: std::time::Instant::now(),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
type IcmpNatTable = Arc<dashmap::DashMap<IcmpNatKey, IcmpNatEntry>>;
|
||||
type NewPacketSender = tokio::sync::mpsc::UnboundedSender<IcmpNatKey>;
|
||||
type NewPacketReceiver = tokio::sync::mpsc::UnboundedReceiver<IcmpNatKey>;
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct IcmpProxy {
|
||||
global_ctx: ArcGlobalCtx,
|
||||
peer_manager: Arc<PeerManager>,
|
||||
|
||||
cidr_set: CidrSet,
|
||||
socket: socket2::Socket,
|
||||
|
||||
nat_table: IcmpNatTable,
|
||||
|
||||
tasks: Mutex<JoinSet<()>>,
|
||||
}
|
||||
|
||||
fn socket_recv(socket: &Socket, buf: &mut [MaybeUninit<u8>]) -> Result<(usize, IpAddr), Error> {
|
||||
let (size, addr) = socket.recv_from(buf)?;
|
||||
let addr = match addr.as_socket() {
|
||||
None => IpAddr::V4(Ipv4Addr::UNSPECIFIED),
|
||||
Some(add) => add.ip(),
|
||||
};
|
||||
Ok((size, addr))
|
||||
}
|
||||
|
||||
fn socket_recv_loop(
|
||||
socket: Socket,
|
||||
nat_table: IcmpNatTable,
|
||||
sender: UnboundedSender<packet::Packet>,
|
||||
) {
|
||||
let mut buf = [0u8; 4096];
|
||||
let data: &mut [MaybeUninit<u8>] = unsafe { std::mem::transmute(&mut buf[12..]) };
|
||||
|
||||
loop {
|
||||
let Ok((len, peer_ip)) = socket_recv(&socket, data) else {
|
||||
continue;
|
||||
};
|
||||
|
||||
if !peer_ip.is_ipv4() {
|
||||
continue;
|
||||
}
|
||||
|
||||
let Some(mut ipv4_packet) = MutableIpv4Packet::new(&mut buf[12..12 + len]) else {
|
||||
continue;
|
||||
};
|
||||
|
||||
let Some(icmp_packet) = icmp::echo_reply::EchoReplyPacket::new(ipv4_packet.payload())
|
||||
else {
|
||||
continue;
|
||||
};
|
||||
|
||||
if icmp_packet.get_icmp_type() != IcmpTypes::EchoReply {
|
||||
continue;
|
||||
}
|
||||
|
||||
let key = IcmpNatKey {
|
||||
dst_ip: peer_ip,
|
||||
icmp_id: icmp_packet.get_identifier(),
|
||||
icmp_seq: icmp_packet.get_sequence_number(),
|
||||
};
|
||||
|
||||
let Some((_, v)) = nat_table.remove(&key) else {
|
||||
continue;
|
||||
};
|
||||
|
||||
// send packet back to the peer where this request origin.
|
||||
let IpAddr::V4(dest_ip) = v.src_ip else {
|
||||
continue;
|
||||
};
|
||||
|
||||
ipv4_packet.set_destination(dest_ip);
|
||||
ipv4_packet.set_checksum(ipv4::checksum(&ipv4_packet.to_immutable()));
|
||||
|
||||
let peer_packet = packet::Packet::new_data_packet(
|
||||
v.my_peer_id,
|
||||
v.src_peer_id,
|
||||
&ipv4_packet.to_immutable().packet(),
|
||||
);
|
||||
|
||||
if let Err(e) = sender.send(peer_packet) {
|
||||
tracing::error!("send icmp packet to peer failed: {:?}, may exiting..", e);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait::async_trait]
|
||||
impl PeerPacketFilter for IcmpProxy {
|
||||
async fn try_process_packet_from_peer(
|
||||
&self,
|
||||
packet: &packet::ArchivedPacket,
|
||||
_: &Bytes,
|
||||
) -> Option<()> {
|
||||
let _ = self.global_ctx.get_ipv4()?;
|
||||
|
||||
let packet::ArchivedPacketBody::Data(x) = &packet.body else {
|
||||
return None;
|
||||
};
|
||||
|
||||
let ipv4 = Ipv4Packet::new(&x.data)?;
|
||||
|
||||
if ipv4.get_version() != 4 || ipv4.get_next_level_protocol() != IpNextHeaderProtocols::Icmp
|
||||
{
|
||||
return None;
|
||||
}
|
||||
|
||||
if !self.cidr_set.contains_v4(ipv4.get_destination()) {
|
||||
return None;
|
||||
}
|
||||
|
||||
let icmp_packet = icmp::echo_request::EchoRequestPacket::new(&ipv4.payload())?;
|
||||
|
||||
if icmp_packet.get_icmp_type() != IcmpTypes::EchoRequest {
|
||||
// drop it because we do not support other icmp types
|
||||
tracing::trace!("unsupported icmp type: {:?}", icmp_packet.get_icmp_type());
|
||||
return Some(());
|
||||
}
|
||||
|
||||
let icmp_id = icmp_packet.get_identifier();
|
||||
let icmp_seq = icmp_packet.get_sequence_number();
|
||||
|
||||
let key = IcmpNatKey {
|
||||
dst_ip: ipv4.get_destination().into(),
|
||||
icmp_id,
|
||||
icmp_seq,
|
||||
};
|
||||
|
||||
if packet.to_peer.is_none() {
|
||||
return None;
|
||||
}
|
||||
|
||||
let value = IcmpNatEntry::new(
|
||||
packet.from_peer.to_uuid(),
|
||||
packet.to_peer.as_ref().unwrap().to_uuid(),
|
||||
ipv4.get_source().into(),
|
||||
)
|
||||
.ok()?;
|
||||
|
||||
if let Some(old) = self.nat_table.insert(key, value) {
|
||||
tracing::info!("icmp nat table entry replaced: {:?}", old);
|
||||
}
|
||||
|
||||
if let Err(e) = self.send_icmp_packet(ipv4.get_destination(), &icmp_packet) {
|
||||
tracing::error!("send icmp packet failed: {:?}", e);
|
||||
}
|
||||
|
||||
Some(())
|
||||
}
|
||||
}
|
||||
|
||||
impl IcmpProxy {
|
||||
pub fn new(
|
||||
global_ctx: ArcGlobalCtx,
|
||||
peer_manager: Arc<PeerManager>,
|
||||
) -> Result<Arc<Self>, Error> {
|
||||
let cidr_set = CidrSet::new(global_ctx.clone());
|
||||
|
||||
let _g = global_ctx.net_ns.guard();
|
||||
let socket = socket2::Socket::new(
|
||||
socket2::Domain::IPV4,
|
||||
socket2::Type::RAW,
|
||||
Some(socket2::Protocol::ICMPV4),
|
||||
)?;
|
||||
socket.bind(&socket2::SockAddr::from(SocketAddrV4::new(
|
||||
std::net::Ipv4Addr::UNSPECIFIED,
|
||||
0,
|
||||
)))?;
|
||||
|
||||
let ret = Self {
|
||||
global_ctx,
|
||||
peer_manager,
|
||||
cidr_set,
|
||||
socket,
|
||||
|
||||
nat_table: Arc::new(dashmap::DashMap::new()),
|
||||
tasks: Mutex::new(JoinSet::new()),
|
||||
};
|
||||
|
||||
Ok(Arc::new(ret))
|
||||
}
|
||||
|
||||
pub async fn start(self: &Arc<Self>) -> Result<(), Error> {
|
||||
self.start_icmp_proxy().await?;
|
||||
self.start_nat_table_cleaner().await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn start_nat_table_cleaner(self: &Arc<Self>) -> Result<(), Error> {
|
||||
let nat_table = self.nat_table.clone();
|
||||
self.tasks.lock().await.spawn(
|
||||
async move {
|
||||
loop {
|
||||
tokio::time::sleep(std::time::Duration::from_secs(1)).await;
|
||||
nat_table.retain(|_, v| v.start_time.elapsed().as_secs() < 20);
|
||||
}
|
||||
}
|
||||
.instrument(tracing::info_span!("icmp proxy nat table cleaner")),
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn start_icmp_proxy(self: &Arc<Self>) -> Result<(), Error> {
|
||||
let socket = self.socket.try_clone()?;
|
||||
let (sender, mut receiver) = tokio::sync::mpsc::unbounded_channel();
|
||||
let nat_table = self.nat_table.clone();
|
||||
thread::spawn(|| {
|
||||
socket_recv_loop(socket, nat_table, sender);
|
||||
});
|
||||
|
||||
let peer_manager = self.peer_manager.clone();
|
||||
self.tasks.lock().await.spawn(
|
||||
async move {
|
||||
while let Some(msg) = receiver.recv().await {
|
||||
let to_peer_id: uuid::Uuid = msg.to_peer.as_ref().unwrap().clone().into();
|
||||
let ret = peer_manager.send_msg(msg.into(), &to_peer_id).await;
|
||||
if ret.is_err() {
|
||||
tracing::error!("send icmp packet to peer failed: {:?}", ret);
|
||||
}
|
||||
}
|
||||
}
|
||||
.instrument(tracing::info_span!("icmp proxy send loop")),
|
||||
);
|
||||
|
||||
self.peer_manager
|
||||
.add_packet_process_pipeline(Box::new(self.clone()))
|
||||
.await;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn send_icmp_packet(
|
||||
&self,
|
||||
dst_ip: Ipv4Addr,
|
||||
icmp_packet: &icmp::echo_request::EchoRequestPacket,
|
||||
) -> Result<(), Error> {
|
||||
self.socket.send_to(
|
||||
icmp_packet.packet(),
|
||||
&SocketAddrV4::new(dst_ip.into(), 0).into(),
|
||||
)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,51 @@
|
||||
use dashmap::DashSet;
|
||||
use std::sync::Arc;
|
||||
use tokio::task::JoinSet;
|
||||
|
||||
use crate::common::global_ctx::ArcGlobalCtx;
|
||||
|
||||
pub mod icmp_proxy;
|
||||
pub mod tcp_proxy;
|
||||
|
||||
#[derive(Debug)]
|
||||
struct CidrSet {
|
||||
global_ctx: ArcGlobalCtx,
|
||||
cidr_set: Arc<DashSet<cidr::IpCidr>>,
|
||||
tasks: JoinSet<()>,
|
||||
}
|
||||
|
||||
impl CidrSet {
|
||||
pub fn new(global_ctx: ArcGlobalCtx) -> Self {
|
||||
let mut ret = Self {
|
||||
global_ctx,
|
||||
cidr_set: Arc::new(DashSet::new()),
|
||||
tasks: JoinSet::new(),
|
||||
};
|
||||
ret.run_cidr_updater();
|
||||
ret
|
||||
}
|
||||
|
||||
fn run_cidr_updater(&mut self) {
|
||||
let global_ctx = self.global_ctx.clone();
|
||||
let cidr_set = self.cidr_set.clone();
|
||||
self.tasks.spawn(async move {
|
||||
let mut last_cidrs = vec![];
|
||||
loop {
|
||||
let cidrs = global_ctx.get_proxy_cidrs();
|
||||
if cidrs != last_cidrs {
|
||||
last_cidrs = cidrs.clone();
|
||||
cidr_set.clear();
|
||||
for cidr in cidrs.iter() {
|
||||
cidr_set.insert(cidr.clone());
|
||||
}
|
||||
}
|
||||
tokio::time::sleep(std::time::Duration::from_secs(1)).await;
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
pub fn contains_v4(&self, ip: std::net::Ipv4Addr) -> bool {
|
||||
let ip = ip.into();
|
||||
return self.cidr_set.iter().any(|cidr| cidr.contains(&ip));
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,402 @@
|
||||
use crossbeam::atomic::AtomicCell;
|
||||
use dashmap::DashMap;
|
||||
use pnet::packet::ip::IpNextHeaderProtocols;
|
||||
use pnet::packet::ipv4::{Ipv4Packet, MutableIpv4Packet};
|
||||
use pnet::packet::tcp::{ipv4_checksum, MutableTcpPacket};
|
||||
use std::net::{IpAddr, Ipv4Addr, SocketAddr, SocketAddrV4};
|
||||
use std::sync::atomic::AtomicU16;
|
||||
use std::sync::Arc;
|
||||
use std::time::{Duration, Instant};
|
||||
use tokio::io::copy_bidirectional;
|
||||
use tokio::net::{TcpListener, TcpSocket, TcpStream};
|
||||
use tokio::sync::Mutex;
|
||||
use tokio::task::JoinSet;
|
||||
use tokio_util::bytes::{Bytes, BytesMut};
|
||||
use tracing::Instrument;
|
||||
|
||||
use crate::common::error::Result;
|
||||
use crate::common::global_ctx::GlobalCtx;
|
||||
use crate::common::netns::NetNS;
|
||||
use crate::peers::packet::{self, ArchivedPacket};
|
||||
use crate::peers::peer_manager::{NicPacketFilter, PeerManager, PeerPacketFilter};
|
||||
|
||||
use super::CidrSet;
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq)]
|
||||
enum NatDstEntryState {
|
||||
// receive syn packet but not start connecting to dst
|
||||
SynReceived,
|
||||
// connecting to dst
|
||||
ConnectingDst,
|
||||
// connected to dst
|
||||
Connected,
|
||||
// connection closed
|
||||
Closed,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct NatDstEntry {
|
||||
id: uuid::Uuid,
|
||||
src: SocketAddr,
|
||||
dst: SocketAddr,
|
||||
start_time: Instant,
|
||||
tasks: Mutex<JoinSet<()>>,
|
||||
state: AtomicCell<NatDstEntryState>,
|
||||
}
|
||||
|
||||
impl NatDstEntry {
|
||||
pub fn new(src: SocketAddr, dst: SocketAddr) -> Self {
|
||||
Self {
|
||||
id: uuid::Uuid::new_v4(),
|
||||
src,
|
||||
dst,
|
||||
start_time: Instant::now(),
|
||||
tasks: Mutex::new(JoinSet::new()),
|
||||
state: AtomicCell::new(NatDstEntryState::SynReceived),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
type ArcNatDstEntry = Arc<NatDstEntry>;
|
||||
|
||||
type SynSockMap = Arc<DashMap<SocketAddr, ArcNatDstEntry>>;
|
||||
type ConnSockMap = Arc<DashMap<uuid::Uuid, ArcNatDstEntry>>;
|
||||
// peer src addr to nat entry, when respond tcp packet, should modify the tcp src addr to the nat entry's dst addr
|
||||
type AddrConnSockMap = Arc<DashMap<SocketAddr, ArcNatDstEntry>>;
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct TcpProxy {
|
||||
global_ctx: Arc<GlobalCtx>,
|
||||
peer_manager: Arc<PeerManager>,
|
||||
local_port: AtomicU16,
|
||||
|
||||
tasks: Arc<Mutex<JoinSet<()>>>,
|
||||
|
||||
syn_map: SynSockMap,
|
||||
conn_map: ConnSockMap,
|
||||
addr_conn_map: AddrConnSockMap,
|
||||
|
||||
cidr_set: CidrSet,
|
||||
}
|
||||
|
||||
#[async_trait::async_trait]
|
||||
impl PeerPacketFilter for TcpProxy {
|
||||
async fn try_process_packet_from_peer(&self, packet: &ArchivedPacket, _: &Bytes) -> Option<()> {
|
||||
let ipv4_addr = self.global_ctx.get_ipv4()?;
|
||||
|
||||
let packet::ArchivedPacketBody::Data(x) = &packet.body else {
|
||||
return None;
|
||||
};
|
||||
|
||||
let ipv4 = Ipv4Packet::new(&x.data)?;
|
||||
if ipv4.get_version() != 4 || ipv4.get_next_level_protocol() != IpNextHeaderProtocols::Tcp {
|
||||
return None;
|
||||
}
|
||||
|
||||
if !self.cidr_set.contains_v4(ipv4.get_destination()) {
|
||||
return None;
|
||||
}
|
||||
|
||||
tracing::trace!(ipv4 = ?ipv4, cidr_set = ?self.cidr_set, "proxy tcp packet received");
|
||||
|
||||
let mut packet_buffer = BytesMut::with_capacity(x.data.len());
|
||||
packet_buffer.extend_from_slice(&x.data.to_vec());
|
||||
|
||||
let (ip_buffer, tcp_buffer) =
|
||||
packet_buffer.split_at_mut(ipv4.get_header_length() as usize * 4);
|
||||
|
||||
let mut ip_packet = MutableIpv4Packet::new(ip_buffer).unwrap();
|
||||
let mut tcp_packet = MutableTcpPacket::new(tcp_buffer).unwrap();
|
||||
|
||||
let is_tcp_syn = tcp_packet.get_flags() & pnet::packet::tcp::TcpFlags::SYN != 0;
|
||||
if is_tcp_syn {
|
||||
let source_ip = ip_packet.get_source();
|
||||
let source_port = tcp_packet.get_source();
|
||||
let src = SocketAddr::V4(SocketAddrV4::new(source_ip, source_port));
|
||||
|
||||
let dest_ip = ip_packet.get_destination();
|
||||
let dest_port = tcp_packet.get_destination();
|
||||
let dst = SocketAddr::V4(SocketAddrV4::new(dest_ip, dest_port));
|
||||
|
||||
let old_val = self
|
||||
.syn_map
|
||||
.insert(src, Arc::new(NatDstEntry::new(src, dst)));
|
||||
tracing::trace!(src = ?src, dst = ?dst, old_entry = ?old_val, "tcp syn received");
|
||||
}
|
||||
|
||||
ip_packet.set_destination(ipv4_addr);
|
||||
tcp_packet.set_destination(self.get_local_port());
|
||||
Self::update_ipv4_packet_checksum(&mut ip_packet, &mut tcp_packet);
|
||||
|
||||
tracing::trace!(ip_packet = ?ip_packet, tcp_packet = ?tcp_packet, "tcp packet forwarded");
|
||||
|
||||
if let Err(e) = self
|
||||
.peer_manager
|
||||
.get_nic_channel()
|
||||
.send(packet_buffer.freeze())
|
||||
.await
|
||||
{
|
||||
tracing::error!("send to nic failed: {:?}", e);
|
||||
}
|
||||
|
||||
Some(())
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait::async_trait]
|
||||
impl NicPacketFilter for TcpProxy {
|
||||
async fn try_process_packet_from_nic(&self, mut data: BytesMut) -> BytesMut {
|
||||
let Some(my_ipv4) = self.global_ctx.get_ipv4() else {
|
||||
return data;
|
||||
};
|
||||
|
||||
let header_len = {
|
||||
let Some(ipv4) = &Ipv4Packet::new(&data[..]) else {
|
||||
return data;
|
||||
};
|
||||
|
||||
if ipv4.get_version() != 4
|
||||
|| ipv4.get_source() != my_ipv4
|
||||
|| ipv4.get_next_level_protocol() != IpNextHeaderProtocols::Tcp
|
||||
{
|
||||
return data;
|
||||
}
|
||||
|
||||
ipv4.get_header_length() as usize * 4
|
||||
};
|
||||
|
||||
let (ip_buffer, tcp_buffer) = data.split_at_mut(header_len);
|
||||
let mut ip_packet = MutableIpv4Packet::new(ip_buffer).unwrap();
|
||||
let mut tcp_packet = MutableTcpPacket::new(tcp_buffer).unwrap();
|
||||
|
||||
if tcp_packet.get_source() != self.get_local_port() {
|
||||
return data;
|
||||
}
|
||||
|
||||
let dst_addr = SocketAddr::V4(SocketAddrV4::new(
|
||||
ip_packet.get_destination(),
|
||||
tcp_packet.get_destination(),
|
||||
));
|
||||
|
||||
tracing::trace!(dst_addr = ?dst_addr, "tcp packet try find entry");
|
||||
let entry = if let Some(entry) = self.addr_conn_map.get(&dst_addr) {
|
||||
entry
|
||||
} else {
|
||||
let Some(syn_entry) = self.syn_map.get(&dst_addr) else {
|
||||
return data;
|
||||
};
|
||||
syn_entry
|
||||
};
|
||||
let nat_entry = entry.clone();
|
||||
drop(entry);
|
||||
assert_eq!(nat_entry.src, dst_addr);
|
||||
|
||||
let IpAddr::V4(ip) = nat_entry.dst.ip() else {
|
||||
panic!("v4 nat entry src ip is not v4");
|
||||
};
|
||||
|
||||
ip_packet.set_source(ip);
|
||||
tcp_packet.set_source(nat_entry.dst.port());
|
||||
Self::update_ipv4_packet_checksum(&mut ip_packet, &mut tcp_packet);
|
||||
|
||||
tracing::trace!(dst_addr = ?dst_addr, nat_entry = ?nat_entry, packet = ?ip_packet, "tcp packet after modified");
|
||||
|
||||
data
|
||||
}
|
||||
}
|
||||
|
||||
impl TcpProxy {
|
||||
pub fn new(global_ctx: Arc<GlobalCtx>, peer_manager: Arc<PeerManager>) -> Arc<Self> {
|
||||
Arc::new(Self {
|
||||
global_ctx: global_ctx.clone(),
|
||||
peer_manager,
|
||||
|
||||
local_port: AtomicU16::new(0),
|
||||
tasks: Arc::new(Mutex::new(JoinSet::new())),
|
||||
|
||||
syn_map: Arc::new(DashMap::new()),
|
||||
conn_map: Arc::new(DashMap::new()),
|
||||
addr_conn_map: Arc::new(DashMap::new()),
|
||||
|
||||
cidr_set: CidrSet::new(global_ctx),
|
||||
})
|
||||
}
|
||||
|
||||
fn update_ipv4_packet_checksum(
|
||||
ipv4_packet: &mut MutableIpv4Packet,
|
||||
tcp_packet: &mut MutableTcpPacket,
|
||||
) {
|
||||
tcp_packet.set_checksum(ipv4_checksum(
|
||||
&tcp_packet.to_immutable(),
|
||||
&ipv4_packet.get_source(),
|
||||
&ipv4_packet.get_destination(),
|
||||
));
|
||||
|
||||
ipv4_packet.set_checksum(pnet::packet::ipv4::checksum(&ipv4_packet.to_immutable()));
|
||||
}
|
||||
|
||||
pub async fn start(self: &Arc<Self>) -> Result<()> {
|
||||
self.run_syn_map_cleaner().await?;
|
||||
self.run_listener().await?;
|
||||
self.peer_manager
|
||||
.add_packet_process_pipeline(Box::new(self.clone()))
|
||||
.await;
|
||||
self.peer_manager
|
||||
.add_nic_packet_process_pipeline(Box::new(self.clone()))
|
||||
.await;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn run_syn_map_cleaner(&self) -> Result<()> {
|
||||
let syn_map = self.syn_map.clone();
|
||||
let tasks = self.tasks.clone();
|
||||
let syn_map_cleaner_task = async move {
|
||||
loop {
|
||||
syn_map.retain(|_, entry| {
|
||||
if entry.start_time.elapsed() > Duration::from_secs(30) {
|
||||
tracing::warn!(entry = ?entry, "syn nat entry expired");
|
||||
entry.state.store(NatDstEntryState::Closed);
|
||||
false
|
||||
} else {
|
||||
true
|
||||
}
|
||||
});
|
||||
tokio::time::sleep(Duration::from_secs(10)).await;
|
||||
}
|
||||
};
|
||||
tasks.lock().await.spawn(syn_map_cleaner_task);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn run_listener(&self) -> Result<()> {
|
||||
// bind on both v4 & v6
|
||||
let listen_addr = SocketAddr::new(Ipv4Addr::UNSPECIFIED.into(), 0);
|
||||
|
||||
let net_ns = self.global_ctx.net_ns.clone();
|
||||
let tcp_listener = net_ns
|
||||
.run_async(|| async { TcpListener::bind(&listen_addr).await })
|
||||
.await?;
|
||||
|
||||
self.local_port.store(
|
||||
tcp_listener.local_addr()?.port(),
|
||||
std::sync::atomic::Ordering::Relaxed,
|
||||
);
|
||||
|
||||
let tasks = self.tasks.clone();
|
||||
let syn_map = self.syn_map.clone();
|
||||
let conn_map = self.conn_map.clone();
|
||||
let addr_conn_map = self.addr_conn_map.clone();
|
||||
let accept_task = async move {
|
||||
tracing::info!(listener = ?tcp_listener, "tcp connection start accepting");
|
||||
|
||||
let conn_map = conn_map.clone();
|
||||
while let Ok((tcp_stream, socket_addr)) = tcp_listener.accept().await {
|
||||
let Some(entry) = syn_map.get(&socket_addr) else {
|
||||
tracing::error!("tcp connection from unknown source: {:?}", socket_addr);
|
||||
continue;
|
||||
};
|
||||
assert_eq!(entry.state.load(), NatDstEntryState::SynReceived);
|
||||
|
||||
let entry_clone = entry.clone();
|
||||
drop(entry);
|
||||
syn_map.remove_if(&socket_addr, |_, entry| entry.id == entry_clone.id);
|
||||
|
||||
entry_clone.state.store(NatDstEntryState::ConnectingDst);
|
||||
|
||||
let _ = addr_conn_map.insert(entry_clone.src, entry_clone.clone());
|
||||
let old_nat_val = conn_map.insert(entry_clone.id, entry_clone.clone());
|
||||
assert!(old_nat_val.is_none());
|
||||
|
||||
tasks.lock().await.spawn(Self::connect_to_nat_dst(
|
||||
net_ns.clone(),
|
||||
tcp_stream,
|
||||
conn_map.clone(),
|
||||
addr_conn_map.clone(),
|
||||
entry_clone,
|
||||
));
|
||||
}
|
||||
tracing::error!("nat tcp listener exited");
|
||||
panic!("nat tcp listener exited");
|
||||
};
|
||||
self.tasks
|
||||
.lock()
|
||||
.await
|
||||
.spawn(accept_task.instrument(tracing::info_span!("tcp_proxy_listener")));
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn remove_entry_from_all_conn_map(
|
||||
conn_map: ConnSockMap,
|
||||
addr_conn_map: AddrConnSockMap,
|
||||
nat_entry: ArcNatDstEntry,
|
||||
) {
|
||||
conn_map.remove(&nat_entry.id);
|
||||
addr_conn_map.remove_if(&nat_entry.src, |_, entry| entry.id == nat_entry.id);
|
||||
}
|
||||
|
||||
async fn connect_to_nat_dst(
|
||||
net_ns: NetNS,
|
||||
src_tcp_stream: TcpStream,
|
||||
conn_map: ConnSockMap,
|
||||
addr_conn_map: AddrConnSockMap,
|
||||
nat_entry: ArcNatDstEntry,
|
||||
) {
|
||||
if let Err(e) = src_tcp_stream.set_nodelay(true) {
|
||||
tracing::warn!("set_nodelay failed, ignore it: {:?}", e);
|
||||
}
|
||||
|
||||
let _guard = net_ns.guard();
|
||||
let socket = TcpSocket::new_v4().unwrap();
|
||||
if let Err(e) = socket.set_nodelay(true) {
|
||||
tracing::warn!("set_nodelay failed, ignore it: {:?}", e);
|
||||
}
|
||||
let Ok(Ok(dst_tcp_stream)) = tokio::time::timeout(
|
||||
Duration::from_secs(10),
|
||||
TcpSocket::new_v4().unwrap().connect(nat_entry.dst),
|
||||
)
|
||||
.await
|
||||
else {
|
||||
tracing::error!("connect to dst failed: {:?}", nat_entry);
|
||||
nat_entry.state.store(NatDstEntryState::Closed);
|
||||
Self::remove_entry_from_all_conn_map(conn_map, addr_conn_map, nat_entry);
|
||||
return;
|
||||
};
|
||||
drop(_guard);
|
||||
|
||||
assert_eq!(nat_entry.state.load(), NatDstEntryState::ConnectingDst);
|
||||
nat_entry.state.store(NatDstEntryState::Connected);
|
||||
|
||||
Self::handle_nat_connection(
|
||||
src_tcp_stream,
|
||||
dst_tcp_stream,
|
||||
conn_map,
|
||||
addr_conn_map,
|
||||
nat_entry,
|
||||
)
|
||||
.await;
|
||||
}
|
||||
|
||||
async fn handle_nat_connection(
|
||||
mut src_tcp_stream: TcpStream,
|
||||
mut dst_tcp_stream: TcpStream,
|
||||
conn_map: ConnSockMap,
|
||||
addr_conn_map: AddrConnSockMap,
|
||||
nat_entry: ArcNatDstEntry,
|
||||
) {
|
||||
let nat_entry_clone = nat_entry.clone();
|
||||
nat_entry.tasks.lock().await.spawn(async move {
|
||||
let ret = copy_bidirectional(&mut src_tcp_stream, &mut dst_tcp_stream).await;
|
||||
tracing::trace!(nat_entry = ?nat_entry_clone, ret = ?ret, "nat tcp connection closed");
|
||||
nat_entry_clone.state.store(NatDstEntryState::Closed);
|
||||
|
||||
Self::remove_entry_from_all_conn_map(conn_map, addr_conn_map, nat_entry_clone);
|
||||
});
|
||||
}
|
||||
|
||||
pub fn get_local_port(&self) -> u16 {
|
||||
self.local_port.load(std::sync::atomic::Ordering::Relaxed)
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user