Initial Version

This commit is contained in:
sijie.sun
2023-09-23 01:53:45 +00:00
commit 9779923b87
63 changed files with 10840 additions and 0 deletions
+54
View File
@@ -0,0 +1,54 @@
use std::result::Result;
use tokio::io;
use tokio_util::{
bytes::{BufMut, Bytes, BytesMut},
codec::{Decoder, Encoder},
};
#[derive(Copy, Clone, Debug, Eq, PartialEq, Ord, PartialOrd, Hash, Default)]
pub struct BytesCodec {
capacity: usize,
}
impl BytesCodec {
/// Creates a new `BytesCodec` for shipping around raw bytes.
pub fn new(capacity: usize) -> BytesCodec {
BytesCodec { capacity }
}
}
impl Decoder for BytesCodec {
type Item = BytesMut;
type Error = io::Error;
fn decode(&mut self, buf: &mut BytesMut) -> Result<Option<BytesMut>, io::Error> {
if !buf.is_empty() {
let len = buf.len();
let ret = Some(buf.split_to(len));
buf.reserve(self.capacity);
Ok(ret)
} else {
Ok(None)
}
}
}
impl Encoder<Bytes> for BytesCodec {
type Error = io::Error;
fn encode(&mut self, data: Bytes, buf: &mut BytesMut) -> Result<(), io::Error> {
buf.reserve(data.len());
buf.put(data);
Ok(())
}
}
impl Encoder<BytesMut> for BytesCodec {
type Error = io::Error;
fn encode(&mut self, data: BytesMut, buf: &mut BytesMut) -> Result<(), io::Error> {
buf.reserve(data.len());
buf.put(data);
Ok(())
}
}
+399
View File
@@ -0,0 +1,399 @@
use std::{
collections::VecDeque,
net::IpAddr,
sync::Arc,
task::{ready, Context, Poll},
};
use async_stream::stream;
use futures::{Future, FutureExt, Sink, SinkExt, Stream, StreamExt};
use tokio::{sync::Mutex, time::error::Elapsed};
use std::pin::Pin;
use crate::tunnels::{SinkError, TunnelError};
use super::{DatagramSink, DatagramStream, SinkItem, StreamT, Tunnel, TunnelInfo};
pub struct FramedTunnel<R, W> {
read: Arc<Mutex<R>>,
write: Arc<Mutex<W>>,
info: Option<TunnelInfo>,
}
impl<R, RE, W, WE> FramedTunnel<R, W>
where
R: Stream<Item = Result<StreamT, RE>> + Send + Sync + Unpin + 'static,
W: Sink<SinkItem, Error = WE> + Send + Sync + Unpin + 'static,
RE: std::error::Error + std::fmt::Debug + Send + Sync + 'static,
WE: std::error::Error + std::fmt::Debug + Send + Sync + 'static + From<Elapsed>,
{
pub fn new(read: R, write: W, info: Option<TunnelInfo>) -> Self {
FramedTunnel {
read: Arc::new(Mutex::new(read)),
write: Arc::new(Mutex::new(write)),
info,
}
}
pub fn new_tunnel_with_info(read: R, write: W, info: TunnelInfo) -> Box<dyn Tunnel> {
Box::new(FramedTunnel::new(read, write, Some(info)))
}
pub fn recv_stream(&self) -> impl DatagramStream {
let read = self.read.clone();
let info = self.info.clone();
stream! {
loop {
let read_ret = read.lock().await.next().await;
if read_ret.is_none() {
tracing::info!(?info, "read_ret is none");
yield Err(TunnelError::CommonError("recv stream closed".to_string()));
} else {
let read_ret = read_ret.unwrap();
if read_ret.is_err() {
let err = read_ret.err().unwrap();
tracing::info!(?info, "recv stream read error");
yield Err(TunnelError::CommonError(err.to_string()));
} else {
yield Ok(read_ret.unwrap());
}
}
}
}
}
pub fn send_sink(&self) -> impl DatagramSink {
struct SendSink<W, WE> {
write: Arc<Mutex<W>>,
max_buffer_size: usize,
sending_buffers: Option<VecDeque<SinkItem>>,
send_task:
Option<Pin<Box<dyn Future<Output = Result<(), WE>> + Send + Sync + 'static>>>,
close_task:
Option<Pin<Box<dyn Future<Output = Result<(), WE>> + Send + Sync + 'static>>>,
}
impl<W, WE> SendSink<W, WE>
where
W: Sink<SinkItem, Error = WE> + Send + Sync + Unpin + 'static,
WE: std::error::Error + std::fmt::Debug + Send + Sync + From<Elapsed>,
{
fn try_send_buffser(
&mut self,
cx: &mut Context<'_>,
) -> Poll<std::result::Result<(), WE>> {
if self.send_task.is_none() {
let mut buffers = self.sending_buffers.take().unwrap();
let tun = self.write.clone();
let send_task = async move {
if buffers.is_empty() {
return Ok(());
}
let mut locked_tun = tun.lock_owned().await;
while let Some(buf) = buffers.front() {
log::trace!(
"try_send buffer, len: {:?}, buf: {:?}",
buffers.len(),
&buf
);
let timeout_task = tokio::time::timeout(
std::time::Duration::from_secs(1),
locked_tun.send(buf.clone()),
);
let send_res = timeout_task.await;
let Ok(send_res) = send_res else {
// panic!("send timeout");
let err = send_res.err().unwrap();
return Err(err.into());
};
let Ok(_) = send_res else {
let err = send_res.err().unwrap();
println!("send error: {:?}", err);
return Err(err);
};
buffers.pop_front();
}
return Ok(());
};
self.send_task = Some(Box::pin(send_task));
}
let ret = ready!(self.send_task.as_mut().unwrap().poll_unpin(cx));
self.send_task = None;
self.sending_buffers = Some(VecDeque::new());
return Poll::Ready(ret);
}
}
impl<W, WE> Sink<SinkItem> for SendSink<W, WE>
where
W: Sink<SinkItem, Error = WE> + Send + Sync + Unpin + 'static,
WE: std::error::Error + std::fmt::Debug + Send + Sync + From<Elapsed>,
{
type Error = SinkError;
fn poll_ready(
self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> Poll<Result<(), Self::Error>> {
let self_mut = self.get_mut();
let sending_buf = self_mut.sending_buffers.as_ref();
// if sending_buffers is None, must already be doing flush
if sending_buf.is_none() || sending_buf.unwrap().len() > self_mut.max_buffer_size {
return self_mut.poll_flush_unpin(cx);
} else {
return Poll::Ready(Ok(()));
}
}
fn start_send(self: Pin<&mut Self>, item: SinkItem) -> Result<(), Self::Error> {
assert!(self.send_task.is_none());
let self_mut = self.get_mut();
self_mut.sending_buffers.as_mut().unwrap().push_back(item);
Ok(())
}
fn poll_flush(
self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> Poll<Result<(), Self::Error>> {
let self_mut = self.get_mut();
let ret = self_mut.try_send_buffser(cx);
match ret {
Poll::Ready(Ok(())) => Poll::Ready(Ok(())),
Poll::Ready(Err(e)) => Poll::Ready(Err(SinkError::CommonError(e.to_string()))),
Poll::Pending => {
return Poll::Pending;
}
}
}
fn poll_close(
self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> Poll<Result<(), Self::Error>> {
let self_mut = self.get_mut();
if self_mut.close_task.is_none() {
let tun = self_mut.write.clone();
let close_task = async move {
let mut locked_tun = tun.lock_owned().await;
return locked_tun.close().await;
};
self_mut.close_task = Some(Box::pin(close_task));
}
let ret = ready!(self_mut.close_task.as_mut().unwrap().poll_unpin(cx));
self_mut.close_task = None;
if ret.is_err() {
return Poll::Ready(Err(SinkError::CommonError(
ret.err().unwrap().to_string(),
)));
} else {
return Poll::Ready(Ok(()));
}
}
}
SendSink {
write: self.write.clone(),
max_buffer_size: 1000,
sending_buffers: Some(VecDeque::new()),
send_task: None,
close_task: None,
}
}
}
impl<R, RE, W, WE> Tunnel for FramedTunnel<R, W>
where
R: Stream<Item = Result<StreamT, RE>> + Send + Sync + Unpin + 'static,
W: Sink<SinkItem, Error = WE> + Send + Sync + Unpin + 'static,
RE: std::error::Error + std::fmt::Debug + Send + Sync + 'static,
WE: std::error::Error + std::fmt::Debug + Send + Sync + 'static + From<Elapsed>,
{
fn stream(&self) -> Box<dyn DatagramStream> {
Box::new(self.recv_stream())
}
fn sink(&self) -> Box<dyn DatagramSink> {
Box::new(self.send_sink())
}
fn info(&self) -> Option<TunnelInfo> {
if self.info.is_none() {
None
} else {
Some(self.info.clone().unwrap())
}
}
}
pub struct TunnelWithCustomInfo {
tunnel: Box<dyn Tunnel>,
info: TunnelInfo,
}
impl TunnelWithCustomInfo {
pub fn new(tunnel: Box<dyn Tunnel>, info: TunnelInfo) -> Self {
TunnelWithCustomInfo { tunnel, info }
}
}
impl Tunnel for TunnelWithCustomInfo {
fn stream(&self) -> Box<dyn DatagramStream> {
self.tunnel.stream()
}
fn sink(&self) -> Box<dyn DatagramSink> {
self.tunnel.sink()
}
fn info(&self) -> Option<TunnelInfo> {
Some(self.info.clone())
}
}
pub(crate) fn get_interface_name_by_ip(local_ip: &IpAddr) -> Option<String> {
let ifaces = pnet::datalink::interfaces();
for iface in ifaces {
for ip in iface.ips {
if ip.ip() == *local_ip {
return Some(iface.name);
}
}
}
None
}
pub mod tests {
use std::time::Instant;
use futures::SinkExt;
use tokio_stream::StreamExt;
use tokio_util::bytes::{BufMut, Bytes, BytesMut};
use crate::{
common::netns::NetNS,
tunnels::{close_tunnel, TunnelConnector, TunnelListener},
};
pub async fn _tunnel_echo_server(tunnel: Box<dyn super::Tunnel>, once: bool) {
let mut recv = Box::into_pin(tunnel.stream());
let mut send = Box::into_pin(tunnel.sink());
while let Some(ret) = recv.next().await {
if ret.is_err() {
log::trace!("recv error: {:?}", ret.err().unwrap());
break;
}
let res = ret.unwrap();
log::trace!("recv a msg, try echo back: {:?}", res);
send.send(Bytes::from(res)).await.unwrap();
if once {
break;
}
}
log::warn!("echo server exit...");
}
pub(crate) async fn _tunnel_pingpong<L, C>(listener: L, connector: C)
where
L: TunnelListener + Send + Sync + 'static,
C: TunnelConnector + Send + Sync + 'static,
{
_tunnel_pingpong_netns(listener, connector, NetNS::new(None), NetNS::new(None)).await
}
pub(crate) async fn _tunnel_pingpong_netns<L, C>(
mut listener: L,
mut connector: C,
l_netns: NetNS,
c_netns: NetNS,
) where
L: TunnelListener + Send + Sync + 'static,
C: TunnelConnector + Send + Sync + 'static,
{
l_netns
.run_async(|| async {
listener.listen().await.unwrap();
})
.await;
let lis = tokio::spawn(async move {
let ret = listener.accept().await.unwrap();
assert_eq!(
ret.info().unwrap().local_addr,
listener.local_url().to_string()
);
_tunnel_echo_server(ret, false).await
});
let tunnel = c_netns.run_async(|| connector.connect()).await.unwrap();
assert_eq!(
tunnel.info().unwrap().remote_addr,
connector.remote_url().to_string()
);
let mut send = tunnel.pin_sink();
let mut recv = tunnel.pin_stream();
let send_data = Bytes::from("abc");
send.send(send_data).await.unwrap();
let ret = tokio::time::timeout(tokio::time::Duration::from_secs(1), recv.next())
.await
.unwrap()
.unwrap()
.unwrap();
println!("echo back: {:?}", ret);
assert_eq!(ret, Bytes::from("abc"));
close_tunnel(&tunnel).await.unwrap();
if connector.remote_url().scheme() == "udp" {
lis.abort();
} else {
// lis should finish in 1 second
let ret = tokio::time::timeout(tokio::time::Duration::from_secs(1), lis).await;
assert!(ret.is_ok());
}
}
pub(crate) async fn _tunnel_bench<L, C>(mut listener: L, mut connector: C)
where
L: TunnelListener + Send + Sync + 'static,
C: TunnelConnector + Send + Sync + 'static,
{
listener.listen().await.unwrap();
let lis = tokio::spawn(async move {
let ret = listener.accept().await.unwrap();
_tunnel_echo_server(ret, false).await
});
let tunnel = connector.connect().await.unwrap();
let mut send = tunnel.pin_sink();
let mut recv = tunnel.pin_stream();
// prepare a 4k buffer with random data
let mut send_buf = BytesMut::new();
for _ in 0..64 {
send_buf.put_i128(rand::random::<i128>());
}
let now = Instant::now();
let mut count = 0;
while now.elapsed().as_secs() < 3 {
send.send(send_buf.clone().freeze()).await.unwrap();
let _ = recv.next().await.unwrap().unwrap();
count += 1;
}
println!("bps: {}", (count / 1024) * 4 / now.elapsed().as_secs());
lis.abort();
}
}
+159
View File
@@ -0,0 +1,159 @@
pub mod codec;
pub mod common;
pub mod ring_tunnel;
pub mod stats;
pub mod tcp_tunnel;
pub mod tunnel_filter;
pub mod udp_tunnel;
use std::{fmt::Debug, net::SocketAddr, pin::Pin, sync::Arc};
use async_trait::async_trait;
use easytier_rpc::TunnelInfo;
use futures::{Sink, SinkExt, Stream};
use thiserror::Error;
use tokio_util::bytes::{Bytes, BytesMut};
#[derive(Error, Debug)]
pub enum TunnelError {
#[error("Error: {0}")]
CommonError(String),
#[error("io error")]
IOError(#[from] std::io::Error),
#[error("wait resp error")]
WaitRespError(String),
#[error("Connect Error: {0}")]
ConnectError(String),
#[error("Invalid Protocol: {0}")]
InvalidProtocol(String),
#[error("Invalid Addr: {0}")]
InvalidAddr(String),
#[error("Tun Error: {0}")]
TunError(String),
#[error("timeout")]
Timeout(#[from] tokio::time::error::Elapsed),
}
pub type StreamT = BytesMut;
pub type StreamItem = Result<StreamT, TunnelError>;
pub type SinkItem = Bytes;
pub type SinkError = TunnelError;
pub trait DatagramStream: Stream<Item = StreamItem> + Send + Sync {}
impl<T> DatagramStream for T where T: Stream<Item = StreamItem> + Send + Sync {}
pub trait DatagramSink: Sink<SinkItem, Error = SinkError> + Send + Sync {}
impl<T> DatagramSink for T where T: Sink<SinkItem, Error = SinkError> + Send + Sync {}
#[auto_impl::auto_impl(Box, Arc)]
pub trait Tunnel: Send + Sync {
fn stream(&self) -> Box<dyn DatagramStream>;
fn sink(&self) -> Box<dyn DatagramSink>;
fn pin_stream(&self) -> Pin<Box<dyn DatagramStream>> {
Box::into_pin(self.stream())
}
fn pin_sink(&self) -> Pin<Box<dyn DatagramSink>> {
Box::into_pin(self.sink())
}
fn info(&self) -> Option<TunnelInfo>;
}
pub async fn close_tunnel(t: &Box<dyn Tunnel>) -> Result<(), TunnelError> {
t.pin_sink().close().await
}
#[auto_impl::auto_impl(Arc)]
pub trait TunnelConnCounter: 'static + Send + Sync + Debug {
fn get(&self) -> u32;
}
#[async_trait]
#[auto_impl::auto_impl(Box)]
pub trait TunnelListener: Send + Sync {
async fn listen(&mut self) -> Result<(), TunnelError>;
async fn accept(&mut self) -> Result<Box<dyn Tunnel>, TunnelError>;
fn local_url(&self) -> url::Url;
fn get_conn_counter(&self) -> Arc<Box<dyn TunnelConnCounter>> {
#[derive(Debug)]
struct FakeTunnelConnCounter {}
impl TunnelConnCounter for FakeTunnelConnCounter {
fn get(&self) -> u32 {
0
}
}
Arc::new(Box::new(FakeTunnelConnCounter {}))
}
}
#[async_trait]
#[auto_impl::auto_impl(Box)]
pub trait TunnelConnector {
async fn connect(&mut self) -> Result<Box<dyn Tunnel>, TunnelError>;
fn remote_url(&self) -> url::Url;
fn set_bind_addrs(&mut self, _addrs: Vec<SocketAddr>) {}
}
pub fn build_url_from_socket_addr(addr: &String, scheme: &str) -> url::Url {
url::Url::parse(format!("{}://{}", scheme, addr).as_str()).unwrap()
}
impl std::fmt::Debug for dyn Tunnel {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Tunnel")
.field("info", &self.info())
.finish()
}
}
impl std::fmt::Debug for dyn TunnelConnector + Sync + Send {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("TunnelConnector")
.field("remote_url", &self.remote_url())
.finish()
}
}
impl std::fmt::Debug for dyn TunnelListener {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("TunnelListener")
.field("local_url", &self.local_url())
.finish()
}
}
pub(crate) trait FromUrl {
fn from_url(url: url::Url) -> Result<Self, TunnelError>
where
Self: Sized;
}
pub(crate) fn check_scheme_and_get_socket_addr<T>(
url: &url::Url,
scheme: &str,
) -> Result<T, TunnelError>
where
T: FromUrl,
{
if url.scheme() != scheme {
return Err(TunnelError::InvalidProtocol(url.scheme().to_string()));
}
Ok(T::from_url(url.clone())?)
}
impl FromUrl for SocketAddr {
fn from_url(url: url::Url) -> Result<Self, TunnelError> {
Ok(url.socket_addrs(|| None)?.pop().unwrap())
}
}
impl FromUrl for uuid::Uuid {
fn from_url(url: url::Url) -> Result<Self, TunnelError> {
let o = url.host_str().unwrap();
let o = uuid::Uuid::parse_str(o).map_err(|e| TunnelError::InvalidAddr(e.to_string()))?;
Ok(o)
}
}
+391
View File
@@ -0,0 +1,391 @@
use std::{
collections::HashMap,
sync::{atomic::AtomicBool, Arc},
task::Poll,
};
use async_stream::stream;
use crossbeam_queue::ArrayQueue;
use async_trait::async_trait;
use futures::Sink;
use once_cell::sync::Lazy;
use tokio::sync::{Mutex, Notify};
use futures::FutureExt;
use tokio_util::bytes::BytesMut;
use uuid::Uuid;
use crate::tunnels::{SinkError, SinkItem};
use super::{
build_url_from_socket_addr, check_scheme_and_get_socket_addr, DatagramSink, DatagramStream,
Tunnel, TunnelConnector, TunnelError, TunnelInfo, TunnelListener,
};
static RING_TUNNEL_CAP: usize = 1000;
pub struct RingTunnel {
id: Uuid,
ring: Arc<ArrayQueue<SinkItem>>,
consume_notify: Arc<Notify>,
produce_notify: Arc<Notify>,
closed: Arc<AtomicBool>,
}
impl RingTunnel {
pub fn new(cap: usize) -> Self {
RingTunnel {
id: Uuid::new_v4(),
ring: Arc::new(ArrayQueue::new(cap)),
consume_notify: Arc::new(Notify::new()),
produce_notify: Arc::new(Notify::new()),
closed: Arc::new(AtomicBool::new(false)),
}
}
pub fn new_with_id(id: Uuid, cap: usize) -> Self {
let mut ret = Self::new(cap);
ret.id = id;
ret
}
fn recv_stream(&self) -> impl DatagramStream {
let ring = self.ring.clone();
let produce_notify = self.produce_notify.clone();
let consume_notify = self.consume_notify.clone();
let closed = self.closed.clone();
let id = self.id;
stream! {
loop {
if closed.load(std::sync::atomic::Ordering::Relaxed) {
log::warn!("ring recv tunnel {:?} closed", id);
yield Err(TunnelError::CommonError("Closed".to_owned()));
}
match ring.pop() {
Some(v) => {
let mut out = BytesMut::new();
out.extend_from_slice(&v);
consume_notify.notify_one();
log::trace!("id: {}, recv buffer, len: {:?}, buf: {:?}", id, v.len(), &v);
yield Ok(out);
},
None => {
log::trace!("waiting recv buffer, id: {}", id);
produce_notify.notified().await;
}
}
}
}
}
fn send_sink(&self) -> impl DatagramSink {
let ring = self.ring.clone();
let produce_notify = self.produce_notify.clone();
let consume_notify = self.consume_notify.clone();
let closed = self.closed.clone();
let id = self.id;
// type T = RingTunnel;
use tokio::task::JoinHandle;
struct T {
ring: RingTunnel,
wait_consume_task: Option<JoinHandle<()>>,
}
impl T {
fn wait_ring_consume(
self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
expected_size: usize,
) -> std::task::Poll<()> {
let self_mut = self.get_mut();
if self_mut.ring.ring.len() <= expected_size {
return Poll::Ready(());
}
if self_mut.wait_consume_task.is_none() {
let id = self_mut.ring.id;
let consume_notify = self_mut.ring.consume_notify.clone();
let ring = self_mut.ring.ring.clone();
let task = async move {
log::trace!(
"waiting ring consume done, expected_size: {}, id: {}",
expected_size,
id
);
while ring.len() > expected_size {
consume_notify.notified().await;
}
log::trace!(
"ring consume done, expected_size: {}, id: {}",
expected_size,
id
);
};
self_mut.wait_consume_task = Some(tokio::spawn(task));
}
let task = self_mut.wait_consume_task.as_mut().unwrap();
match task.poll_unpin(cx) {
Poll::Ready(_) => {
self_mut.wait_consume_task = None;
Poll::Ready(())
}
Poll::Pending => Poll::Pending,
}
}
}
impl Sink<SinkItem> for T {
type Error = SinkError;
fn poll_ready(
self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Result<(), Self::Error>> {
let expected_size = self.ring.ring.capacity() - 1;
match self.wait_ring_consume(cx, expected_size) {
Poll::Ready(_) => Poll::Ready(Ok(())),
Poll::Pending => Poll::Pending,
}
}
fn start_send(
self: std::pin::Pin<&mut Self>,
item: SinkItem,
) -> Result<(), Self::Error> {
log::trace!("id: {}, send buffer, buf: {:?}", self.ring.id, &item);
self.ring.ring.push(item).unwrap();
self.ring.produce_notify.notify_one();
Ok(())
}
fn poll_flush(
self: std::pin::Pin<&mut Self>,
_cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Result<(), Self::Error>> {
Poll::Ready(Ok(()))
}
fn poll_close(
self: std::pin::Pin<&mut Self>,
_cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Result<(), Self::Error>> {
self.ring
.closed
.store(true, std::sync::atomic::Ordering::Relaxed);
log::warn!("ring tunnel send {:?} closed", self.ring.id);
self.ring.produce_notify.notify_one();
Poll::Ready(Ok(()))
}
}
T {
ring: RingTunnel {
id,
ring,
consume_notify,
produce_notify,
closed,
},
wait_consume_task: None,
}
}
}
struct Connection {
client: RingTunnel,
server: RingTunnel,
connect_notify: Arc<Notify>,
}
impl Tunnel for RingTunnel {
fn stream(&self) -> Box<dyn DatagramStream> {
Box::new(self.recv_stream())
}
fn sink(&self) -> Box<dyn DatagramSink> {
Box::new(self.send_sink())
}
fn info(&self) -> Option<TunnelInfo> {
None
// Some(TunnelInfo {
// tunnel_type: "ring".to_owned(),
// local_addr: format!("ring://{}", self.id),
// remote_addr: format!("ring://{}", self.id),
// })
}
}
static CONNECTION_MAP: Lazy<Arc<Mutex<HashMap<uuid::Uuid, Arc<Connection>>>>> =
Lazy::new(|| Arc::new(Mutex::new(HashMap::new())));
#[derive(Debug)]
pub struct RingTunnelListener {
listerner_addr: url::Url,
}
impl RingTunnelListener {
pub fn new(key: url::Url) -> Self {
RingTunnelListener {
listerner_addr: key,
}
}
}
struct ConnectionForServer {
conn: Arc<Connection>,
}
impl Tunnel for ConnectionForServer {
fn stream(&self) -> Box<dyn DatagramStream> {
Box::new(self.conn.server.recv_stream())
}
fn sink(&self) -> Box<dyn DatagramSink> {
Box::new(self.conn.client.send_sink())
}
fn info(&self) -> Option<TunnelInfo> {
Some(TunnelInfo {
tunnel_type: "ring".to_owned(),
local_addr: build_url_from_socket_addr(&self.conn.server.id.into(), "ring").into(),
remote_addr: build_url_from_socket_addr(&self.conn.client.id.into(), "ring").into(),
})
}
}
struct ConnectionForClient {
conn: Arc<Connection>,
}
impl Tunnel for ConnectionForClient {
fn stream(&self) -> Box<dyn DatagramStream> {
Box::new(self.conn.client.recv_stream())
}
fn sink(&self) -> Box<dyn DatagramSink> {
Box::new(self.conn.server.send_sink())
}
fn info(&self) -> Option<TunnelInfo> {
Some(TunnelInfo {
tunnel_type: "ring".to_owned(),
local_addr: build_url_from_socket_addr(&self.conn.client.id.into(), "ring").into(),
remote_addr: build_url_from_socket_addr(&self.conn.server.id.into(), "ring").into(),
})
}
}
impl RingTunnelListener {
async fn add_connection(listener_addr: uuid::Uuid) {
CONNECTION_MAP.lock().await.insert(
listener_addr.clone(),
Arc::new(Connection {
client: RingTunnel::new(RING_TUNNEL_CAP),
server: RingTunnel::new_with_id(listener_addr.clone(), RING_TUNNEL_CAP),
connect_notify: Arc::new(Notify::new()),
}),
);
}
fn get_addr(&self) -> Result<uuid::Uuid, TunnelError> {
check_scheme_and_get_socket_addr::<Uuid>(&self.listerner_addr, "ring")
}
}
#[async_trait]
impl TunnelListener for RingTunnelListener {
async fn listen(&mut self) -> Result<(), TunnelError> {
log::info!("listen new conn of key: {}", self.listerner_addr);
Self::add_connection(self.get_addr()?).await;
Ok(())
}
async fn accept(&mut self) -> Result<Box<dyn Tunnel>, TunnelError> {
log::info!("waiting accept new conn of key: {}", self.listerner_addr);
let val = CONNECTION_MAP
.lock()
.await
.get(&self.get_addr()?)
.unwrap()
.clone();
val.connect_notify.notified().await;
CONNECTION_MAP.lock().await.remove(&self.get_addr()?);
Self::add_connection(self.get_addr()?).await;
log::info!("accept new conn of key: {}", self.listerner_addr);
Ok(Box::new(ConnectionForServer { conn: val }))
}
fn local_url(&self) -> url::Url {
self.listerner_addr.clone()
}
}
pub struct RingTunnelConnector {
remote_addr: url::Url,
}
impl RingTunnelConnector {
pub fn new(remote_addr: url::Url) -> Self {
RingTunnelConnector { remote_addr }
}
}
#[async_trait]
impl TunnelConnector for RingTunnelConnector {
async fn connect(&mut self) -> Result<Box<dyn Tunnel>, super::TunnelError> {
let val = CONNECTION_MAP
.lock()
.await
.get(&check_scheme_and_get_socket_addr::<Uuid>(
&self.remote_addr,
"ring",
)?)
.unwrap()
.clone();
val.connect_notify.notify_one();
log::info!("connecting");
Ok(Box::new(ConnectionForClient { conn: val }))
}
fn remote_url(&self) -> url::Url {
self.remote_addr.clone()
}
}
pub fn create_ring_tunnel_pair() -> (Box<dyn Tunnel>, Box<dyn Tunnel>) {
let conn = Arc::new(Connection {
client: RingTunnel::new(RING_TUNNEL_CAP),
server: RingTunnel::new(RING_TUNNEL_CAP),
connect_notify: Arc::new(Notify::new()),
});
(
Box::new(ConnectionForServer { conn: conn.clone() }),
Box::new(ConnectionForClient { conn: conn }),
)
}
#[cfg(test)]
mod tests {
use crate::tunnels::common::tests::{_tunnel_bench, _tunnel_pingpong};
use super::*;
#[tokio::test]
async fn ring_pingpong() {
let id: url::Url = format!("ring://{}", Uuid::new_v4()).parse().unwrap();
let listener = RingTunnelListener::new(id.clone());
let connector = RingTunnelConnector::new(id.clone());
_tunnel_pingpong(listener, connector).await
}
#[tokio::test]
async fn ring_bench() {
let id: url::Url = format!("ring://{}", Uuid::new_v4()).parse().unwrap();
let listener = RingTunnelListener::new(id.clone());
let connector = RingTunnelConnector::new(id);
_tunnel_bench(listener, connector).await
}
}
+101
View File
@@ -0,0 +1,101 @@
use std::sync::atomic::{AtomicU32, AtomicU64};
pub struct WindowLatency {
latency_us_window: Vec<AtomicU64>,
latency_us_window_index: AtomicU32,
latency_us_window_size: AtomicU32,
}
impl WindowLatency {
pub fn new(window_size: u32) -> Self {
Self {
latency_us_window: (0..window_size).map(|_| AtomicU64::new(0)).collect(),
latency_us_window_index: AtomicU32::new(0),
latency_us_window_size: AtomicU32::new(window_size),
}
}
pub fn record_latency(&self, latency_us: u64) {
let index = self
.latency_us_window_index
.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
let index = index
% self
.latency_us_window_size
.load(std::sync::atomic::Ordering::Relaxed);
self.latency_us_window[index as usize]
.store(latency_us, std::sync::atomic::Ordering::Relaxed);
}
pub fn get_latency_us(&self) -> u64 {
let window_size = self
.latency_us_window_size
.load(std::sync::atomic::Ordering::Relaxed);
let mut sum = 0;
let mut count = 0;
for i in 0..window_size {
let latency_us =
self.latency_us_window[i as usize].load(std::sync::atomic::Ordering::Relaxed);
if latency_us > 0 {
sum += latency_us;
count += 1;
}
}
if count == 0 {
0
} else {
sum / count
}
}
}
pub struct Throughput {
tx_bytes: AtomicU64,
rx_bytes: AtomicU64,
tx_packets: AtomicU64,
rx_packets: AtomicU64,
}
impl Throughput {
pub fn new() -> Self {
Self {
tx_bytes: AtomicU64::new(0),
rx_bytes: AtomicU64::new(0),
tx_packets: AtomicU64::new(0),
rx_packets: AtomicU64::new(0),
}
}
pub fn tx_bytes(&self) -> u64 {
self.tx_bytes.load(std::sync::atomic::Ordering::Relaxed)
}
pub fn rx_bytes(&self) -> u64 {
self.rx_bytes.load(std::sync::atomic::Ordering::Relaxed)
}
pub fn tx_packets(&self) -> u64 {
self.tx_packets.load(std::sync::atomic::Ordering::Relaxed)
}
pub fn rx_packets(&self) -> u64 {
self.rx_packets.load(std::sync::atomic::Ordering::Relaxed)
}
pub fn record_tx_bytes(&self, bytes: u64) {
self.tx_bytes
.fetch_add(bytes, std::sync::atomic::Ordering::Relaxed);
self.tx_packets
.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
}
pub fn record_rx_bytes(&self, bytes: u64) {
self.rx_bytes
.fetch_add(bytes, std::sync::atomic::Ordering::Relaxed);
self.rx_packets
.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
}
}
+284
View File
@@ -0,0 +1,284 @@
use std::net::SocketAddr;
use async_trait::async_trait;
use futures::{stream::FuturesUnordered, StreamExt};
use tokio::net::{TcpListener, TcpSocket, TcpStream};
use tokio_util::codec::{FramedRead, FramedWrite, LengthDelimitedCodec};
use super::{
check_scheme_and_get_socket_addr, common::FramedTunnel, Tunnel, TunnelInfo, TunnelListener,
};
#[derive(Debug)]
pub struct TcpTunnelListener {
addr: url::Url,
listener: Option<TcpListener>,
}
impl TcpTunnelListener {
pub fn new(addr: url::Url) -> Self {
TcpTunnelListener {
addr,
listener: None,
}
}
}
#[async_trait]
impl TunnelListener for TcpTunnelListener {
async fn listen(&mut self) -> Result<(), super::TunnelError> {
let addr = check_scheme_and_get_socket_addr::<SocketAddr>(&self.addr, "tcp")?;
let socket = if addr.is_ipv4() {
TcpSocket::new_v4()?
} else {
TcpSocket::new_v6()?
};
socket.set_reuseaddr(true)?;
#[cfg(all(unix, not(target_os = "solaris"), not(target_os = "illumos")))]
socket.set_reuseport(true)?;
socket.bind(addr)?;
self.listener = Some(socket.listen(1024)?);
Ok(())
}
async fn accept(&mut self) -> Result<Box<dyn Tunnel>, super::TunnelError> {
let listener = self.listener.as_ref().unwrap();
let (stream, _) = listener.accept().await?;
stream.set_nodelay(true).unwrap();
let info = TunnelInfo {
tunnel_type: "tcp".to_owned(),
local_addr: self.local_url().into(),
remote_addr: super::build_url_from_socket_addr(&stream.peer_addr()?.to_string(), "tcp")
.into(),
};
let (r, w) = tokio::io::split(stream);
Ok(FramedTunnel::new_tunnel_with_info(
FramedRead::new(r, LengthDelimitedCodec::new()),
FramedWrite::new(w, LengthDelimitedCodec::new()),
info,
))
}
fn local_url(&self) -> url::Url {
self.addr.clone()
}
}
fn get_tunnel_with_tcp_stream(
stream: TcpStream,
remote_url: url::Url,
) -> Result<Box<dyn Tunnel>, super::TunnelError> {
stream.set_nodelay(true).unwrap();
let info = TunnelInfo {
tunnel_type: "tcp".to_owned(),
local_addr: super::build_url_from_socket_addr(&stream.local_addr()?.to_string(), "tcp")
.into(),
remote_addr: remote_url.into(),
};
let (r, w) = tokio::io::split(stream);
Ok(Box::new(FramedTunnel::new_tunnel_with_info(
FramedRead::new(r, LengthDelimitedCodec::new()),
FramedWrite::new(w, LengthDelimitedCodec::new()),
info,
)))
}
#[derive(Debug)]
pub struct TcpTunnelConnector {
addr: url::Url,
bind_addrs: Vec<SocketAddr>,
}
impl TcpTunnelConnector {
pub fn new(addr: url::Url) -> Self {
TcpTunnelConnector {
addr,
bind_addrs: vec![],
}
}
async fn connect_with_default_bind(&mut self) -> Result<Box<dyn Tunnel>, super::TunnelError> {
tracing::info!(addr = ?self.addr, "connect tcp start");
let addr = check_scheme_and_get_socket_addr::<SocketAddr>(&self.addr, "tcp")?;
let stream = TcpStream::connect(addr).await?;
tracing::info!(addr = ?self.addr, "connect tcp succ");
return get_tunnel_with_tcp_stream(stream, self.addr.clone().into());
}
async fn connect_with_custom_bind(
&mut self,
is_ipv4: bool,
) -> Result<Box<dyn Tunnel>, super::TunnelError> {
let mut futures = FuturesUnordered::new();
let dst_addr = check_scheme_and_get_socket_addr::<SocketAddr>(&self.addr, "tcp")?;
for bind_addr in self.bind_addrs.iter() {
let socket = if is_ipv4 {
TcpSocket::new_v4()?
} else {
TcpSocket::new_v6()?
};
socket.set_reuseaddr(true)?;
#[cfg(all(unix, not(target_os = "solaris"), not(target_os = "illumos")))]
socket.set_reuseport(true)?;
socket.bind(*bind_addr)?;
// linux does not use interface of bind_addr to send packet, so we need to bind device
// mac can handle this with bind correctly
#[cfg(any(target_os = "android", target_os = "fuchsia", target_os = "linux"))]
if let Some(dev_name) = super::common::get_interface_name_by_ip(&bind_addr.ip()) {
tracing::trace!(dev_name = ?dev_name, "bind device");
socket.bind_device(Some(dev_name.as_bytes()))?;
}
futures.push(socket.connect(dst_addr.clone()));
}
let Some(ret) = futures.next().await else {
return Err(super::TunnelError::CommonError(
"join connect futures failed".to_owned(),
));
};
return get_tunnel_with_tcp_stream(ret?, self.addr.clone().into());
}
}
#[async_trait]
impl super::TunnelConnector for TcpTunnelConnector {
async fn connect(&mut self) -> Result<Box<dyn Tunnel>, super::TunnelError> {
if self.bind_addrs.is_empty() {
self.connect_with_default_bind().await
} else if self.bind_addrs[0].is_ipv4() {
self.connect_with_custom_bind(true).await
} else {
self.connect_with_custom_bind(false).await
}
}
fn remote_url(&self) -> url::Url {
self.addr.clone()
}
fn set_bind_addrs(&mut self, addrs: Vec<SocketAddr>) {
self.bind_addrs = addrs;
}
}
#[cfg(test)]
mod tests {
use futures::SinkExt;
use crate::tunnels::{
common::tests::{_tunnel_bench, _tunnel_pingpong},
TunnelConnector,
};
use super::*;
#[tokio::test]
async fn tcp_pingpong() {
let listener = TcpTunnelListener::new("tcp://0.0.0.0:11011".parse().unwrap());
let connector = TcpTunnelConnector::new("tcp://127.0.0.1:11011".parse().unwrap());
_tunnel_pingpong(listener, connector).await
}
#[tokio::test]
async fn tcp_bench() {
let listener = TcpTunnelListener::new("tcp://0.0.0.0:11012".parse().unwrap());
let connector = TcpTunnelConnector::new("tcp://127.0.0.1:11012".parse().unwrap());
_tunnel_bench(listener, connector).await
}
#[tokio::test]
async fn tcp_bench_with_bind() {
let listener = TcpTunnelListener::new("tcp://127.0.0.1:11013".parse().unwrap());
let mut connector = TcpTunnelConnector::new("tcp://127.0.0.1:11013".parse().unwrap());
connector.set_bind_addrs(vec!["127.0.0.1:0".parse().unwrap()]);
_tunnel_pingpong(listener, connector).await
}
#[tokio::test]
#[should_panic]
async fn tcp_bench_with_bind_fail() {
let listener = TcpTunnelListener::new("tcp://127.0.0.1:11014".parse().unwrap());
let mut connector = TcpTunnelConnector::new("tcp://127.0.0.1:11014".parse().unwrap());
connector.set_bind_addrs(vec!["10.0.0.1:0".parse().unwrap()]);
_tunnel_pingpong(listener, connector).await
}
// test slow send lock in framed tunnel
#[tokio::test]
async fn tcp_multiple_sender_and_slow_receiver() {
// console_subscriber::init();
let mut listener = TcpTunnelListener::new("tcp://127.0.0.1:11014".parse().unwrap());
let mut connector = TcpTunnelConnector::new("tcp://127.0.0.1:11014".parse().unwrap());
listener.listen().await.unwrap();
let t1 = tokio::spawn(async move {
let t = listener.accept().await.unwrap();
let mut stream = t.pin_stream();
let now = tokio::time::Instant::now();
while let Some(Ok(_)) = stream.next().await {
tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
if now.elapsed().as_secs() > 5 {
break;
}
}
tracing::info!("t1 exit");
});
let tunnel = connector.connect().await.unwrap();
let mut sink1 = tunnel.pin_sink();
let t2 = tokio::spawn(async move {
for i in 0..1000000 {
let a = sink1.send(b"hello".to_vec().into()).await;
if a.is_err() {
tracing::info!(?a, "t2 exit with err");
break;
}
if i % 5000 == 0 {
tracing::info!(i, "send2 1000");
}
}
tracing::info!("t2 exit");
});
let mut sink2 = tunnel.pin_sink();
let t3 = tokio::spawn(async move {
for i in 0..1000000 {
let a = sink2.send(b"hello".to_vec().into()).await;
if a.is_err() {
tracing::info!(?a, "t3 exit with err");
break;
}
if i % 5000 == 0 {
tracing::info!(i, "send2 1000");
}
}
tracing::info!("t3 exit");
});
let t4 = tokio::spawn(async move {
tokio::time::sleep(tokio::time::Duration::from_secs(5)).await;
tracing::info!("closing");
let close_ret = tunnel.pin_sink().close().await;
tracing::info!("closed {:?}", close_ret);
});
let _ = tokio::join!(t1, t2, t3, t4);
}
}
+228
View File
@@ -0,0 +1,228 @@
use std::{
sync::Arc,
task::{Context, Poll},
};
use easytier_rpc::TunnelInfo;
use futures::{Sink, SinkExt, Stream, StreamExt};
use self::stats::Throughput;
use super::*;
use crate::tunnels::{DatagramSink, DatagramStream, SinkError, SinkItem, StreamItem, Tunnel};
pub trait TunnelFilter {
fn before_send(&self, data: SinkItem) -> Result<SinkItem, SinkError>;
fn after_received(&self, data: StreamItem) -> Result<BytesMut, TunnelError>;
}
pub struct TunnelWithFilter<T, F> {
inner: T,
filter: Arc<F>,
}
impl<T, F> Tunnel for TunnelWithFilter<T, F>
where
T: Tunnel + Send + Sync + 'static,
F: TunnelFilter + Send + Sync + 'static,
{
fn sink(&self) -> Box<dyn DatagramSink> {
struct SinkWrapper<F> {
sink: Pin<Box<dyn DatagramSink>>,
filter: Arc<F>,
}
impl<F> Sink<SinkItem> for SinkWrapper<F>
where
F: TunnelFilter + Send + Sync + 'static,
{
type Error = SinkError;
fn poll_ready(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Result<(), Self::Error>> {
self.get_mut().sink.poll_ready_unpin(cx)
}
fn start_send(self: Pin<&mut Self>, item: SinkItem) -> Result<(), Self::Error> {
let item = self.filter.before_send(item)?;
self.get_mut().sink.start_send_unpin(item)
}
fn poll_flush(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Result<(), Self::Error>> {
self.get_mut().sink.poll_flush_unpin(cx)
}
fn poll_close(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Result<(), Self::Error>> {
self.get_mut().sink.poll_close_unpin(cx)
}
}
Box::new(SinkWrapper {
sink: self.inner.pin_sink(),
filter: self.filter.clone(),
})
}
fn stream(&self) -> Box<dyn DatagramStream> {
struct StreamWrapper<F> {
stream: Pin<Box<dyn DatagramStream>>,
filter: Arc<F>,
}
impl<F> Stream for StreamWrapper<F>
where
F: TunnelFilter + Send + Sync + 'static,
{
type Item = StreamItem;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let self_mut = self.get_mut();
match self_mut.stream.poll_next_unpin(cx) {
Poll::Ready(Some(ret)) => {
Poll::Ready(Some(self_mut.filter.after_received(ret)))
}
Poll::Ready(None) => Poll::Ready(None),
Poll::Pending => Poll::Pending,
}
}
}
Box::new(StreamWrapper {
stream: self.inner.pin_stream(),
filter: self.filter.clone(),
})
}
fn info(&self) -> Option<TunnelInfo> {
self.inner.info()
}
}
impl<T, F> TunnelWithFilter<T, F>
where
T: Tunnel + Send + Sync + 'static,
F: TunnelFilter + Send + Sync + 'static,
{
pub fn new(inner: T, filter: Arc<F>) -> Self {
Self { inner, filter }
}
}
pub struct PacketRecorderTunnelFilter {
pub received: Arc<std::sync::Mutex<Vec<Bytes>>>,
pub sent: Arc<std::sync::Mutex<Vec<Bytes>>>,
}
impl TunnelFilter for PacketRecorderTunnelFilter {
fn before_send(&self, data: SinkItem) -> Result<SinkItem, SinkError> {
self.received.lock().unwrap().push(data.clone());
Ok(data)
}
fn after_received(&self, data: StreamItem) -> Result<BytesMut, TunnelError> {
match data {
Ok(v) => {
self.sent.lock().unwrap().push(v.clone().into());
Ok(v)
}
Err(e) => Err(e),
}
}
}
impl PacketRecorderTunnelFilter {
pub fn new() -> Self {
Self {
received: Arc::new(std::sync::Mutex::new(Vec::new())),
sent: Arc::new(std::sync::Mutex::new(Vec::new())),
}
}
}
pub struct StatsRecorderTunnelFilter {
throughput: Arc<Throughput>,
}
impl TunnelFilter for StatsRecorderTunnelFilter {
fn before_send(&self, data: SinkItem) -> Result<SinkItem, SinkError> {
self.throughput.record_tx_bytes(data.len() as u64);
Ok(data)
}
fn after_received(&self, data: StreamItem) -> Result<BytesMut, TunnelError> {
match data {
Ok(v) => {
self.throughput.record_rx_bytes(v.len() as u64);
Ok(v)
}
Err(e) => Err(e),
}
}
}
impl StatsRecorderTunnelFilter {
pub fn new() -> Self {
Self {
throughput: Arc::new(Throughput::new()),
}
}
pub fn get_throughput(&self) -> Arc<Throughput> {
self.throughput.clone()
}
}
#[macro_export]
macro_rules! define_tunnel_filter_chain {
($type_name:ident $(, $field_name:ident = $filter_type:ty)+) => (
pub struct $type_name {
$($field_name: std::sync::Arc<$filter_type>,)+
}
impl $type_name {
pub fn new() -> Self {
Self {
$($field_name: std::sync::Arc::new(<$filter_type>::new()),)+
}
}
pub fn wrap_tunnel(&self, tunnel: impl Tunnel + 'static) -> impl Tunnel {
$(
let tunnel = crate::tunnels::tunnel_filter::TunnelWithFilter::new(tunnel, self.$field_name.clone());
)+
tunnel
}
}
)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::tunnels::ring_tunnel::RingTunnel;
#[tokio::test]
async fn test_nested_filter() {
define_tunnel_filter_chain!(
Filter,
a = PacketRecorderTunnelFilter,
b = PacketRecorderTunnelFilter,
c = PacketRecorderTunnelFilter
);
let filter = Filter::new();
let tunnel = filter.wrap_tunnel(RingTunnel::new(1));
let mut s = tunnel.pin_sink();
s.send(Bytes::from("hello")).await.unwrap();
assert_eq!(1, filter.a.received.lock().unwrap().len());
assert_eq!(1, filter.b.received.lock().unwrap().len());
assert_eq!(1, filter.c.received.lock().unwrap().len());
}
}
+574
View File
@@ -0,0 +1,574 @@
use std::{fmt::Debug, pin::Pin, sync::Arc};
use async_trait::async_trait;
use dashmap::DashMap;
use easytier_rpc::TunnelInfo;
use futures::{stream::FuturesUnordered, SinkExt, StreamExt};
use rkyv::{Archive, Deserialize, Serialize};
use std::net::SocketAddr;
use tokio::{net::UdpSocket, sync::Mutex, task::JoinSet};
use tokio_util::{
bytes::{Buf, Bytes, BytesMut},
udp::UdpFramed,
};
use tracing::Instrument;
use crate::{
common::rkyv_util::{self, encode_to_bytes},
tunnels::{build_url_from_socket_addr, close_tunnel, TunnelConnCounter, TunnelConnector},
};
use super::{
codec::BytesCodec,
common::{FramedTunnel, TunnelWithCustomInfo},
ring_tunnel::create_ring_tunnel_pair,
DatagramSink, DatagramStream, Tunnel, TunnelListener,
};
pub const UDP_DATA_MTU: usize = 2500;
#[derive(Archive, Deserialize, Serialize, Debug)]
#[archive(compare(PartialEq), check_bytes)]
// Derives can be passed through to the generated type:
#[archive_attr(derive(Debug))]
pub enum UdpPacketPayload {
Syn,
Sack,
HolePunch(Vec<u8>),
Data(Vec<u8>),
}
#[derive(Archive, Deserialize, Serialize, Debug)]
#[archive(compare(PartialEq), check_bytes)]
#[archive_attr(derive(Debug))]
pub struct UdpPacket {
pub conn_id: u32,
pub payload: UdpPacketPayload,
}
impl UdpPacket {
pub fn new_data_packet(conn_id: u32, data: Vec<u8>) -> Self {
Self {
conn_id,
payload: UdpPacketPayload::Data(data),
}
}
pub fn new_hole_punch_packet(data: Vec<u8>) -> Self {
Self {
conn_id: 0,
payload: UdpPacketPayload::HolePunch(data),
}
}
pub fn new_syn_packet(conn_id: u32) -> Self {
Self {
conn_id,
payload: UdpPacketPayload::Syn,
}
}
pub fn new_sack_packet(conn_id: u32) -> Self {
Self {
conn_id,
payload: UdpPacketPayload::Sack,
}
}
}
fn try_get_data_payload(mut buf: BytesMut, conn_id: u32) -> Option<BytesMut> {
let Ok(udp_packet) = rkyv_util::decode_from_bytes_checked::<UdpPacket>(&buf) else {
tracing::warn!(?buf, "udp decode error");
return None;
};
if udp_packet.conn_id != conn_id.clone() {
tracing::warn!(?udp_packet, ?conn_id, "udp conn id not match");
return None;
}
let ArchivedUdpPacketPayload::Data(payload) = &udp_packet.payload else {
tracing::warn!(?udp_packet, "udp payload not data");
return None;
};
let ptr_range = payload.as_ptr_range();
let offset = ptr_range.start as usize - buf.as_ptr() as usize;
let len = ptr_range.end as usize - ptr_range.start as usize;
buf.advance(offset);
buf.truncate(len);
tracing::trace!(?offset, ?len, ?buf, "udp payload data");
Some(buf)
}
fn get_tunnel_from_socket(
socket: Arc<UdpSocket>,
addr: SocketAddr,
conn_id: u32,
) -> Box<dyn super::Tunnel> {
let udp = UdpFramed::new(socket.clone(), BytesCodec::new(UDP_DATA_MTU));
let (sink, stream) = udp.split();
let recv_addr = addr;
let stream = stream.filter_map(move |v| async move {
tracing::trace!(?v, "udp stream recv something");
if v.is_err() {
tracing::warn!(?v, "udp stream error");
return Some(Err(super::TunnelError::CommonError(
"udp stream error".to_owned(),
)));
}
let (buf, addr) = v.unwrap();
assert_eq!(addr, recv_addr.clone());
Some(Ok(try_get_data_payload(buf, conn_id.clone())?))
});
let stream = Box::pin(stream);
let sender_addr = addr;
let sink = Box::pin(sink.with(move |v: Bytes| async move {
if false {
return Err(super::TunnelError::CommonError("udp sink error".to_owned()));
}
// TODO: two copy here, how to avoid?
let udp_packet = UdpPacket::new_data_packet(conn_id, v.to_vec());
tracing::trace!(?udp_packet, ?v, "udp send packet");
let v = encode_to_bytes::<_, UDP_DATA_MTU>(&udp_packet);
Ok((v, sender_addr))
}));
FramedTunnel::new_tunnel_with_info(
stream,
sink,
// TODO: this remote addr is not a url
super::TunnelInfo {
tunnel_type: "udp".to_owned(),
local_addr: super::build_url_from_socket_addr(
&socket.local_addr().unwrap().to_string(),
"udp",
)
.into(),
remote_addr: super::build_url_from_socket_addr(&addr.to_string(), "udp").into(),
},
)
}
struct StreamSinkPair(
Pin<Box<dyn DatagramStream>>,
Pin<Box<dyn DatagramSink>>,
u32,
);
type ArcStreamSinkPair = Arc<Mutex<StreamSinkPair>>;
pub struct UdpTunnelListener {
addr: url::Url,
socket: Option<Arc<UdpSocket>>,
sock_map: Arc<DashMap<SocketAddr, ArcStreamSinkPair>>,
forward_tasks: Arc<Mutex<JoinSet<()>>>,
conn_recv: tokio::sync::mpsc::Receiver<Box<dyn Tunnel>>,
conn_send: Option<tokio::sync::mpsc::Sender<Box<dyn Tunnel>>>,
}
impl UdpTunnelListener {
pub fn new(addr: url::Url) -> Self {
let (conn_send, conn_recv) = tokio::sync::mpsc::channel(100);
Self {
addr,
socket: None,
sock_map: Arc::new(DashMap::new()),
forward_tasks: Arc::new(Mutex::new(JoinSet::new())),
conn_recv,
conn_send: Some(conn_send),
}
}
async fn try_forward_packet(
sock_map: &DashMap<SocketAddr, ArcStreamSinkPair>,
buf: BytesMut,
addr: SocketAddr,
) -> Result<(), super::TunnelError> {
let entry = sock_map.get_mut(&addr);
if entry.is_none() {
log::warn!("udp forward packet: {:?}, {:?}, no entry", addr, buf);
return Ok(());
}
log::trace!("udp forward packet: {:?}, {:?}", addr, buf);
let entry = entry.unwrap();
let pair = entry.value().clone();
drop(entry);
let Some(buf) = try_get_data_payload(buf, pair.lock().await.2) else {
return Ok(());
};
pair.lock().await.1.send(buf.freeze()).await?;
Ok(())
}
async fn handle_connect(
socket: Arc<UdpSocket>,
addr: SocketAddr,
forward_tasks: Arc<Mutex<JoinSet<()>>>,
sock_map: Arc<DashMap<SocketAddr, ArcStreamSinkPair>>,
local_url: url::Url,
conn_id: u32,
) -> Result<Box<dyn Tunnel>, super::TunnelError> {
tracing::info!(?conn_id, ?addr, "udp connection accept handling",);
let udp_packet = UdpPacket::new_sack_packet(conn_id);
let sack_buf = encode_to_bytes::<_, UDP_DATA_MTU>(&udp_packet);
socket.send_to(&sack_buf, addr).await?;
let (ctunnel, stunnel) = create_ring_tunnel_pair();
let udp_tunnel = get_tunnel_from_socket(socket.clone(), addr, conn_id);
let ss_pair = StreamSinkPair(ctunnel.pin_stream(), ctunnel.pin_sink(), conn_id);
let addr_copy = addr.clone();
sock_map.insert(addr, Arc::new(Mutex::new(ss_pair)));
let ctunnel_stream = ctunnel.pin_stream();
forward_tasks.lock().await.spawn(async move {
let ret = ctunnel_stream
.map(|v| {
tracing::trace!(?v, "udp stream recv something in forward task");
if v.is_err() {
return Err(super::TunnelError::CommonError(
"udp stream error".to_owned(),
));
}
Ok(v.unwrap().freeze())
})
.forward(udp_tunnel.pin_sink())
.await;
if let None = sock_map.remove(&addr_copy) {
log::warn!("udp forward packet: {:?}, no entry", addr_copy);
}
close_tunnel(&ctunnel).await.unwrap();
log::warn!("udp connection forward done: {:?}, {:?}", addr_copy, ret);
});
Ok(Box::new(TunnelWithCustomInfo::new(
stunnel,
TunnelInfo {
tunnel_type: "udp".to_owned(),
local_addr: local_url.into(),
remote_addr: build_url_from_socket_addr(&addr.to_string(), "udp").into(),
},
)))
}
pub fn get_socket(&self) -> Option<Arc<UdpSocket>> {
self.socket.clone()
}
}
#[async_trait]
impl TunnelListener for UdpTunnelListener {
async fn listen(&mut self) -> Result<(), super::TunnelError> {
let addr = super::check_scheme_and_get_socket_addr::<SocketAddr>(&self.addr, "udp")?;
self.socket = Some(Arc::new(UdpSocket::bind(addr).await?));
let socket = self.socket.as_ref().unwrap().clone();
let forward_tasks = self.forward_tasks.clone();
let sock_map = self.sock_map.clone();
let conn_send = self.conn_send.take().unwrap();
let local_url = self.local_url().clone();
self.forward_tasks.lock().await.spawn(
async move {
loop {
let mut buf = BytesMut::new();
buf.resize(2500, 0);
let (_size, addr) = socket.recv_from(&mut buf).await.unwrap();
let _ = buf.split_off(_size);
log::trace!(
"udp recv packet: {:?}, buf: {:?}, size: {}",
addr,
buf,
_size
);
let Ok(udp_packet) = rkyv_util::decode_from_bytes_checked::<UdpPacket>(&buf)
else {
tracing::warn!(?buf, "udp decode error in forward task");
continue;
};
if matches!(udp_packet.payload, ArchivedUdpPacketPayload::Syn) {
let conn = Self::handle_connect(
socket.clone(),
addr,
forward_tasks.clone(),
sock_map.clone(),
local_url.clone(),
udp_packet.conn_id.into(),
)
.await
.unwrap();
if let Err(e) = conn_send.send(conn).await {
tracing::warn!(?e, "udp send conn to accept channel error");
}
} else {
Self::try_forward_packet(sock_map.as_ref(), buf, addr)
.await
.unwrap();
}
}
}
.instrument(tracing::info_span!("udp forward task", ?self.socket)),
);
// let forward_tasks_clone = self.forward_tasks.clone();
// tokio::spawn(async move {
// loop {
// let mut locked_forward_tasks = forward_tasks_clone.lock().await;
// tokio::select! {
// ret = locked_forward_tasks.join_next() => {
// tracing::warn!(?ret, "udp forward task exit");
// }
// else => {
// drop(locked_forward_tasks);
// tokio::time::sleep(tokio::time::Duration::from_secs(1)).await;
// continue;
// }
// }
// }
// });
Ok(())
}
async fn accept(&mut self) -> Result<Box<dyn super::Tunnel>, super::TunnelError> {
log::info!("start udp accept: {:?}", self.addr);
while let Some(conn) = self.conn_recv.recv().await {
return Ok(conn);
}
return Err(super::TunnelError::CommonError(
"udp accept error".to_owned(),
));
}
fn local_url(&self) -> url::Url {
self.addr.clone()
}
fn get_conn_counter(&self) -> Arc<Box<dyn TunnelConnCounter>> {
struct UdpTunnelConnCounter {
sock_map: Arc<DashMap<SocketAddr, ArcStreamSinkPair>>,
}
impl Debug for UdpTunnelConnCounter {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("UdpTunnelConnCounter")
.field("sock_map_len", &self.sock_map.len())
.finish()
}
}
impl TunnelConnCounter for UdpTunnelConnCounter {
fn get(&self) -> u32 {
self.sock_map.len() as u32
}
}
Arc::new(Box::new(UdpTunnelConnCounter {
sock_map: self.sock_map.clone(),
}))
}
}
pub struct UdpTunnelConnector {
addr: url::Url,
bind_addrs: Vec<SocketAddr>,
}
impl UdpTunnelConnector {
pub fn new(addr: url::Url) -> Self {
Self {
addr,
bind_addrs: vec![],
}
}
async fn wait_sack(
socket: &UdpSocket,
addr: SocketAddr,
conn_id: u32,
) -> Result<(), super::TunnelError> {
let mut buf = BytesMut::new();
buf.resize(128, 0);
let (usize, recv_addr) = tokio::time::timeout(
tokio::time::Duration::from_secs(3),
socket.recv_from(&mut buf),
)
.await??;
if recv_addr != addr {
return Err(super::TunnelError::ConnectError(format!(
"udp connect error, unexpected sack addr: {:?}, {:?}",
recv_addr, addr
)));
}
let _ = buf.split_off(usize);
let Ok(udp_packet) = rkyv_util::decode_from_bytes_checked::<UdpPacket>(&buf) else {
tracing::warn!(?buf, "udp decode error in wait sack");
return Err(super::TunnelError::ConnectError(format!(
"udp connect error, decode error. buf: {:?}",
buf
)));
};
if conn_id != udp_packet.conn_id {
return Err(super::TunnelError::ConnectError(format!(
"udp connect error, conn id not match. conn_id: {:?}, {:?}",
conn_id, udp_packet.conn_id
)));
}
if !matches!(udp_packet.payload, ArchivedUdpPacketPayload::Sack) {
return Err(super::TunnelError::ConnectError(format!(
"udp connect error, unexpected payload. payload: {:?}",
udp_packet.payload
)));
}
Ok(())
}
async fn wait_sack_loop(
socket: &UdpSocket,
addr: SocketAddr,
conn_id: u32,
) -> Result<(), super::TunnelError> {
while let Err(err) = Self::wait_sack(socket, addr, conn_id).await {
tracing::warn!(?err, "udp wait sack error");
}
Ok(())
}
pub async fn try_connect_with_socket(
&self,
socket: UdpSocket,
) -> Result<Box<dyn super::Tunnel>, super::TunnelError> {
let addr = super::check_scheme_and_get_socket_addr::<SocketAddr>(&self.addr, "udp")?;
log::warn!("udp connect: {:?}", self.addr);
// send syn
let conn_id = rand::random();
let udp_packet = UdpPacket::new_syn_packet(conn_id);
let b = encode_to_bytes::<_, UDP_DATA_MTU>(&udp_packet);
let ret = socket.send_to(&b, &addr).await?;
tracing::warn!(?udp_packet, ?ret, "udp send syn");
// wait sack
tokio::time::timeout(
tokio::time::Duration::from_secs(3),
Self::wait_sack_loop(&socket, addr, conn_id),
)
.await??;
// sack done
let local_addr = socket.local_addr().unwrap().to_string();
Ok(Box::new(TunnelWithCustomInfo::new(
get_tunnel_from_socket(Arc::new(socket), addr, conn_id),
TunnelInfo {
tunnel_type: "udp".to_owned(),
local_addr: super::build_url_from_socket_addr(&local_addr, "udp").into(),
remote_addr: self.remote_url().into(),
},
)))
}
async fn connect_with_default_bind(&mut self) -> Result<Box<dyn Tunnel>, super::TunnelError> {
let socket = UdpSocket::bind("0.0.0.0:0").await?;
return self.try_connect_with_socket(socket).await;
}
async fn connect_with_custom_bind(&mut self) -> Result<Box<dyn Tunnel>, super::TunnelError> {
let mut futures = FuturesUnordered::new();
for bind_addr in self.bind_addrs.iter() {
let socket = UdpSocket::bind(*bind_addr).await?;
// linux does not use interface of bind_addr to send packet, so we need to bind device
// mac can handle this with bind correctly
#[cfg(any(target_os = "android", target_os = "fuchsia", target_os = "linux"))]
if let Some(dev_name) = super::common::get_interface_name_by_ip(&bind_addr.ip()) {
tracing::trace!(dev_name = ?dev_name, "bind device");
socket.bind_device(Some(dev_name.as_bytes()))?;
}
futures.push(self.try_connect_with_socket(socket));
}
let Some(ret) = futures.next().await else {
return Err(super::TunnelError::CommonError(
"join connect futures failed".to_owned(),
));
};
return ret;
}
}
#[async_trait]
impl super::TunnelConnector for UdpTunnelConnector {
async fn connect(&mut self) -> Result<Box<dyn super::Tunnel>, super::TunnelError> {
if self.bind_addrs.is_empty() {
self.connect_with_default_bind().await
} else {
self.connect_with_custom_bind().await
}
}
fn remote_url(&self) -> url::Url {
self.addr.clone()
}
fn set_bind_addrs(&mut self, addrs: Vec<SocketAddr>) {
self.bind_addrs = addrs;
}
}
#[cfg(test)]
mod tests {
use crate::tunnels::common::tests::{_tunnel_bench, _tunnel_pingpong};
use super::*;
#[tokio::test]
async fn udp_pingpong() {
let listener = UdpTunnelListener::new("udp://0.0.0.0:5556".parse().unwrap());
let connector = UdpTunnelConnector::new("udp://127.0.0.1:5556".parse().unwrap());
_tunnel_pingpong(listener, connector).await
}
#[tokio::test]
async fn udp_bench() {
let listener = UdpTunnelListener::new("udp://0.0.0.0:5555".parse().unwrap());
let connector = UdpTunnelConnector::new("udp://127.0.0.1:5555".parse().unwrap());
_tunnel_bench(listener, connector).await
}
#[tokio::test]
async fn udp_bench_with_bind() {
let listener = UdpTunnelListener::new("udp://127.0.0.1:5554".parse().unwrap());
let mut connector = UdpTunnelConnector::new("udp://127.0.0.1:5554".parse().unwrap());
connector.set_bind_addrs(vec!["127.0.0.1:0".parse().unwrap()]);
_tunnel_pingpong(listener, connector).await
}
#[tokio::test]
#[should_panic]
async fn udp_bench_with_bind_fail() {
let listener = UdpTunnelListener::new("udp://127.0.0.1:5553".parse().unwrap());
let mut connector = UdpTunnelConnector::new("udp://127.0.0.1:5553".parse().unwrap());
connector.set_bind_addrs(vec!["10.0.0.1:0".parse().unwrap()]);
_tunnel_pingpong(listener, connector).await
}
}