Initial Version
This commit is contained in:
@@ -0,0 +1,54 @@
|
||||
use std::result::Result;
|
||||
use tokio::io;
|
||||
use tokio_util::{
|
||||
bytes::{BufMut, Bytes, BytesMut},
|
||||
codec::{Decoder, Encoder},
|
||||
};
|
||||
|
||||
#[derive(Copy, Clone, Debug, Eq, PartialEq, Ord, PartialOrd, Hash, Default)]
|
||||
pub struct BytesCodec {
|
||||
capacity: usize,
|
||||
}
|
||||
|
||||
impl BytesCodec {
|
||||
/// Creates a new `BytesCodec` for shipping around raw bytes.
|
||||
pub fn new(capacity: usize) -> BytesCodec {
|
||||
BytesCodec { capacity }
|
||||
}
|
||||
}
|
||||
|
||||
impl Decoder for BytesCodec {
|
||||
type Item = BytesMut;
|
||||
type Error = io::Error;
|
||||
|
||||
fn decode(&mut self, buf: &mut BytesMut) -> Result<Option<BytesMut>, io::Error> {
|
||||
if !buf.is_empty() {
|
||||
let len = buf.len();
|
||||
let ret = Some(buf.split_to(len));
|
||||
buf.reserve(self.capacity);
|
||||
Ok(ret)
|
||||
} else {
|
||||
Ok(None)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Encoder<Bytes> for BytesCodec {
|
||||
type Error = io::Error;
|
||||
|
||||
fn encode(&mut self, data: Bytes, buf: &mut BytesMut) -> Result<(), io::Error> {
|
||||
buf.reserve(data.len());
|
||||
buf.put(data);
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
impl Encoder<BytesMut> for BytesCodec {
|
||||
type Error = io::Error;
|
||||
|
||||
fn encode(&mut self, data: BytesMut, buf: &mut BytesMut) -> Result<(), io::Error> {
|
||||
buf.reserve(data.len());
|
||||
buf.put(data);
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,399 @@
|
||||
use std::{
|
||||
collections::VecDeque,
|
||||
net::IpAddr,
|
||||
sync::Arc,
|
||||
task::{ready, Context, Poll},
|
||||
};
|
||||
|
||||
use async_stream::stream;
|
||||
use futures::{Future, FutureExt, Sink, SinkExt, Stream, StreamExt};
|
||||
use tokio::{sync::Mutex, time::error::Elapsed};
|
||||
|
||||
use std::pin::Pin;
|
||||
|
||||
use crate::tunnels::{SinkError, TunnelError};
|
||||
|
||||
use super::{DatagramSink, DatagramStream, SinkItem, StreamT, Tunnel, TunnelInfo};
|
||||
|
||||
pub struct FramedTunnel<R, W> {
|
||||
read: Arc<Mutex<R>>,
|
||||
write: Arc<Mutex<W>>,
|
||||
|
||||
info: Option<TunnelInfo>,
|
||||
}
|
||||
|
||||
impl<R, RE, W, WE> FramedTunnel<R, W>
|
||||
where
|
||||
R: Stream<Item = Result<StreamT, RE>> + Send + Sync + Unpin + 'static,
|
||||
W: Sink<SinkItem, Error = WE> + Send + Sync + Unpin + 'static,
|
||||
RE: std::error::Error + std::fmt::Debug + Send + Sync + 'static,
|
||||
WE: std::error::Error + std::fmt::Debug + Send + Sync + 'static + From<Elapsed>,
|
||||
{
|
||||
pub fn new(read: R, write: W, info: Option<TunnelInfo>) -> Self {
|
||||
FramedTunnel {
|
||||
read: Arc::new(Mutex::new(read)),
|
||||
write: Arc::new(Mutex::new(write)),
|
||||
info,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn new_tunnel_with_info(read: R, write: W, info: TunnelInfo) -> Box<dyn Tunnel> {
|
||||
Box::new(FramedTunnel::new(read, write, Some(info)))
|
||||
}
|
||||
|
||||
pub fn recv_stream(&self) -> impl DatagramStream {
|
||||
let read = self.read.clone();
|
||||
let info = self.info.clone();
|
||||
stream! {
|
||||
loop {
|
||||
let read_ret = read.lock().await.next().await;
|
||||
if read_ret.is_none() {
|
||||
tracing::info!(?info, "read_ret is none");
|
||||
yield Err(TunnelError::CommonError("recv stream closed".to_string()));
|
||||
} else {
|
||||
let read_ret = read_ret.unwrap();
|
||||
if read_ret.is_err() {
|
||||
let err = read_ret.err().unwrap();
|
||||
tracing::info!(?info, "recv stream read error");
|
||||
yield Err(TunnelError::CommonError(err.to_string()));
|
||||
} else {
|
||||
yield Ok(read_ret.unwrap());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn send_sink(&self) -> impl DatagramSink {
|
||||
struct SendSink<W, WE> {
|
||||
write: Arc<Mutex<W>>,
|
||||
max_buffer_size: usize,
|
||||
sending_buffers: Option<VecDeque<SinkItem>>,
|
||||
send_task:
|
||||
Option<Pin<Box<dyn Future<Output = Result<(), WE>> + Send + Sync + 'static>>>,
|
||||
close_task:
|
||||
Option<Pin<Box<dyn Future<Output = Result<(), WE>> + Send + Sync + 'static>>>,
|
||||
}
|
||||
|
||||
impl<W, WE> SendSink<W, WE>
|
||||
where
|
||||
W: Sink<SinkItem, Error = WE> + Send + Sync + Unpin + 'static,
|
||||
WE: std::error::Error + std::fmt::Debug + Send + Sync + From<Elapsed>,
|
||||
{
|
||||
fn try_send_buffser(
|
||||
&mut self,
|
||||
cx: &mut Context<'_>,
|
||||
) -> Poll<std::result::Result<(), WE>> {
|
||||
if self.send_task.is_none() {
|
||||
let mut buffers = self.sending_buffers.take().unwrap();
|
||||
let tun = self.write.clone();
|
||||
let send_task = async move {
|
||||
if buffers.is_empty() {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let mut locked_tun = tun.lock_owned().await;
|
||||
while let Some(buf) = buffers.front() {
|
||||
log::trace!(
|
||||
"try_send buffer, len: {:?}, buf: {:?}",
|
||||
buffers.len(),
|
||||
&buf
|
||||
);
|
||||
let timeout_task = tokio::time::timeout(
|
||||
std::time::Duration::from_secs(1),
|
||||
locked_tun.send(buf.clone()),
|
||||
);
|
||||
let send_res = timeout_task.await;
|
||||
let Ok(send_res) = send_res else {
|
||||
// panic!("send timeout");
|
||||
let err = send_res.err().unwrap();
|
||||
return Err(err.into());
|
||||
};
|
||||
let Ok(_) = send_res else {
|
||||
let err = send_res.err().unwrap();
|
||||
println!("send error: {:?}", err);
|
||||
return Err(err);
|
||||
};
|
||||
buffers.pop_front();
|
||||
}
|
||||
return Ok(());
|
||||
};
|
||||
self.send_task = Some(Box::pin(send_task));
|
||||
}
|
||||
|
||||
let ret = ready!(self.send_task.as_mut().unwrap().poll_unpin(cx));
|
||||
self.send_task = None;
|
||||
self.sending_buffers = Some(VecDeque::new());
|
||||
return Poll::Ready(ret);
|
||||
}
|
||||
}
|
||||
|
||||
impl<W, WE> Sink<SinkItem> for SendSink<W, WE>
|
||||
where
|
||||
W: Sink<SinkItem, Error = WE> + Send + Sync + Unpin + 'static,
|
||||
WE: std::error::Error + std::fmt::Debug + Send + Sync + From<Elapsed>,
|
||||
{
|
||||
type Error = SinkError;
|
||||
|
||||
fn poll_ready(
|
||||
self: Pin<&mut Self>,
|
||||
cx: &mut std::task::Context<'_>,
|
||||
) -> Poll<Result<(), Self::Error>> {
|
||||
let self_mut = self.get_mut();
|
||||
let sending_buf = self_mut.sending_buffers.as_ref();
|
||||
// if sending_buffers is None, must already be doing flush
|
||||
if sending_buf.is_none() || sending_buf.unwrap().len() > self_mut.max_buffer_size {
|
||||
return self_mut.poll_flush_unpin(cx);
|
||||
} else {
|
||||
return Poll::Ready(Ok(()));
|
||||
}
|
||||
}
|
||||
|
||||
fn start_send(self: Pin<&mut Self>, item: SinkItem) -> Result<(), Self::Error> {
|
||||
assert!(self.send_task.is_none());
|
||||
let self_mut = self.get_mut();
|
||||
self_mut.sending_buffers.as_mut().unwrap().push_back(item);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn poll_flush(
|
||||
self: Pin<&mut Self>,
|
||||
cx: &mut std::task::Context<'_>,
|
||||
) -> Poll<Result<(), Self::Error>> {
|
||||
let self_mut = self.get_mut();
|
||||
let ret = self_mut.try_send_buffser(cx);
|
||||
match ret {
|
||||
Poll::Ready(Ok(())) => Poll::Ready(Ok(())),
|
||||
Poll::Ready(Err(e)) => Poll::Ready(Err(SinkError::CommonError(e.to_string()))),
|
||||
Poll::Pending => {
|
||||
return Poll::Pending;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn poll_close(
|
||||
self: Pin<&mut Self>,
|
||||
cx: &mut std::task::Context<'_>,
|
||||
) -> Poll<Result<(), Self::Error>> {
|
||||
let self_mut = self.get_mut();
|
||||
if self_mut.close_task.is_none() {
|
||||
let tun = self_mut.write.clone();
|
||||
let close_task = async move {
|
||||
let mut locked_tun = tun.lock_owned().await;
|
||||
return locked_tun.close().await;
|
||||
};
|
||||
self_mut.close_task = Some(Box::pin(close_task));
|
||||
}
|
||||
|
||||
let ret = ready!(self_mut.close_task.as_mut().unwrap().poll_unpin(cx));
|
||||
self_mut.close_task = None;
|
||||
|
||||
if ret.is_err() {
|
||||
return Poll::Ready(Err(SinkError::CommonError(
|
||||
ret.err().unwrap().to_string(),
|
||||
)));
|
||||
} else {
|
||||
return Poll::Ready(Ok(()));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
SendSink {
|
||||
write: self.write.clone(),
|
||||
max_buffer_size: 1000,
|
||||
sending_buffers: Some(VecDeque::new()),
|
||||
send_task: None,
|
||||
close_task: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<R, RE, W, WE> Tunnel for FramedTunnel<R, W>
|
||||
where
|
||||
R: Stream<Item = Result<StreamT, RE>> + Send + Sync + Unpin + 'static,
|
||||
W: Sink<SinkItem, Error = WE> + Send + Sync + Unpin + 'static,
|
||||
RE: std::error::Error + std::fmt::Debug + Send + Sync + 'static,
|
||||
WE: std::error::Error + std::fmt::Debug + Send + Sync + 'static + From<Elapsed>,
|
||||
{
|
||||
fn stream(&self) -> Box<dyn DatagramStream> {
|
||||
Box::new(self.recv_stream())
|
||||
}
|
||||
|
||||
fn sink(&self) -> Box<dyn DatagramSink> {
|
||||
Box::new(self.send_sink())
|
||||
}
|
||||
|
||||
fn info(&self) -> Option<TunnelInfo> {
|
||||
if self.info.is_none() {
|
||||
None
|
||||
} else {
|
||||
Some(self.info.clone().unwrap())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct TunnelWithCustomInfo {
|
||||
tunnel: Box<dyn Tunnel>,
|
||||
info: TunnelInfo,
|
||||
}
|
||||
|
||||
impl TunnelWithCustomInfo {
|
||||
pub fn new(tunnel: Box<dyn Tunnel>, info: TunnelInfo) -> Self {
|
||||
TunnelWithCustomInfo { tunnel, info }
|
||||
}
|
||||
}
|
||||
|
||||
impl Tunnel for TunnelWithCustomInfo {
|
||||
fn stream(&self) -> Box<dyn DatagramStream> {
|
||||
self.tunnel.stream()
|
||||
}
|
||||
|
||||
fn sink(&self) -> Box<dyn DatagramSink> {
|
||||
self.tunnel.sink()
|
||||
}
|
||||
|
||||
fn info(&self) -> Option<TunnelInfo> {
|
||||
Some(self.info.clone())
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn get_interface_name_by_ip(local_ip: &IpAddr) -> Option<String> {
|
||||
let ifaces = pnet::datalink::interfaces();
|
||||
for iface in ifaces {
|
||||
for ip in iface.ips {
|
||||
if ip.ip() == *local_ip {
|
||||
return Some(iface.name);
|
||||
}
|
||||
}
|
||||
}
|
||||
None
|
||||
}
|
||||
|
||||
pub mod tests {
|
||||
use std::time::Instant;
|
||||
|
||||
use futures::SinkExt;
|
||||
use tokio_stream::StreamExt;
|
||||
use tokio_util::bytes::{BufMut, Bytes, BytesMut};
|
||||
|
||||
use crate::{
|
||||
common::netns::NetNS,
|
||||
tunnels::{close_tunnel, TunnelConnector, TunnelListener},
|
||||
};
|
||||
|
||||
pub async fn _tunnel_echo_server(tunnel: Box<dyn super::Tunnel>, once: bool) {
|
||||
let mut recv = Box::into_pin(tunnel.stream());
|
||||
let mut send = Box::into_pin(tunnel.sink());
|
||||
|
||||
while let Some(ret) = recv.next().await {
|
||||
if ret.is_err() {
|
||||
log::trace!("recv error: {:?}", ret.err().unwrap());
|
||||
break;
|
||||
}
|
||||
let res = ret.unwrap();
|
||||
log::trace!("recv a msg, try echo back: {:?}", res);
|
||||
send.send(Bytes::from(res)).await.unwrap();
|
||||
if once {
|
||||
break;
|
||||
}
|
||||
}
|
||||
log::warn!("echo server exit...");
|
||||
}
|
||||
|
||||
pub(crate) async fn _tunnel_pingpong<L, C>(listener: L, connector: C)
|
||||
where
|
||||
L: TunnelListener + Send + Sync + 'static,
|
||||
C: TunnelConnector + Send + Sync + 'static,
|
||||
{
|
||||
_tunnel_pingpong_netns(listener, connector, NetNS::new(None), NetNS::new(None)).await
|
||||
}
|
||||
|
||||
pub(crate) async fn _tunnel_pingpong_netns<L, C>(
|
||||
mut listener: L,
|
||||
mut connector: C,
|
||||
l_netns: NetNS,
|
||||
c_netns: NetNS,
|
||||
) where
|
||||
L: TunnelListener + Send + Sync + 'static,
|
||||
C: TunnelConnector + Send + Sync + 'static,
|
||||
{
|
||||
l_netns
|
||||
.run_async(|| async {
|
||||
listener.listen().await.unwrap();
|
||||
})
|
||||
.await;
|
||||
|
||||
let lis = tokio::spawn(async move {
|
||||
let ret = listener.accept().await.unwrap();
|
||||
assert_eq!(
|
||||
ret.info().unwrap().local_addr,
|
||||
listener.local_url().to_string()
|
||||
);
|
||||
_tunnel_echo_server(ret, false).await
|
||||
});
|
||||
|
||||
let tunnel = c_netns.run_async(|| connector.connect()).await.unwrap();
|
||||
|
||||
assert_eq!(
|
||||
tunnel.info().unwrap().remote_addr,
|
||||
connector.remote_url().to_string()
|
||||
);
|
||||
|
||||
let mut send = tunnel.pin_sink();
|
||||
let mut recv = tunnel.pin_stream();
|
||||
let send_data = Bytes::from("abc");
|
||||
send.send(send_data).await.unwrap();
|
||||
let ret = tokio::time::timeout(tokio::time::Duration::from_secs(1), recv.next())
|
||||
.await
|
||||
.unwrap()
|
||||
.unwrap()
|
||||
.unwrap();
|
||||
println!("echo back: {:?}", ret);
|
||||
assert_eq!(ret, Bytes::from("abc"));
|
||||
|
||||
close_tunnel(&tunnel).await.unwrap();
|
||||
|
||||
if connector.remote_url().scheme() == "udp" {
|
||||
lis.abort();
|
||||
} else {
|
||||
// lis should finish in 1 second
|
||||
let ret = tokio::time::timeout(tokio::time::Duration::from_secs(1), lis).await;
|
||||
assert!(ret.is_ok());
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) async fn _tunnel_bench<L, C>(mut listener: L, mut connector: C)
|
||||
where
|
||||
L: TunnelListener + Send + Sync + 'static,
|
||||
C: TunnelConnector + Send + Sync + 'static,
|
||||
{
|
||||
listener.listen().await.unwrap();
|
||||
|
||||
let lis = tokio::spawn(async move {
|
||||
let ret = listener.accept().await.unwrap();
|
||||
_tunnel_echo_server(ret, false).await
|
||||
});
|
||||
|
||||
let tunnel = connector.connect().await.unwrap();
|
||||
|
||||
let mut send = tunnel.pin_sink();
|
||||
let mut recv = tunnel.pin_stream();
|
||||
|
||||
// prepare a 4k buffer with random data
|
||||
let mut send_buf = BytesMut::new();
|
||||
for _ in 0..64 {
|
||||
send_buf.put_i128(rand::random::<i128>());
|
||||
}
|
||||
|
||||
let now = Instant::now();
|
||||
let mut count = 0;
|
||||
while now.elapsed().as_secs() < 3 {
|
||||
send.send(send_buf.clone().freeze()).await.unwrap();
|
||||
let _ = recv.next().await.unwrap().unwrap();
|
||||
count += 1;
|
||||
}
|
||||
println!("bps: {}", (count / 1024) * 4 / now.elapsed().as_secs());
|
||||
|
||||
lis.abort();
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,159 @@
|
||||
pub mod codec;
|
||||
pub mod common;
|
||||
pub mod ring_tunnel;
|
||||
pub mod stats;
|
||||
pub mod tcp_tunnel;
|
||||
pub mod tunnel_filter;
|
||||
pub mod udp_tunnel;
|
||||
|
||||
use std::{fmt::Debug, net::SocketAddr, pin::Pin, sync::Arc};
|
||||
|
||||
use async_trait::async_trait;
|
||||
use easytier_rpc::TunnelInfo;
|
||||
use futures::{Sink, SinkExt, Stream};
|
||||
|
||||
use thiserror::Error;
|
||||
use tokio_util::bytes::{Bytes, BytesMut};
|
||||
|
||||
#[derive(Error, Debug)]
|
||||
pub enum TunnelError {
|
||||
#[error("Error: {0}")]
|
||||
CommonError(String),
|
||||
#[error("io error")]
|
||||
IOError(#[from] std::io::Error),
|
||||
#[error("wait resp error")]
|
||||
WaitRespError(String),
|
||||
#[error("Connect Error: {0}")]
|
||||
ConnectError(String),
|
||||
#[error("Invalid Protocol: {0}")]
|
||||
InvalidProtocol(String),
|
||||
#[error("Invalid Addr: {0}")]
|
||||
InvalidAddr(String),
|
||||
#[error("Tun Error: {0}")]
|
||||
TunError(String),
|
||||
#[error("timeout")]
|
||||
Timeout(#[from] tokio::time::error::Elapsed),
|
||||
}
|
||||
|
||||
pub type StreamT = BytesMut;
|
||||
pub type StreamItem = Result<StreamT, TunnelError>;
|
||||
pub type SinkItem = Bytes;
|
||||
pub type SinkError = TunnelError;
|
||||
|
||||
pub trait DatagramStream: Stream<Item = StreamItem> + Send + Sync {}
|
||||
impl<T> DatagramStream for T where T: Stream<Item = StreamItem> + Send + Sync {}
|
||||
pub trait DatagramSink: Sink<SinkItem, Error = SinkError> + Send + Sync {}
|
||||
impl<T> DatagramSink for T where T: Sink<SinkItem, Error = SinkError> + Send + Sync {}
|
||||
|
||||
#[auto_impl::auto_impl(Box, Arc)]
|
||||
pub trait Tunnel: Send + Sync {
|
||||
fn stream(&self) -> Box<dyn DatagramStream>;
|
||||
fn sink(&self) -> Box<dyn DatagramSink>;
|
||||
|
||||
fn pin_stream(&self) -> Pin<Box<dyn DatagramStream>> {
|
||||
Box::into_pin(self.stream())
|
||||
}
|
||||
|
||||
fn pin_sink(&self) -> Pin<Box<dyn DatagramSink>> {
|
||||
Box::into_pin(self.sink())
|
||||
}
|
||||
|
||||
fn info(&self) -> Option<TunnelInfo>;
|
||||
}
|
||||
|
||||
pub async fn close_tunnel(t: &Box<dyn Tunnel>) -> Result<(), TunnelError> {
|
||||
t.pin_sink().close().await
|
||||
}
|
||||
|
||||
#[auto_impl::auto_impl(Arc)]
|
||||
pub trait TunnelConnCounter: 'static + Send + Sync + Debug {
|
||||
fn get(&self) -> u32;
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
#[auto_impl::auto_impl(Box)]
|
||||
pub trait TunnelListener: Send + Sync {
|
||||
async fn listen(&mut self) -> Result<(), TunnelError>;
|
||||
async fn accept(&mut self) -> Result<Box<dyn Tunnel>, TunnelError>;
|
||||
fn local_url(&self) -> url::Url;
|
||||
fn get_conn_counter(&self) -> Arc<Box<dyn TunnelConnCounter>> {
|
||||
#[derive(Debug)]
|
||||
struct FakeTunnelConnCounter {}
|
||||
impl TunnelConnCounter for FakeTunnelConnCounter {
|
||||
fn get(&self) -> u32 {
|
||||
0
|
||||
}
|
||||
}
|
||||
Arc::new(Box::new(FakeTunnelConnCounter {}))
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
#[auto_impl::auto_impl(Box)]
|
||||
pub trait TunnelConnector {
|
||||
async fn connect(&mut self) -> Result<Box<dyn Tunnel>, TunnelError>;
|
||||
fn remote_url(&self) -> url::Url;
|
||||
fn set_bind_addrs(&mut self, _addrs: Vec<SocketAddr>) {}
|
||||
}
|
||||
|
||||
pub fn build_url_from_socket_addr(addr: &String, scheme: &str) -> url::Url {
|
||||
url::Url::parse(format!("{}://{}", scheme, addr).as_str()).unwrap()
|
||||
}
|
||||
|
||||
impl std::fmt::Debug for dyn Tunnel {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
f.debug_struct("Tunnel")
|
||||
.field("info", &self.info())
|
||||
.finish()
|
||||
}
|
||||
}
|
||||
|
||||
impl std::fmt::Debug for dyn TunnelConnector + Sync + Send {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
f.debug_struct("TunnelConnector")
|
||||
.field("remote_url", &self.remote_url())
|
||||
.finish()
|
||||
}
|
||||
}
|
||||
|
||||
impl std::fmt::Debug for dyn TunnelListener {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
f.debug_struct("TunnelListener")
|
||||
.field("local_url", &self.local_url())
|
||||
.finish()
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) trait FromUrl {
|
||||
fn from_url(url: url::Url) -> Result<Self, TunnelError>
|
||||
where
|
||||
Self: Sized;
|
||||
}
|
||||
|
||||
pub(crate) fn check_scheme_and_get_socket_addr<T>(
|
||||
url: &url::Url,
|
||||
scheme: &str,
|
||||
) -> Result<T, TunnelError>
|
||||
where
|
||||
T: FromUrl,
|
||||
{
|
||||
if url.scheme() != scheme {
|
||||
return Err(TunnelError::InvalidProtocol(url.scheme().to_string()));
|
||||
}
|
||||
|
||||
Ok(T::from_url(url.clone())?)
|
||||
}
|
||||
|
||||
impl FromUrl for SocketAddr {
|
||||
fn from_url(url: url::Url) -> Result<Self, TunnelError> {
|
||||
Ok(url.socket_addrs(|| None)?.pop().unwrap())
|
||||
}
|
||||
}
|
||||
|
||||
impl FromUrl for uuid::Uuid {
|
||||
fn from_url(url: url::Url) -> Result<Self, TunnelError> {
|
||||
let o = url.host_str().unwrap();
|
||||
let o = uuid::Uuid::parse_str(o).map_err(|e| TunnelError::InvalidAddr(e.to_string()))?;
|
||||
Ok(o)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,391 @@
|
||||
use std::{
|
||||
collections::HashMap,
|
||||
sync::{atomic::AtomicBool, Arc},
|
||||
task::Poll,
|
||||
};
|
||||
|
||||
use async_stream::stream;
|
||||
use crossbeam_queue::ArrayQueue;
|
||||
|
||||
use async_trait::async_trait;
|
||||
use futures::Sink;
|
||||
use once_cell::sync::Lazy;
|
||||
use tokio::sync::{Mutex, Notify};
|
||||
|
||||
use futures::FutureExt;
|
||||
use tokio_util::bytes::BytesMut;
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::tunnels::{SinkError, SinkItem};
|
||||
|
||||
use super::{
|
||||
build_url_from_socket_addr, check_scheme_and_get_socket_addr, DatagramSink, DatagramStream,
|
||||
Tunnel, TunnelConnector, TunnelError, TunnelInfo, TunnelListener,
|
||||
};
|
||||
|
||||
static RING_TUNNEL_CAP: usize = 1000;
|
||||
|
||||
pub struct RingTunnel {
|
||||
id: Uuid,
|
||||
ring: Arc<ArrayQueue<SinkItem>>,
|
||||
consume_notify: Arc<Notify>,
|
||||
produce_notify: Arc<Notify>,
|
||||
closed: Arc<AtomicBool>,
|
||||
}
|
||||
|
||||
impl RingTunnel {
|
||||
pub fn new(cap: usize) -> Self {
|
||||
RingTunnel {
|
||||
id: Uuid::new_v4(),
|
||||
ring: Arc::new(ArrayQueue::new(cap)),
|
||||
consume_notify: Arc::new(Notify::new()),
|
||||
produce_notify: Arc::new(Notify::new()),
|
||||
closed: Arc::new(AtomicBool::new(false)),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn new_with_id(id: Uuid, cap: usize) -> Self {
|
||||
let mut ret = Self::new(cap);
|
||||
ret.id = id;
|
||||
ret
|
||||
}
|
||||
|
||||
fn recv_stream(&self) -> impl DatagramStream {
|
||||
let ring = self.ring.clone();
|
||||
let produce_notify = self.produce_notify.clone();
|
||||
let consume_notify = self.consume_notify.clone();
|
||||
let closed = self.closed.clone();
|
||||
let id = self.id;
|
||||
stream! {
|
||||
loop {
|
||||
if closed.load(std::sync::atomic::Ordering::Relaxed) {
|
||||
log::warn!("ring recv tunnel {:?} closed", id);
|
||||
yield Err(TunnelError::CommonError("Closed".to_owned()));
|
||||
}
|
||||
match ring.pop() {
|
||||
Some(v) => {
|
||||
let mut out = BytesMut::new();
|
||||
out.extend_from_slice(&v);
|
||||
consume_notify.notify_one();
|
||||
log::trace!("id: {}, recv buffer, len: {:?}, buf: {:?}", id, v.len(), &v);
|
||||
yield Ok(out);
|
||||
},
|
||||
None => {
|
||||
log::trace!("waiting recv buffer, id: {}", id);
|
||||
produce_notify.notified().await;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn send_sink(&self) -> impl DatagramSink {
|
||||
let ring = self.ring.clone();
|
||||
let produce_notify = self.produce_notify.clone();
|
||||
let consume_notify = self.consume_notify.clone();
|
||||
let closed = self.closed.clone();
|
||||
let id = self.id;
|
||||
|
||||
// type T = RingTunnel;
|
||||
|
||||
use tokio::task::JoinHandle;
|
||||
|
||||
struct T {
|
||||
ring: RingTunnel,
|
||||
wait_consume_task: Option<JoinHandle<()>>,
|
||||
}
|
||||
|
||||
impl T {
|
||||
fn wait_ring_consume(
|
||||
self: std::pin::Pin<&mut Self>,
|
||||
cx: &mut std::task::Context<'_>,
|
||||
expected_size: usize,
|
||||
) -> std::task::Poll<()> {
|
||||
let self_mut = self.get_mut();
|
||||
if self_mut.ring.ring.len() <= expected_size {
|
||||
return Poll::Ready(());
|
||||
}
|
||||
if self_mut.wait_consume_task.is_none() {
|
||||
let id = self_mut.ring.id;
|
||||
let consume_notify = self_mut.ring.consume_notify.clone();
|
||||
let ring = self_mut.ring.ring.clone();
|
||||
let task = async move {
|
||||
log::trace!(
|
||||
"waiting ring consume done, expected_size: {}, id: {}",
|
||||
expected_size,
|
||||
id
|
||||
);
|
||||
while ring.len() > expected_size {
|
||||
consume_notify.notified().await;
|
||||
}
|
||||
log::trace!(
|
||||
"ring consume done, expected_size: {}, id: {}",
|
||||
expected_size,
|
||||
id
|
||||
);
|
||||
};
|
||||
self_mut.wait_consume_task = Some(tokio::spawn(task));
|
||||
}
|
||||
let task = self_mut.wait_consume_task.as_mut().unwrap();
|
||||
match task.poll_unpin(cx) {
|
||||
Poll::Ready(_) => {
|
||||
self_mut.wait_consume_task = None;
|
||||
Poll::Ready(())
|
||||
}
|
||||
Poll::Pending => Poll::Pending,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Sink<SinkItem> for T {
|
||||
type Error = SinkError;
|
||||
|
||||
fn poll_ready(
|
||||
self: std::pin::Pin<&mut Self>,
|
||||
cx: &mut std::task::Context<'_>,
|
||||
) -> std::task::Poll<Result<(), Self::Error>> {
|
||||
let expected_size = self.ring.ring.capacity() - 1;
|
||||
match self.wait_ring_consume(cx, expected_size) {
|
||||
Poll::Ready(_) => Poll::Ready(Ok(())),
|
||||
Poll::Pending => Poll::Pending,
|
||||
}
|
||||
}
|
||||
|
||||
fn start_send(
|
||||
self: std::pin::Pin<&mut Self>,
|
||||
item: SinkItem,
|
||||
) -> Result<(), Self::Error> {
|
||||
log::trace!("id: {}, send buffer, buf: {:?}", self.ring.id, &item);
|
||||
self.ring.ring.push(item).unwrap();
|
||||
self.ring.produce_notify.notify_one();
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn poll_flush(
|
||||
self: std::pin::Pin<&mut Self>,
|
||||
_cx: &mut std::task::Context<'_>,
|
||||
) -> std::task::Poll<Result<(), Self::Error>> {
|
||||
Poll::Ready(Ok(()))
|
||||
}
|
||||
|
||||
fn poll_close(
|
||||
self: std::pin::Pin<&mut Self>,
|
||||
_cx: &mut std::task::Context<'_>,
|
||||
) -> std::task::Poll<Result<(), Self::Error>> {
|
||||
self.ring
|
||||
.closed
|
||||
.store(true, std::sync::atomic::Ordering::Relaxed);
|
||||
log::warn!("ring tunnel send {:?} closed", self.ring.id);
|
||||
self.ring.produce_notify.notify_one();
|
||||
Poll::Ready(Ok(()))
|
||||
}
|
||||
}
|
||||
|
||||
T {
|
||||
ring: RingTunnel {
|
||||
id,
|
||||
ring,
|
||||
consume_notify,
|
||||
produce_notify,
|
||||
closed,
|
||||
},
|
||||
wait_consume_task: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
struct Connection {
|
||||
client: RingTunnel,
|
||||
server: RingTunnel,
|
||||
connect_notify: Arc<Notify>,
|
||||
}
|
||||
|
||||
impl Tunnel for RingTunnel {
|
||||
fn stream(&self) -> Box<dyn DatagramStream> {
|
||||
Box::new(self.recv_stream())
|
||||
}
|
||||
|
||||
fn sink(&self) -> Box<dyn DatagramSink> {
|
||||
Box::new(self.send_sink())
|
||||
}
|
||||
|
||||
fn info(&self) -> Option<TunnelInfo> {
|
||||
None
|
||||
// Some(TunnelInfo {
|
||||
// tunnel_type: "ring".to_owned(),
|
||||
// local_addr: format!("ring://{}", self.id),
|
||||
// remote_addr: format!("ring://{}", self.id),
|
||||
// })
|
||||
}
|
||||
}
|
||||
|
||||
static CONNECTION_MAP: Lazy<Arc<Mutex<HashMap<uuid::Uuid, Arc<Connection>>>>> =
|
||||
Lazy::new(|| Arc::new(Mutex::new(HashMap::new())));
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct RingTunnelListener {
|
||||
listerner_addr: url::Url,
|
||||
}
|
||||
|
||||
impl RingTunnelListener {
|
||||
pub fn new(key: url::Url) -> Self {
|
||||
RingTunnelListener {
|
||||
listerner_addr: key,
|
||||
}
|
||||
}
|
||||
}
|
||||
struct ConnectionForServer {
|
||||
conn: Arc<Connection>,
|
||||
}
|
||||
|
||||
impl Tunnel for ConnectionForServer {
|
||||
fn stream(&self) -> Box<dyn DatagramStream> {
|
||||
Box::new(self.conn.server.recv_stream())
|
||||
}
|
||||
|
||||
fn sink(&self) -> Box<dyn DatagramSink> {
|
||||
Box::new(self.conn.client.send_sink())
|
||||
}
|
||||
|
||||
fn info(&self) -> Option<TunnelInfo> {
|
||||
Some(TunnelInfo {
|
||||
tunnel_type: "ring".to_owned(),
|
||||
local_addr: build_url_from_socket_addr(&self.conn.server.id.into(), "ring").into(),
|
||||
remote_addr: build_url_from_socket_addr(&self.conn.client.id.into(), "ring").into(),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
struct ConnectionForClient {
|
||||
conn: Arc<Connection>,
|
||||
}
|
||||
|
||||
impl Tunnel for ConnectionForClient {
|
||||
fn stream(&self) -> Box<dyn DatagramStream> {
|
||||
Box::new(self.conn.client.recv_stream())
|
||||
}
|
||||
|
||||
fn sink(&self) -> Box<dyn DatagramSink> {
|
||||
Box::new(self.conn.server.send_sink())
|
||||
}
|
||||
|
||||
fn info(&self) -> Option<TunnelInfo> {
|
||||
Some(TunnelInfo {
|
||||
tunnel_type: "ring".to_owned(),
|
||||
local_addr: build_url_from_socket_addr(&self.conn.client.id.into(), "ring").into(),
|
||||
remote_addr: build_url_from_socket_addr(&self.conn.server.id.into(), "ring").into(),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl RingTunnelListener {
|
||||
async fn add_connection(listener_addr: uuid::Uuid) {
|
||||
CONNECTION_MAP.lock().await.insert(
|
||||
listener_addr.clone(),
|
||||
Arc::new(Connection {
|
||||
client: RingTunnel::new(RING_TUNNEL_CAP),
|
||||
server: RingTunnel::new_with_id(listener_addr.clone(), RING_TUNNEL_CAP),
|
||||
connect_notify: Arc::new(Notify::new()),
|
||||
}),
|
||||
);
|
||||
}
|
||||
|
||||
fn get_addr(&self) -> Result<uuid::Uuid, TunnelError> {
|
||||
check_scheme_and_get_socket_addr::<Uuid>(&self.listerner_addr, "ring")
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl TunnelListener for RingTunnelListener {
|
||||
async fn listen(&mut self) -> Result<(), TunnelError> {
|
||||
log::info!("listen new conn of key: {}", self.listerner_addr);
|
||||
Self::add_connection(self.get_addr()?).await;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn accept(&mut self) -> Result<Box<dyn Tunnel>, TunnelError> {
|
||||
log::info!("waiting accept new conn of key: {}", self.listerner_addr);
|
||||
let val = CONNECTION_MAP
|
||||
.lock()
|
||||
.await
|
||||
.get(&self.get_addr()?)
|
||||
.unwrap()
|
||||
.clone();
|
||||
val.connect_notify.notified().await;
|
||||
CONNECTION_MAP.lock().await.remove(&self.get_addr()?);
|
||||
Self::add_connection(self.get_addr()?).await;
|
||||
log::info!("accept new conn of key: {}", self.listerner_addr);
|
||||
Ok(Box::new(ConnectionForServer { conn: val }))
|
||||
}
|
||||
|
||||
fn local_url(&self) -> url::Url {
|
||||
self.listerner_addr.clone()
|
||||
}
|
||||
}
|
||||
|
||||
pub struct RingTunnelConnector {
|
||||
remote_addr: url::Url,
|
||||
}
|
||||
|
||||
impl RingTunnelConnector {
|
||||
pub fn new(remote_addr: url::Url) -> Self {
|
||||
RingTunnelConnector { remote_addr }
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl TunnelConnector for RingTunnelConnector {
|
||||
async fn connect(&mut self) -> Result<Box<dyn Tunnel>, super::TunnelError> {
|
||||
let val = CONNECTION_MAP
|
||||
.lock()
|
||||
.await
|
||||
.get(&check_scheme_and_get_socket_addr::<Uuid>(
|
||||
&self.remote_addr,
|
||||
"ring",
|
||||
)?)
|
||||
.unwrap()
|
||||
.clone();
|
||||
val.connect_notify.notify_one();
|
||||
log::info!("connecting");
|
||||
Ok(Box::new(ConnectionForClient { conn: val }))
|
||||
}
|
||||
|
||||
fn remote_url(&self) -> url::Url {
|
||||
self.remote_addr.clone()
|
||||
}
|
||||
}
|
||||
|
||||
pub fn create_ring_tunnel_pair() -> (Box<dyn Tunnel>, Box<dyn Tunnel>) {
|
||||
let conn = Arc::new(Connection {
|
||||
client: RingTunnel::new(RING_TUNNEL_CAP),
|
||||
server: RingTunnel::new(RING_TUNNEL_CAP),
|
||||
connect_notify: Arc::new(Notify::new()),
|
||||
});
|
||||
(
|
||||
Box::new(ConnectionForServer { conn: conn.clone() }),
|
||||
Box::new(ConnectionForClient { conn: conn }),
|
||||
)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use crate::tunnels::common::tests::{_tunnel_bench, _tunnel_pingpong};
|
||||
|
||||
use super::*;
|
||||
|
||||
#[tokio::test]
|
||||
async fn ring_pingpong() {
|
||||
let id: url::Url = format!("ring://{}", Uuid::new_v4()).parse().unwrap();
|
||||
let listener = RingTunnelListener::new(id.clone());
|
||||
let connector = RingTunnelConnector::new(id.clone());
|
||||
_tunnel_pingpong(listener, connector).await
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn ring_bench() {
|
||||
let id: url::Url = format!("ring://{}", Uuid::new_v4()).parse().unwrap();
|
||||
let listener = RingTunnelListener::new(id.clone());
|
||||
let connector = RingTunnelConnector::new(id);
|
||||
_tunnel_bench(listener, connector).await
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,101 @@
|
||||
use std::sync::atomic::{AtomicU32, AtomicU64};
|
||||
|
||||
pub struct WindowLatency {
|
||||
latency_us_window: Vec<AtomicU64>,
|
||||
latency_us_window_index: AtomicU32,
|
||||
latency_us_window_size: AtomicU32,
|
||||
}
|
||||
|
||||
impl WindowLatency {
|
||||
pub fn new(window_size: u32) -> Self {
|
||||
Self {
|
||||
latency_us_window: (0..window_size).map(|_| AtomicU64::new(0)).collect(),
|
||||
latency_us_window_index: AtomicU32::new(0),
|
||||
latency_us_window_size: AtomicU32::new(window_size),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn record_latency(&self, latency_us: u64) {
|
||||
let index = self
|
||||
.latency_us_window_index
|
||||
.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
|
||||
let index = index
|
||||
% self
|
||||
.latency_us_window_size
|
||||
.load(std::sync::atomic::Ordering::Relaxed);
|
||||
self.latency_us_window[index as usize]
|
||||
.store(latency_us, std::sync::atomic::Ordering::Relaxed);
|
||||
}
|
||||
|
||||
pub fn get_latency_us(&self) -> u64 {
|
||||
let window_size = self
|
||||
.latency_us_window_size
|
||||
.load(std::sync::atomic::Ordering::Relaxed);
|
||||
let mut sum = 0;
|
||||
let mut count = 0;
|
||||
for i in 0..window_size {
|
||||
let latency_us =
|
||||
self.latency_us_window[i as usize].load(std::sync::atomic::Ordering::Relaxed);
|
||||
if latency_us > 0 {
|
||||
sum += latency_us;
|
||||
count += 1;
|
||||
}
|
||||
}
|
||||
|
||||
if count == 0 {
|
||||
0
|
||||
} else {
|
||||
sum / count
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct Throughput {
|
||||
tx_bytes: AtomicU64,
|
||||
rx_bytes: AtomicU64,
|
||||
|
||||
tx_packets: AtomicU64,
|
||||
rx_packets: AtomicU64,
|
||||
}
|
||||
|
||||
impl Throughput {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
tx_bytes: AtomicU64::new(0),
|
||||
rx_bytes: AtomicU64::new(0),
|
||||
|
||||
tx_packets: AtomicU64::new(0),
|
||||
rx_packets: AtomicU64::new(0),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn tx_bytes(&self) -> u64 {
|
||||
self.tx_bytes.load(std::sync::atomic::Ordering::Relaxed)
|
||||
}
|
||||
|
||||
pub fn rx_bytes(&self) -> u64 {
|
||||
self.rx_bytes.load(std::sync::atomic::Ordering::Relaxed)
|
||||
}
|
||||
|
||||
pub fn tx_packets(&self) -> u64 {
|
||||
self.tx_packets.load(std::sync::atomic::Ordering::Relaxed)
|
||||
}
|
||||
|
||||
pub fn rx_packets(&self) -> u64 {
|
||||
self.rx_packets.load(std::sync::atomic::Ordering::Relaxed)
|
||||
}
|
||||
|
||||
pub fn record_tx_bytes(&self, bytes: u64) {
|
||||
self.tx_bytes
|
||||
.fetch_add(bytes, std::sync::atomic::Ordering::Relaxed);
|
||||
self.tx_packets
|
||||
.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
|
||||
}
|
||||
|
||||
pub fn record_rx_bytes(&self, bytes: u64) {
|
||||
self.rx_bytes
|
||||
.fetch_add(bytes, std::sync::atomic::Ordering::Relaxed);
|
||||
self.rx_packets
|
||||
.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,284 @@
|
||||
use std::net::SocketAddr;
|
||||
|
||||
use async_trait::async_trait;
|
||||
use futures::{stream::FuturesUnordered, StreamExt};
|
||||
use tokio::net::{TcpListener, TcpSocket, TcpStream};
|
||||
use tokio_util::codec::{FramedRead, FramedWrite, LengthDelimitedCodec};
|
||||
|
||||
use super::{
|
||||
check_scheme_and_get_socket_addr, common::FramedTunnel, Tunnel, TunnelInfo, TunnelListener,
|
||||
};
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct TcpTunnelListener {
|
||||
addr: url::Url,
|
||||
listener: Option<TcpListener>,
|
||||
}
|
||||
|
||||
impl TcpTunnelListener {
|
||||
pub fn new(addr: url::Url) -> Self {
|
||||
TcpTunnelListener {
|
||||
addr,
|
||||
listener: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl TunnelListener for TcpTunnelListener {
|
||||
async fn listen(&mut self) -> Result<(), super::TunnelError> {
|
||||
let addr = check_scheme_and_get_socket_addr::<SocketAddr>(&self.addr, "tcp")?;
|
||||
|
||||
let socket = if addr.is_ipv4() {
|
||||
TcpSocket::new_v4()?
|
||||
} else {
|
||||
TcpSocket::new_v6()?
|
||||
};
|
||||
|
||||
socket.set_reuseaddr(true)?;
|
||||
#[cfg(all(unix, not(target_os = "solaris"), not(target_os = "illumos")))]
|
||||
socket.set_reuseport(true)?;
|
||||
socket.bind(addr)?;
|
||||
|
||||
self.listener = Some(socket.listen(1024)?);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn accept(&mut self) -> Result<Box<dyn Tunnel>, super::TunnelError> {
|
||||
let listener = self.listener.as_ref().unwrap();
|
||||
let (stream, _) = listener.accept().await?;
|
||||
stream.set_nodelay(true).unwrap();
|
||||
let info = TunnelInfo {
|
||||
tunnel_type: "tcp".to_owned(),
|
||||
local_addr: self.local_url().into(),
|
||||
remote_addr: super::build_url_from_socket_addr(&stream.peer_addr()?.to_string(), "tcp")
|
||||
.into(),
|
||||
};
|
||||
|
||||
let (r, w) = tokio::io::split(stream);
|
||||
Ok(FramedTunnel::new_tunnel_with_info(
|
||||
FramedRead::new(r, LengthDelimitedCodec::new()),
|
||||
FramedWrite::new(w, LengthDelimitedCodec::new()),
|
||||
info,
|
||||
))
|
||||
}
|
||||
|
||||
fn local_url(&self) -> url::Url {
|
||||
self.addr.clone()
|
||||
}
|
||||
}
|
||||
|
||||
fn get_tunnel_with_tcp_stream(
|
||||
stream: TcpStream,
|
||||
remote_url: url::Url,
|
||||
) -> Result<Box<dyn Tunnel>, super::TunnelError> {
|
||||
stream.set_nodelay(true).unwrap();
|
||||
|
||||
let info = TunnelInfo {
|
||||
tunnel_type: "tcp".to_owned(),
|
||||
local_addr: super::build_url_from_socket_addr(&stream.local_addr()?.to_string(), "tcp")
|
||||
.into(),
|
||||
remote_addr: remote_url.into(),
|
||||
};
|
||||
|
||||
let (r, w) = tokio::io::split(stream);
|
||||
Ok(Box::new(FramedTunnel::new_tunnel_with_info(
|
||||
FramedRead::new(r, LengthDelimitedCodec::new()),
|
||||
FramedWrite::new(w, LengthDelimitedCodec::new()),
|
||||
info,
|
||||
)))
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct TcpTunnelConnector {
|
||||
addr: url::Url,
|
||||
|
||||
bind_addrs: Vec<SocketAddr>,
|
||||
}
|
||||
|
||||
impl TcpTunnelConnector {
|
||||
pub fn new(addr: url::Url) -> Self {
|
||||
TcpTunnelConnector {
|
||||
addr,
|
||||
bind_addrs: vec![],
|
||||
}
|
||||
}
|
||||
|
||||
async fn connect_with_default_bind(&mut self) -> Result<Box<dyn Tunnel>, super::TunnelError> {
|
||||
tracing::info!(addr = ?self.addr, "connect tcp start");
|
||||
let addr = check_scheme_and_get_socket_addr::<SocketAddr>(&self.addr, "tcp")?;
|
||||
let stream = TcpStream::connect(addr).await?;
|
||||
tracing::info!(addr = ?self.addr, "connect tcp succ");
|
||||
return get_tunnel_with_tcp_stream(stream, self.addr.clone().into());
|
||||
}
|
||||
|
||||
async fn connect_with_custom_bind(
|
||||
&mut self,
|
||||
is_ipv4: bool,
|
||||
) -> Result<Box<dyn Tunnel>, super::TunnelError> {
|
||||
let mut futures = FuturesUnordered::new();
|
||||
let dst_addr = check_scheme_and_get_socket_addr::<SocketAddr>(&self.addr, "tcp")?;
|
||||
|
||||
for bind_addr in self.bind_addrs.iter() {
|
||||
let socket = if is_ipv4 {
|
||||
TcpSocket::new_v4()?
|
||||
} else {
|
||||
TcpSocket::new_v6()?
|
||||
};
|
||||
socket.set_reuseaddr(true)?;
|
||||
|
||||
#[cfg(all(unix, not(target_os = "solaris"), not(target_os = "illumos")))]
|
||||
socket.set_reuseport(true)?;
|
||||
|
||||
socket.bind(*bind_addr)?;
|
||||
// linux does not use interface of bind_addr to send packet, so we need to bind device
|
||||
// mac can handle this with bind correctly
|
||||
#[cfg(any(target_os = "android", target_os = "fuchsia", target_os = "linux"))]
|
||||
if let Some(dev_name) = super::common::get_interface_name_by_ip(&bind_addr.ip()) {
|
||||
tracing::trace!(dev_name = ?dev_name, "bind device");
|
||||
socket.bind_device(Some(dev_name.as_bytes()))?;
|
||||
}
|
||||
futures.push(socket.connect(dst_addr.clone()));
|
||||
}
|
||||
|
||||
let Some(ret) = futures.next().await else {
|
||||
return Err(super::TunnelError::CommonError(
|
||||
"join connect futures failed".to_owned(),
|
||||
));
|
||||
};
|
||||
|
||||
return get_tunnel_with_tcp_stream(ret?, self.addr.clone().into());
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl super::TunnelConnector for TcpTunnelConnector {
|
||||
async fn connect(&mut self) -> Result<Box<dyn Tunnel>, super::TunnelError> {
|
||||
if self.bind_addrs.is_empty() {
|
||||
self.connect_with_default_bind().await
|
||||
} else if self.bind_addrs[0].is_ipv4() {
|
||||
self.connect_with_custom_bind(true).await
|
||||
} else {
|
||||
self.connect_with_custom_bind(false).await
|
||||
}
|
||||
}
|
||||
|
||||
fn remote_url(&self) -> url::Url {
|
||||
self.addr.clone()
|
||||
}
|
||||
fn set_bind_addrs(&mut self, addrs: Vec<SocketAddr>) {
|
||||
self.bind_addrs = addrs;
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use futures::SinkExt;
|
||||
|
||||
use crate::tunnels::{
|
||||
common::tests::{_tunnel_bench, _tunnel_pingpong},
|
||||
TunnelConnector,
|
||||
};
|
||||
|
||||
use super::*;
|
||||
|
||||
#[tokio::test]
|
||||
async fn tcp_pingpong() {
|
||||
let listener = TcpTunnelListener::new("tcp://0.0.0.0:11011".parse().unwrap());
|
||||
let connector = TcpTunnelConnector::new("tcp://127.0.0.1:11011".parse().unwrap());
|
||||
_tunnel_pingpong(listener, connector).await
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn tcp_bench() {
|
||||
let listener = TcpTunnelListener::new("tcp://0.0.0.0:11012".parse().unwrap());
|
||||
let connector = TcpTunnelConnector::new("tcp://127.0.0.1:11012".parse().unwrap());
|
||||
_tunnel_bench(listener, connector).await
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn tcp_bench_with_bind() {
|
||||
let listener = TcpTunnelListener::new("tcp://127.0.0.1:11013".parse().unwrap());
|
||||
let mut connector = TcpTunnelConnector::new("tcp://127.0.0.1:11013".parse().unwrap());
|
||||
connector.set_bind_addrs(vec!["127.0.0.1:0".parse().unwrap()]);
|
||||
_tunnel_pingpong(listener, connector).await
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
#[should_panic]
|
||||
async fn tcp_bench_with_bind_fail() {
|
||||
let listener = TcpTunnelListener::new("tcp://127.0.0.1:11014".parse().unwrap());
|
||||
let mut connector = TcpTunnelConnector::new("tcp://127.0.0.1:11014".parse().unwrap());
|
||||
connector.set_bind_addrs(vec!["10.0.0.1:0".parse().unwrap()]);
|
||||
_tunnel_pingpong(listener, connector).await
|
||||
}
|
||||
|
||||
// test slow send lock in framed tunnel
|
||||
#[tokio::test]
|
||||
async fn tcp_multiple_sender_and_slow_receiver() {
|
||||
// console_subscriber::init();
|
||||
let mut listener = TcpTunnelListener::new("tcp://127.0.0.1:11014".parse().unwrap());
|
||||
let mut connector = TcpTunnelConnector::new("tcp://127.0.0.1:11014".parse().unwrap());
|
||||
|
||||
listener.listen().await.unwrap();
|
||||
let t1 = tokio::spawn(async move {
|
||||
let t = listener.accept().await.unwrap();
|
||||
let mut stream = t.pin_stream();
|
||||
|
||||
let now = tokio::time::Instant::now();
|
||||
|
||||
while let Some(Ok(_)) = stream.next().await {
|
||||
tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
|
||||
if now.elapsed().as_secs() > 5 {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
tracing::info!("t1 exit");
|
||||
});
|
||||
|
||||
let tunnel = connector.connect().await.unwrap();
|
||||
let mut sink1 = tunnel.pin_sink();
|
||||
let t2 = tokio::spawn(async move {
|
||||
for i in 0..1000000 {
|
||||
let a = sink1.send(b"hello".to_vec().into()).await;
|
||||
if a.is_err() {
|
||||
tracing::info!(?a, "t2 exit with err");
|
||||
break;
|
||||
}
|
||||
|
||||
if i % 5000 == 0 {
|
||||
tracing::info!(i, "send2 1000");
|
||||
}
|
||||
}
|
||||
|
||||
tracing::info!("t2 exit");
|
||||
});
|
||||
|
||||
let mut sink2 = tunnel.pin_sink();
|
||||
let t3 = tokio::spawn(async move {
|
||||
for i in 0..1000000 {
|
||||
let a = sink2.send(b"hello".to_vec().into()).await;
|
||||
if a.is_err() {
|
||||
tracing::info!(?a, "t3 exit with err");
|
||||
break;
|
||||
}
|
||||
|
||||
if i % 5000 == 0 {
|
||||
tracing::info!(i, "send2 1000");
|
||||
}
|
||||
}
|
||||
|
||||
tracing::info!("t3 exit");
|
||||
});
|
||||
|
||||
let t4 = tokio::spawn(async move {
|
||||
tokio::time::sleep(tokio::time::Duration::from_secs(5)).await;
|
||||
tracing::info!("closing");
|
||||
let close_ret = tunnel.pin_sink().close().await;
|
||||
tracing::info!("closed {:?}", close_ret);
|
||||
});
|
||||
|
||||
let _ = tokio::join!(t1, t2, t3, t4);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,228 @@
|
||||
use std::{
|
||||
sync::Arc,
|
||||
task::{Context, Poll},
|
||||
};
|
||||
|
||||
use easytier_rpc::TunnelInfo;
|
||||
use futures::{Sink, SinkExt, Stream, StreamExt};
|
||||
|
||||
use self::stats::Throughput;
|
||||
|
||||
use super::*;
|
||||
use crate::tunnels::{DatagramSink, DatagramStream, SinkError, SinkItem, StreamItem, Tunnel};
|
||||
|
||||
pub trait TunnelFilter {
|
||||
fn before_send(&self, data: SinkItem) -> Result<SinkItem, SinkError>;
|
||||
fn after_received(&self, data: StreamItem) -> Result<BytesMut, TunnelError>;
|
||||
}
|
||||
|
||||
pub struct TunnelWithFilter<T, F> {
|
||||
inner: T,
|
||||
filter: Arc<F>,
|
||||
}
|
||||
|
||||
impl<T, F> Tunnel for TunnelWithFilter<T, F>
|
||||
where
|
||||
T: Tunnel + Send + Sync + 'static,
|
||||
F: TunnelFilter + Send + Sync + 'static,
|
||||
{
|
||||
fn sink(&self) -> Box<dyn DatagramSink> {
|
||||
struct SinkWrapper<F> {
|
||||
sink: Pin<Box<dyn DatagramSink>>,
|
||||
filter: Arc<F>,
|
||||
}
|
||||
impl<F> Sink<SinkItem> for SinkWrapper<F>
|
||||
where
|
||||
F: TunnelFilter + Send + Sync + 'static,
|
||||
{
|
||||
type Error = SinkError;
|
||||
|
||||
fn poll_ready(
|
||||
self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
) -> Poll<Result<(), Self::Error>> {
|
||||
self.get_mut().sink.poll_ready_unpin(cx)
|
||||
}
|
||||
|
||||
fn start_send(self: Pin<&mut Self>, item: SinkItem) -> Result<(), Self::Error> {
|
||||
let item = self.filter.before_send(item)?;
|
||||
self.get_mut().sink.start_send_unpin(item)
|
||||
}
|
||||
|
||||
fn poll_flush(
|
||||
self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
) -> Poll<Result<(), Self::Error>> {
|
||||
self.get_mut().sink.poll_flush_unpin(cx)
|
||||
}
|
||||
|
||||
fn poll_close(
|
||||
self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
) -> Poll<Result<(), Self::Error>> {
|
||||
self.get_mut().sink.poll_close_unpin(cx)
|
||||
}
|
||||
}
|
||||
|
||||
Box::new(SinkWrapper {
|
||||
sink: self.inner.pin_sink(),
|
||||
filter: self.filter.clone(),
|
||||
})
|
||||
}
|
||||
|
||||
fn stream(&self) -> Box<dyn DatagramStream> {
|
||||
struct StreamWrapper<F> {
|
||||
stream: Pin<Box<dyn DatagramStream>>,
|
||||
filter: Arc<F>,
|
||||
}
|
||||
impl<F> Stream for StreamWrapper<F>
|
||||
where
|
||||
F: TunnelFilter + Send + Sync + 'static,
|
||||
{
|
||||
type Item = StreamItem;
|
||||
|
||||
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
|
||||
let self_mut = self.get_mut();
|
||||
match self_mut.stream.poll_next_unpin(cx) {
|
||||
Poll::Ready(Some(ret)) => {
|
||||
Poll::Ready(Some(self_mut.filter.after_received(ret)))
|
||||
}
|
||||
Poll::Ready(None) => Poll::Ready(None),
|
||||
Poll::Pending => Poll::Pending,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Box::new(StreamWrapper {
|
||||
stream: self.inner.pin_stream(),
|
||||
filter: self.filter.clone(),
|
||||
})
|
||||
}
|
||||
|
||||
fn info(&self) -> Option<TunnelInfo> {
|
||||
self.inner.info()
|
||||
}
|
||||
}
|
||||
|
||||
impl<T, F> TunnelWithFilter<T, F>
|
||||
where
|
||||
T: Tunnel + Send + Sync + 'static,
|
||||
F: TunnelFilter + Send + Sync + 'static,
|
||||
{
|
||||
pub fn new(inner: T, filter: Arc<F>) -> Self {
|
||||
Self { inner, filter }
|
||||
}
|
||||
}
|
||||
|
||||
pub struct PacketRecorderTunnelFilter {
|
||||
pub received: Arc<std::sync::Mutex<Vec<Bytes>>>,
|
||||
pub sent: Arc<std::sync::Mutex<Vec<Bytes>>>,
|
||||
}
|
||||
|
||||
impl TunnelFilter for PacketRecorderTunnelFilter {
|
||||
fn before_send(&self, data: SinkItem) -> Result<SinkItem, SinkError> {
|
||||
self.received.lock().unwrap().push(data.clone());
|
||||
Ok(data)
|
||||
}
|
||||
|
||||
fn after_received(&self, data: StreamItem) -> Result<BytesMut, TunnelError> {
|
||||
match data {
|
||||
Ok(v) => {
|
||||
self.sent.lock().unwrap().push(v.clone().into());
|
||||
Ok(v)
|
||||
}
|
||||
Err(e) => Err(e),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl PacketRecorderTunnelFilter {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
received: Arc::new(std::sync::Mutex::new(Vec::new())),
|
||||
sent: Arc::new(std::sync::Mutex::new(Vec::new())),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct StatsRecorderTunnelFilter {
|
||||
throughput: Arc<Throughput>,
|
||||
}
|
||||
|
||||
impl TunnelFilter for StatsRecorderTunnelFilter {
|
||||
fn before_send(&self, data: SinkItem) -> Result<SinkItem, SinkError> {
|
||||
self.throughput.record_tx_bytes(data.len() as u64);
|
||||
Ok(data)
|
||||
}
|
||||
|
||||
fn after_received(&self, data: StreamItem) -> Result<BytesMut, TunnelError> {
|
||||
match data {
|
||||
Ok(v) => {
|
||||
self.throughput.record_rx_bytes(v.len() as u64);
|
||||
Ok(v)
|
||||
}
|
||||
Err(e) => Err(e),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl StatsRecorderTunnelFilter {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
throughput: Arc::new(Throughput::new()),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn get_throughput(&self) -> Arc<Throughput> {
|
||||
self.throughput.clone()
|
||||
}
|
||||
}
|
||||
|
||||
#[macro_export]
|
||||
macro_rules! define_tunnel_filter_chain {
|
||||
($type_name:ident $(, $field_name:ident = $filter_type:ty)+) => (
|
||||
pub struct $type_name {
|
||||
$($field_name: std::sync::Arc<$filter_type>,)+
|
||||
}
|
||||
|
||||
impl $type_name {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
$($field_name: std::sync::Arc::new(<$filter_type>::new()),)+
|
||||
}
|
||||
}
|
||||
|
||||
pub fn wrap_tunnel(&self, tunnel: impl Tunnel + 'static) -> impl Tunnel {
|
||||
$(
|
||||
let tunnel = crate::tunnels::tunnel_filter::TunnelWithFilter::new(tunnel, self.$field_name.clone());
|
||||
)+
|
||||
tunnel
|
||||
}
|
||||
}
|
||||
)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::tunnels::ring_tunnel::RingTunnel;
|
||||
#[tokio::test]
|
||||
async fn test_nested_filter() {
|
||||
define_tunnel_filter_chain!(
|
||||
Filter,
|
||||
a = PacketRecorderTunnelFilter,
|
||||
b = PacketRecorderTunnelFilter,
|
||||
c = PacketRecorderTunnelFilter
|
||||
);
|
||||
|
||||
let filter = Filter::new();
|
||||
let tunnel = filter.wrap_tunnel(RingTunnel::new(1));
|
||||
|
||||
let mut s = tunnel.pin_sink();
|
||||
s.send(Bytes::from("hello")).await.unwrap();
|
||||
|
||||
assert_eq!(1, filter.a.received.lock().unwrap().len());
|
||||
assert_eq!(1, filter.b.received.lock().unwrap().len());
|
||||
assert_eq!(1, filter.c.received.lock().unwrap().len());
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,574 @@
|
||||
use std::{fmt::Debug, pin::Pin, sync::Arc};
|
||||
|
||||
use async_trait::async_trait;
|
||||
use dashmap::DashMap;
|
||||
use easytier_rpc::TunnelInfo;
|
||||
use futures::{stream::FuturesUnordered, SinkExt, StreamExt};
|
||||
use rkyv::{Archive, Deserialize, Serialize};
|
||||
use std::net::SocketAddr;
|
||||
use tokio::{net::UdpSocket, sync::Mutex, task::JoinSet};
|
||||
use tokio_util::{
|
||||
bytes::{Buf, Bytes, BytesMut},
|
||||
udp::UdpFramed,
|
||||
};
|
||||
use tracing::Instrument;
|
||||
|
||||
use crate::{
|
||||
common::rkyv_util::{self, encode_to_bytes},
|
||||
tunnels::{build_url_from_socket_addr, close_tunnel, TunnelConnCounter, TunnelConnector},
|
||||
};
|
||||
|
||||
use super::{
|
||||
codec::BytesCodec,
|
||||
common::{FramedTunnel, TunnelWithCustomInfo},
|
||||
ring_tunnel::create_ring_tunnel_pair,
|
||||
DatagramSink, DatagramStream, Tunnel, TunnelListener,
|
||||
};
|
||||
|
||||
pub const UDP_DATA_MTU: usize = 2500;
|
||||
|
||||
#[derive(Archive, Deserialize, Serialize, Debug)]
|
||||
#[archive(compare(PartialEq), check_bytes)]
|
||||
// Derives can be passed through to the generated type:
|
||||
#[archive_attr(derive(Debug))]
|
||||
pub enum UdpPacketPayload {
|
||||
Syn,
|
||||
Sack,
|
||||
HolePunch(Vec<u8>),
|
||||
Data(Vec<u8>),
|
||||
}
|
||||
|
||||
#[derive(Archive, Deserialize, Serialize, Debug)]
|
||||
#[archive(compare(PartialEq), check_bytes)]
|
||||
#[archive_attr(derive(Debug))]
|
||||
pub struct UdpPacket {
|
||||
pub conn_id: u32,
|
||||
pub payload: UdpPacketPayload,
|
||||
}
|
||||
|
||||
impl UdpPacket {
|
||||
pub fn new_data_packet(conn_id: u32, data: Vec<u8>) -> Self {
|
||||
Self {
|
||||
conn_id,
|
||||
payload: UdpPacketPayload::Data(data),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn new_hole_punch_packet(data: Vec<u8>) -> Self {
|
||||
Self {
|
||||
conn_id: 0,
|
||||
payload: UdpPacketPayload::HolePunch(data),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn new_syn_packet(conn_id: u32) -> Self {
|
||||
Self {
|
||||
conn_id,
|
||||
payload: UdpPacketPayload::Syn,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn new_sack_packet(conn_id: u32) -> Self {
|
||||
Self {
|
||||
conn_id,
|
||||
payload: UdpPacketPayload::Sack,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn try_get_data_payload(mut buf: BytesMut, conn_id: u32) -> Option<BytesMut> {
|
||||
let Ok(udp_packet) = rkyv_util::decode_from_bytes_checked::<UdpPacket>(&buf) else {
|
||||
tracing::warn!(?buf, "udp decode error");
|
||||
return None;
|
||||
};
|
||||
|
||||
if udp_packet.conn_id != conn_id.clone() {
|
||||
tracing::warn!(?udp_packet, ?conn_id, "udp conn id not match");
|
||||
return None;
|
||||
}
|
||||
|
||||
let ArchivedUdpPacketPayload::Data(payload) = &udp_packet.payload else {
|
||||
tracing::warn!(?udp_packet, "udp payload not data");
|
||||
return None;
|
||||
};
|
||||
|
||||
let ptr_range = payload.as_ptr_range();
|
||||
let offset = ptr_range.start as usize - buf.as_ptr() as usize;
|
||||
let len = ptr_range.end as usize - ptr_range.start as usize;
|
||||
buf.advance(offset);
|
||||
buf.truncate(len);
|
||||
tracing::trace!(?offset, ?len, ?buf, "udp payload data");
|
||||
|
||||
Some(buf)
|
||||
}
|
||||
|
||||
fn get_tunnel_from_socket(
|
||||
socket: Arc<UdpSocket>,
|
||||
addr: SocketAddr,
|
||||
conn_id: u32,
|
||||
) -> Box<dyn super::Tunnel> {
|
||||
let udp = UdpFramed::new(socket.clone(), BytesCodec::new(UDP_DATA_MTU));
|
||||
let (sink, stream) = udp.split();
|
||||
|
||||
let recv_addr = addr;
|
||||
let stream = stream.filter_map(move |v| async move {
|
||||
tracing::trace!(?v, "udp stream recv something");
|
||||
if v.is_err() {
|
||||
tracing::warn!(?v, "udp stream error");
|
||||
return Some(Err(super::TunnelError::CommonError(
|
||||
"udp stream error".to_owned(),
|
||||
)));
|
||||
}
|
||||
|
||||
let (buf, addr) = v.unwrap();
|
||||
assert_eq!(addr, recv_addr.clone());
|
||||
Some(Ok(try_get_data_payload(buf, conn_id.clone())?))
|
||||
});
|
||||
let stream = Box::pin(stream);
|
||||
|
||||
let sender_addr = addr;
|
||||
let sink = Box::pin(sink.with(move |v: Bytes| async move {
|
||||
if false {
|
||||
return Err(super::TunnelError::CommonError("udp sink error".to_owned()));
|
||||
}
|
||||
|
||||
// TODO: two copy here, how to avoid?
|
||||
let udp_packet = UdpPacket::new_data_packet(conn_id, v.to_vec());
|
||||
tracing::trace!(?udp_packet, ?v, "udp send packet");
|
||||
let v = encode_to_bytes::<_, UDP_DATA_MTU>(&udp_packet);
|
||||
|
||||
Ok((v, sender_addr))
|
||||
}));
|
||||
|
||||
FramedTunnel::new_tunnel_with_info(
|
||||
stream,
|
||||
sink,
|
||||
// TODO: this remote addr is not a url
|
||||
super::TunnelInfo {
|
||||
tunnel_type: "udp".to_owned(),
|
||||
local_addr: super::build_url_from_socket_addr(
|
||||
&socket.local_addr().unwrap().to_string(),
|
||||
"udp",
|
||||
)
|
||||
.into(),
|
||||
remote_addr: super::build_url_from_socket_addr(&addr.to_string(), "udp").into(),
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
struct StreamSinkPair(
|
||||
Pin<Box<dyn DatagramStream>>,
|
||||
Pin<Box<dyn DatagramSink>>,
|
||||
u32,
|
||||
);
|
||||
type ArcStreamSinkPair = Arc<Mutex<StreamSinkPair>>;
|
||||
|
||||
pub struct UdpTunnelListener {
|
||||
addr: url::Url,
|
||||
socket: Option<Arc<UdpSocket>>,
|
||||
|
||||
sock_map: Arc<DashMap<SocketAddr, ArcStreamSinkPair>>,
|
||||
forward_tasks: Arc<Mutex<JoinSet<()>>>,
|
||||
|
||||
conn_recv: tokio::sync::mpsc::Receiver<Box<dyn Tunnel>>,
|
||||
conn_send: Option<tokio::sync::mpsc::Sender<Box<dyn Tunnel>>>,
|
||||
}
|
||||
|
||||
impl UdpTunnelListener {
|
||||
pub fn new(addr: url::Url) -> Self {
|
||||
let (conn_send, conn_recv) = tokio::sync::mpsc::channel(100);
|
||||
Self {
|
||||
addr,
|
||||
socket: None,
|
||||
sock_map: Arc::new(DashMap::new()),
|
||||
forward_tasks: Arc::new(Mutex::new(JoinSet::new())),
|
||||
conn_recv,
|
||||
conn_send: Some(conn_send),
|
||||
}
|
||||
}
|
||||
|
||||
async fn try_forward_packet(
|
||||
sock_map: &DashMap<SocketAddr, ArcStreamSinkPair>,
|
||||
buf: BytesMut,
|
||||
addr: SocketAddr,
|
||||
) -> Result<(), super::TunnelError> {
|
||||
let entry = sock_map.get_mut(&addr);
|
||||
if entry.is_none() {
|
||||
log::warn!("udp forward packet: {:?}, {:?}, no entry", addr, buf);
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
log::trace!("udp forward packet: {:?}, {:?}", addr, buf);
|
||||
let entry = entry.unwrap();
|
||||
let pair = entry.value().clone();
|
||||
drop(entry);
|
||||
|
||||
let Some(buf) = try_get_data_payload(buf, pair.lock().await.2) else {
|
||||
return Ok(());
|
||||
};
|
||||
pair.lock().await.1.send(buf.freeze()).await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn handle_connect(
|
||||
socket: Arc<UdpSocket>,
|
||||
addr: SocketAddr,
|
||||
forward_tasks: Arc<Mutex<JoinSet<()>>>,
|
||||
sock_map: Arc<DashMap<SocketAddr, ArcStreamSinkPair>>,
|
||||
local_url: url::Url,
|
||||
conn_id: u32,
|
||||
) -> Result<Box<dyn Tunnel>, super::TunnelError> {
|
||||
tracing::info!(?conn_id, ?addr, "udp connection accept handling",);
|
||||
|
||||
let udp_packet = UdpPacket::new_sack_packet(conn_id);
|
||||
let sack_buf = encode_to_bytes::<_, UDP_DATA_MTU>(&udp_packet);
|
||||
socket.send_to(&sack_buf, addr).await?;
|
||||
|
||||
let (ctunnel, stunnel) = create_ring_tunnel_pair();
|
||||
let udp_tunnel = get_tunnel_from_socket(socket.clone(), addr, conn_id);
|
||||
let ss_pair = StreamSinkPair(ctunnel.pin_stream(), ctunnel.pin_sink(), conn_id);
|
||||
let addr_copy = addr.clone();
|
||||
sock_map.insert(addr, Arc::new(Mutex::new(ss_pair)));
|
||||
let ctunnel_stream = ctunnel.pin_stream();
|
||||
forward_tasks.lock().await.spawn(async move {
|
||||
let ret = ctunnel_stream
|
||||
.map(|v| {
|
||||
tracing::trace!(?v, "udp stream recv something in forward task");
|
||||
if v.is_err() {
|
||||
return Err(super::TunnelError::CommonError(
|
||||
"udp stream error".to_owned(),
|
||||
));
|
||||
}
|
||||
Ok(v.unwrap().freeze())
|
||||
})
|
||||
.forward(udp_tunnel.pin_sink())
|
||||
.await;
|
||||
if let None = sock_map.remove(&addr_copy) {
|
||||
log::warn!("udp forward packet: {:?}, no entry", addr_copy);
|
||||
}
|
||||
close_tunnel(&ctunnel).await.unwrap();
|
||||
log::warn!("udp connection forward done: {:?}, {:?}", addr_copy, ret);
|
||||
});
|
||||
|
||||
Ok(Box::new(TunnelWithCustomInfo::new(
|
||||
stunnel,
|
||||
TunnelInfo {
|
||||
tunnel_type: "udp".to_owned(),
|
||||
local_addr: local_url.into(),
|
||||
remote_addr: build_url_from_socket_addr(&addr.to_string(), "udp").into(),
|
||||
},
|
||||
)))
|
||||
}
|
||||
|
||||
pub fn get_socket(&self) -> Option<Arc<UdpSocket>> {
|
||||
self.socket.clone()
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl TunnelListener for UdpTunnelListener {
|
||||
async fn listen(&mut self) -> Result<(), super::TunnelError> {
|
||||
let addr = super::check_scheme_and_get_socket_addr::<SocketAddr>(&self.addr, "udp")?;
|
||||
self.socket = Some(Arc::new(UdpSocket::bind(addr).await?));
|
||||
|
||||
let socket = self.socket.as_ref().unwrap().clone();
|
||||
let forward_tasks = self.forward_tasks.clone();
|
||||
let sock_map = self.sock_map.clone();
|
||||
let conn_send = self.conn_send.take().unwrap();
|
||||
let local_url = self.local_url().clone();
|
||||
self.forward_tasks.lock().await.spawn(
|
||||
async move {
|
||||
loop {
|
||||
let mut buf = BytesMut::new();
|
||||
buf.resize(2500, 0);
|
||||
let (_size, addr) = socket.recv_from(&mut buf).await.unwrap();
|
||||
let _ = buf.split_off(_size);
|
||||
log::trace!(
|
||||
"udp recv packet: {:?}, buf: {:?}, size: {}",
|
||||
addr,
|
||||
buf,
|
||||
_size
|
||||
);
|
||||
|
||||
let Ok(udp_packet) = rkyv_util::decode_from_bytes_checked::<UdpPacket>(&buf)
|
||||
else {
|
||||
tracing::warn!(?buf, "udp decode error in forward task");
|
||||
continue;
|
||||
};
|
||||
|
||||
if matches!(udp_packet.payload, ArchivedUdpPacketPayload::Syn) {
|
||||
let conn = Self::handle_connect(
|
||||
socket.clone(),
|
||||
addr,
|
||||
forward_tasks.clone(),
|
||||
sock_map.clone(),
|
||||
local_url.clone(),
|
||||
udp_packet.conn_id.into(),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
if let Err(e) = conn_send.send(conn).await {
|
||||
tracing::warn!(?e, "udp send conn to accept channel error");
|
||||
}
|
||||
} else {
|
||||
Self::try_forward_packet(sock_map.as_ref(), buf, addr)
|
||||
.await
|
||||
.unwrap();
|
||||
}
|
||||
}
|
||||
}
|
||||
.instrument(tracing::info_span!("udp forward task", ?self.socket)),
|
||||
);
|
||||
|
||||
// let forward_tasks_clone = self.forward_tasks.clone();
|
||||
// tokio::spawn(async move {
|
||||
// loop {
|
||||
// let mut locked_forward_tasks = forward_tasks_clone.lock().await;
|
||||
// tokio::select! {
|
||||
// ret = locked_forward_tasks.join_next() => {
|
||||
// tracing::warn!(?ret, "udp forward task exit");
|
||||
// }
|
||||
// else => {
|
||||
// drop(locked_forward_tasks);
|
||||
// tokio::time::sleep(tokio::time::Duration::from_secs(1)).await;
|
||||
// continue;
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
// });
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn accept(&mut self) -> Result<Box<dyn super::Tunnel>, super::TunnelError> {
|
||||
log::info!("start udp accept: {:?}", self.addr);
|
||||
while let Some(conn) = self.conn_recv.recv().await {
|
||||
return Ok(conn);
|
||||
}
|
||||
return Err(super::TunnelError::CommonError(
|
||||
"udp accept error".to_owned(),
|
||||
));
|
||||
}
|
||||
|
||||
fn local_url(&self) -> url::Url {
|
||||
self.addr.clone()
|
||||
}
|
||||
|
||||
fn get_conn_counter(&self) -> Arc<Box<dyn TunnelConnCounter>> {
|
||||
struct UdpTunnelConnCounter {
|
||||
sock_map: Arc<DashMap<SocketAddr, ArcStreamSinkPair>>,
|
||||
}
|
||||
|
||||
impl Debug for UdpTunnelConnCounter {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
f.debug_struct("UdpTunnelConnCounter")
|
||||
.field("sock_map_len", &self.sock_map.len())
|
||||
.finish()
|
||||
}
|
||||
}
|
||||
|
||||
impl TunnelConnCounter for UdpTunnelConnCounter {
|
||||
fn get(&self) -> u32 {
|
||||
self.sock_map.len() as u32
|
||||
}
|
||||
}
|
||||
|
||||
Arc::new(Box::new(UdpTunnelConnCounter {
|
||||
sock_map: self.sock_map.clone(),
|
||||
}))
|
||||
}
|
||||
}
|
||||
|
||||
pub struct UdpTunnelConnector {
|
||||
addr: url::Url,
|
||||
bind_addrs: Vec<SocketAddr>,
|
||||
}
|
||||
|
||||
impl UdpTunnelConnector {
|
||||
pub fn new(addr: url::Url) -> Self {
|
||||
Self {
|
||||
addr,
|
||||
bind_addrs: vec![],
|
||||
}
|
||||
}
|
||||
|
||||
async fn wait_sack(
|
||||
socket: &UdpSocket,
|
||||
addr: SocketAddr,
|
||||
conn_id: u32,
|
||||
) -> Result<(), super::TunnelError> {
|
||||
let mut buf = BytesMut::new();
|
||||
buf.resize(128, 0);
|
||||
|
||||
let (usize, recv_addr) = tokio::time::timeout(
|
||||
tokio::time::Duration::from_secs(3),
|
||||
socket.recv_from(&mut buf),
|
||||
)
|
||||
.await??;
|
||||
|
||||
if recv_addr != addr {
|
||||
return Err(super::TunnelError::ConnectError(format!(
|
||||
"udp connect error, unexpected sack addr: {:?}, {:?}",
|
||||
recv_addr, addr
|
||||
)));
|
||||
}
|
||||
|
||||
let _ = buf.split_off(usize);
|
||||
|
||||
let Ok(udp_packet) = rkyv_util::decode_from_bytes_checked::<UdpPacket>(&buf) else {
|
||||
tracing::warn!(?buf, "udp decode error in wait sack");
|
||||
return Err(super::TunnelError::ConnectError(format!(
|
||||
"udp connect error, decode error. buf: {:?}",
|
||||
buf
|
||||
)));
|
||||
};
|
||||
|
||||
if conn_id != udp_packet.conn_id {
|
||||
return Err(super::TunnelError::ConnectError(format!(
|
||||
"udp connect error, conn id not match. conn_id: {:?}, {:?}",
|
||||
conn_id, udp_packet.conn_id
|
||||
)));
|
||||
}
|
||||
|
||||
if !matches!(udp_packet.payload, ArchivedUdpPacketPayload::Sack) {
|
||||
return Err(super::TunnelError::ConnectError(format!(
|
||||
"udp connect error, unexpected payload. payload: {:?}",
|
||||
udp_packet.payload
|
||||
)));
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn wait_sack_loop(
|
||||
socket: &UdpSocket,
|
||||
addr: SocketAddr,
|
||||
conn_id: u32,
|
||||
) -> Result<(), super::TunnelError> {
|
||||
while let Err(err) = Self::wait_sack(socket, addr, conn_id).await {
|
||||
tracing::warn!(?err, "udp wait sack error");
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn try_connect_with_socket(
|
||||
&self,
|
||||
socket: UdpSocket,
|
||||
) -> Result<Box<dyn super::Tunnel>, super::TunnelError> {
|
||||
let addr = super::check_scheme_and_get_socket_addr::<SocketAddr>(&self.addr, "udp")?;
|
||||
log::warn!("udp connect: {:?}", self.addr);
|
||||
|
||||
// send syn
|
||||
let conn_id = rand::random();
|
||||
let udp_packet = UdpPacket::new_syn_packet(conn_id);
|
||||
let b = encode_to_bytes::<_, UDP_DATA_MTU>(&udp_packet);
|
||||
let ret = socket.send_to(&b, &addr).await?;
|
||||
tracing::warn!(?udp_packet, ?ret, "udp send syn");
|
||||
|
||||
// wait sack
|
||||
tokio::time::timeout(
|
||||
tokio::time::Duration::from_secs(3),
|
||||
Self::wait_sack_loop(&socket, addr, conn_id),
|
||||
)
|
||||
.await??;
|
||||
|
||||
// sack done
|
||||
let local_addr = socket.local_addr().unwrap().to_string();
|
||||
Ok(Box::new(TunnelWithCustomInfo::new(
|
||||
get_tunnel_from_socket(Arc::new(socket), addr, conn_id),
|
||||
TunnelInfo {
|
||||
tunnel_type: "udp".to_owned(),
|
||||
local_addr: super::build_url_from_socket_addr(&local_addr, "udp").into(),
|
||||
remote_addr: self.remote_url().into(),
|
||||
},
|
||||
)))
|
||||
}
|
||||
|
||||
async fn connect_with_default_bind(&mut self) -> Result<Box<dyn Tunnel>, super::TunnelError> {
|
||||
let socket = UdpSocket::bind("0.0.0.0:0").await?;
|
||||
return self.try_connect_with_socket(socket).await;
|
||||
}
|
||||
|
||||
async fn connect_with_custom_bind(&mut self) -> Result<Box<dyn Tunnel>, super::TunnelError> {
|
||||
let mut futures = FuturesUnordered::new();
|
||||
|
||||
for bind_addr in self.bind_addrs.iter() {
|
||||
let socket = UdpSocket::bind(*bind_addr).await?;
|
||||
|
||||
// linux does not use interface of bind_addr to send packet, so we need to bind device
|
||||
// mac can handle this with bind correctly
|
||||
#[cfg(any(target_os = "android", target_os = "fuchsia", target_os = "linux"))]
|
||||
if let Some(dev_name) = super::common::get_interface_name_by_ip(&bind_addr.ip()) {
|
||||
tracing::trace!(dev_name = ?dev_name, "bind device");
|
||||
socket.bind_device(Some(dev_name.as_bytes()))?;
|
||||
}
|
||||
|
||||
futures.push(self.try_connect_with_socket(socket));
|
||||
}
|
||||
|
||||
let Some(ret) = futures.next().await else {
|
||||
return Err(super::TunnelError::CommonError(
|
||||
"join connect futures failed".to_owned(),
|
||||
));
|
||||
};
|
||||
|
||||
return ret;
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl super::TunnelConnector for UdpTunnelConnector {
|
||||
async fn connect(&mut self) -> Result<Box<dyn super::Tunnel>, super::TunnelError> {
|
||||
if self.bind_addrs.is_empty() {
|
||||
self.connect_with_default_bind().await
|
||||
} else {
|
||||
self.connect_with_custom_bind().await
|
||||
}
|
||||
}
|
||||
|
||||
fn remote_url(&self) -> url::Url {
|
||||
self.addr.clone()
|
||||
}
|
||||
|
||||
fn set_bind_addrs(&mut self, addrs: Vec<SocketAddr>) {
|
||||
self.bind_addrs = addrs;
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use crate::tunnels::common::tests::{_tunnel_bench, _tunnel_pingpong};
|
||||
|
||||
use super::*;
|
||||
|
||||
#[tokio::test]
|
||||
async fn udp_pingpong() {
|
||||
let listener = UdpTunnelListener::new("udp://0.0.0.0:5556".parse().unwrap());
|
||||
let connector = UdpTunnelConnector::new("udp://127.0.0.1:5556".parse().unwrap());
|
||||
_tunnel_pingpong(listener, connector).await
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn udp_bench() {
|
||||
let listener = UdpTunnelListener::new("udp://0.0.0.0:5555".parse().unwrap());
|
||||
let connector = UdpTunnelConnector::new("udp://127.0.0.1:5555".parse().unwrap());
|
||||
_tunnel_bench(listener, connector).await
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn udp_bench_with_bind() {
|
||||
let listener = UdpTunnelListener::new("udp://127.0.0.1:5554".parse().unwrap());
|
||||
let mut connector = UdpTunnelConnector::new("udp://127.0.0.1:5554".parse().unwrap());
|
||||
connector.set_bind_addrs(vec!["127.0.0.1:0".parse().unwrap()]);
|
||||
_tunnel_pingpong(listener, connector).await
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
#[should_panic]
|
||||
async fn udp_bench_with_bind_fail() {
|
||||
let listener = UdpTunnelListener::new("udp://127.0.0.1:5553".parse().unwrap());
|
||||
let mut connector = UdpTunnelConnector::new("udp://127.0.0.1:5553".parse().unwrap());
|
||||
connector.set_bind_addrs(vec!["10.0.0.1:0".parse().unwrap()]);
|
||||
_tunnel_pingpong(listener, connector).await
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user