Initial Version

This commit is contained in:
sijie.sun
2023-09-23 01:53:45 +00:00
commit 9779923b87
63 changed files with 10840 additions and 0 deletions
+14
View File
@@ -0,0 +1,14 @@
pub mod packet;
pub mod peer;
pub mod peer_conn;
pub mod peer_manager;
pub mod peer_map;
pub mod peer_rpc;
pub mod rip_route;
pub mod route_trait;
pub mod rpc_service;
#[cfg(test)]
pub mod tests;
pub type PeerId = uuid::Uuid;
+205
View File
@@ -0,0 +1,205 @@
use rkyv::{Archive, Deserialize, Serialize};
use tokio_util::bytes::Bytes;
use crate::common::rkyv_util::{decode_from_bytes, encode_to_bytes};
const MAGIC: u32 = 0xd1e1a5e1;
const VERSION: u32 = 1;
#[derive(Archive, Deserialize, Serialize, PartialEq, Clone)]
#[archive(compare(PartialEq), check_bytes)]
// Derives can be passed through to the generated type:
#[archive_attr(derive(Debug))]
pub struct UUID(uuid::Bytes);
// impl Debug for UUID
impl std::fmt::Debug for UUID {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let uuid = uuid::Uuid::from_bytes(self.0);
write!(f, "{}", uuid)
}
}
impl From<uuid::Uuid> for UUID {
fn from(uuid: uuid::Uuid) -> Self {
UUID(*uuid.as_bytes())
}
}
impl From<UUID> for uuid::Uuid {
fn from(uuid: UUID) -> Self {
uuid::Uuid::from_bytes(uuid.0)
}
}
impl ArchivedUUID {
pub fn to_uuid(&self) -> uuid::Uuid {
uuid::Uuid::from_bytes(self.0)
}
}
impl From<&ArchivedUUID> for UUID {
fn from(uuid: &ArchivedUUID) -> Self {
UUID(uuid.0)
}
}
#[derive(Archive, Deserialize, Serialize, Debug)]
#[archive(compare(PartialEq), check_bytes)]
// Derives can be passed through to the generated type:
#[archive_attr(derive(Debug))]
pub struct HandShake {
pub magic: u32,
pub my_peer_id: UUID,
pub version: u32,
pub features: Vec<String>,
// pub interfaces: Vec<String>,
}
#[derive(Archive, Deserialize, Serialize, Debug)]
#[archive(compare(PartialEq), check_bytes)]
#[archive_attr(derive(Debug))]
pub struct RoutePacket {
pub route_id: u8,
pub body: Vec<u8>,
}
#[derive(Archive, Deserialize, Serialize, Debug)]
#[archive(compare(PartialEq), check_bytes)]
// Derives can be passed through to the generated type:
#[archive_attr(derive(Debug))]
pub enum CtrlPacketBody {
HandShake(HandShake),
RoutePacket(RoutePacket),
Ping,
Pong,
TaRpc(u32, bool, Vec<u8>), // u32: service_id, bool: is_req, Vec<u8>: rpc body
}
#[derive(Archive, Deserialize, Serialize, Debug)]
#[archive(compare(PartialEq), check_bytes)]
// Derives can be passed through to the generated type:
#[archive_attr(derive(Debug))]
pub struct DataPacketBody {
pub data: Vec<u8>,
}
#[derive(Archive, Deserialize, Serialize, Debug)]
#[archive(compare(PartialEq), check_bytes)]
// Derives can be passed through to the generated type:
#[archive_attr(derive(Debug))]
pub enum PacketBody {
Ctrl(CtrlPacketBody),
Data(DataPacketBody),
}
#[derive(Archive, Deserialize, Serialize, Debug)]
#[archive(compare(PartialEq), check_bytes)]
// Derives can be passed through to the generated type:
#[archive_attr(derive(Debug))]
pub struct Packet {
pub from_peer: UUID,
pub to_peer: Option<UUID>,
pub body: PacketBody,
}
impl Packet {
pub fn decode(v: &[u8]) -> &ArchivedPacket {
decode_from_bytes::<Packet>(v).unwrap()
}
}
impl From<Packet> for Bytes {
fn from(val: Packet) -> Self {
encode_to_bytes::<_, 4096>(&val)
}
}
impl Packet {
pub fn new_handshake(from_peer: uuid::Uuid) -> Self {
Packet {
from_peer: from_peer.into(),
to_peer: None,
body: PacketBody::Ctrl(CtrlPacketBody::HandShake(HandShake {
magic: MAGIC,
my_peer_id: from_peer.into(),
version: VERSION,
features: Vec::new(),
})),
}
}
pub fn new_data_packet(from_peer: uuid::Uuid, to_peer: uuid::Uuid, data: &[u8]) -> Self {
Packet {
from_peer: from_peer.into(),
to_peer: Some(to_peer.into()),
body: PacketBody::Data(DataPacketBody {
data: data.to_vec(),
}),
}
}
pub fn new_route_packet(
from_peer: uuid::Uuid,
to_peer: uuid::Uuid,
route_id: u8,
data: &[u8],
) -> Self {
Packet {
from_peer: from_peer.into(),
to_peer: Some(to_peer.into()),
body: PacketBody::Ctrl(CtrlPacketBody::RoutePacket(RoutePacket {
route_id,
body: data.to_vec(),
})),
}
}
pub fn new_ping_packet(from_peer: uuid::Uuid, to_peer: uuid::Uuid) -> Self {
Packet {
from_peer: from_peer.into(),
to_peer: Some(to_peer.into()),
body: PacketBody::Ctrl(CtrlPacketBody::Ping),
}
}
pub fn new_pong_packet(from_peer: uuid::Uuid, to_peer: uuid::Uuid) -> Self {
Packet {
from_peer: from_peer.into(),
to_peer: Some(to_peer.into()),
body: PacketBody::Ctrl(CtrlPacketBody::Pong),
}
}
pub fn new_tarpc_packet(
from_peer: uuid::Uuid,
to_peer: uuid::Uuid,
service_id: u32,
is_req: bool,
body: Vec<u8>,
) -> Self {
Packet {
from_peer: from_peer.into(),
to_peer: Some(to_peer.into()),
body: PacketBody::Ctrl(CtrlPacketBody::TaRpc(service_id, is_req, body)),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn serialize() {
let a = "abcde";
let out = Packet::new_data_packet(uuid::Uuid::new_v4(), uuid::Uuid::new_v4(), a.as_bytes());
// let out = T::new(a.as_bytes());
let out_bytes: Bytes = out.into();
println!("out str: {:?}", a.as_bytes());
println!("out bytes: {:?}", out_bytes);
let archived = Packet::decode(&out_bytes[..]);
println!("in packet: {:?}", archived);
}
}
+218
View File
@@ -0,0 +1,218 @@
use std::sync::Arc;
use dashmap::DashMap;
use easytier_rpc::PeerConnInfo;
use tokio::{
select,
sync::{mpsc, Mutex},
task::JoinHandle,
};
use tokio_util::bytes::Bytes;
use tracing::Instrument;
use uuid::Uuid;
use crate::common::{
error::Error,
global_ctx::{ArcGlobalCtx, GlobalCtxEvent},
};
use super::peer_conn::PeerConn;
type ArcPeerConn = Arc<Mutex<PeerConn>>;
type ConnMap = Arc<DashMap<Uuid, ArcPeerConn>>;
pub struct Peer {
pub peer_node_id: uuid::Uuid,
conns: ConnMap,
global_ctx: ArcGlobalCtx,
packet_recv_chan: mpsc::Sender<Bytes>,
close_event_sender: mpsc::Sender<Uuid>,
close_event_listener: JoinHandle<()>,
shutdown_notifier: Arc<tokio::sync::Notify>,
}
impl Peer {
pub fn new(
peer_node_id: uuid::Uuid,
packet_recv_chan: mpsc::Sender<Bytes>,
global_ctx: ArcGlobalCtx,
) -> Self {
let conns: ConnMap = Arc::new(DashMap::new());
let (close_event_sender, mut close_event_receiver) = mpsc::channel(10);
let shutdown_notifier = Arc::new(tokio::sync::Notify::new());
let conns_copy = conns.clone();
let shutdown_notifier_copy = shutdown_notifier.clone();
let global_ctx_copy = global_ctx.clone();
let close_event_listener = tokio::spawn(
async move {
loop {
select! {
ret = close_event_receiver.recv() => {
if ret.is_none() {
break;
}
let ret = ret.unwrap();
tracing::warn!(
?peer_node_id,
?ret,
"notified that peer conn is closed",
);
if let Some((_, conn)) = conns_copy.remove(&ret) {
global_ctx_copy.issue_event(GlobalCtxEvent::PeerConnRemoved(
conn.lock().await.get_conn_info(),
));
}
}
_ = shutdown_notifier_copy.notified() => {
close_event_receiver.close();
tracing::warn!(?peer_node_id, "peer close event listener notified");
}
}
}
tracing::info!("peer {} close event listener exit", peer_node_id);
}
.instrument(tracing::info_span!(
"peer_close_event_listener",
?peer_node_id,
)),
);
Peer {
peer_node_id,
conns: conns.clone(),
packet_recv_chan,
global_ctx,
close_event_sender,
close_event_listener,
shutdown_notifier,
}
}
pub async fn add_peer_conn(&self, mut conn: PeerConn) {
conn.set_close_event_sender(self.close_event_sender.clone());
conn.start_recv_loop(self.packet_recv_chan.clone());
self.global_ctx
.issue_event(GlobalCtxEvent::PeerConnAdded(conn.get_conn_info()));
self.conns
.insert(conn.get_conn_id(), Arc::new(Mutex::new(conn)));
}
pub async fn send_msg(&self, msg: Bytes) -> Result<(), Error> {
let Some(conn) = self.conns.iter().next() else {
return Err(Error::PeerNoConnectionError(self.peer_node_id));
};
let conn_clone = conn.clone();
drop(conn);
conn_clone.lock().await.send_msg(msg).await?;
Ok(())
}
pub async fn close_peer_conn(&self, conn_id: &Uuid) -> Result<(), Error> {
let has_key = self.conns.contains_key(conn_id);
if !has_key {
return Err(Error::NotFound);
}
self.close_event_sender.send(conn_id.clone()).await.unwrap();
Ok(())
}
pub async fn list_peer_conns(&self) -> Vec<PeerConnInfo> {
let mut conns = vec![];
for conn in self.conns.iter() {
// do not lock here, otherwise it will cause dashmap deadlock
conns.push(conn.clone());
}
let mut ret = Vec::new();
for conn in conns {
ret.push(conn.lock().await.get_conn_info());
}
ret
}
}
// pritn on drop
impl Drop for Peer {
fn drop(&mut self) {
self.shutdown_notifier.notify_one();
tracing::info!("peer {} drop", self.peer_node_id);
}
}
#[cfg(test)]
mod tests {
use std::sync::Arc;
use tokio::{sync::mpsc, time::timeout};
use crate::{
common::{config_fs::ConfigFs, global_ctx::GlobalCtx, netns::NetNS},
peers::peer_conn::PeerConn,
tunnels::ring_tunnel::create_ring_tunnel_pair,
};
use super::Peer;
#[tokio::test]
async fn close_peer() {
let (local_packet_send, _local_packet_recv) = mpsc::channel(10);
let (remote_packet_send, _remote_packet_recv) = mpsc::channel(10);
let global_ctx = Arc::new(GlobalCtx::new(
"test",
ConfigFs::new("/tmp/easytier-test"),
NetNS::new(None),
));
let local_peer = Peer::new(uuid::Uuid::new_v4(), local_packet_send, global_ctx.clone());
let remote_peer = Peer::new(uuid::Uuid::new_v4(), remote_packet_send, global_ctx.clone());
let (local_tunnel, remote_tunnel) = create_ring_tunnel_pair();
let mut local_peer_conn =
PeerConn::new(local_peer.peer_node_id, global_ctx.clone(), local_tunnel);
let mut remote_peer_conn =
PeerConn::new(remote_peer.peer_node_id, global_ctx.clone(), remote_tunnel);
assert!(!local_peer_conn.handshake_done());
assert!(!remote_peer_conn.handshake_done());
let (a, b) = tokio::join!(
local_peer_conn.do_handshake_as_client(),
remote_peer_conn.do_handshake_as_server()
);
a.unwrap();
b.unwrap();
let local_conn_id = local_peer_conn.get_conn_id();
local_peer.add_peer_conn(local_peer_conn).await;
remote_peer.add_peer_conn(remote_peer_conn).await;
assert_eq!(local_peer.list_peer_conns().await.len(), 1);
assert_eq!(remote_peer.list_peer_conns().await.len(), 1);
let close_handler =
tokio::spawn(async move { local_peer.close_peer_conn(&local_conn_id).await });
// wait for remote peer conn close
timeout(std::time::Duration::from_secs(5), async {
while (&remote_peer).list_peer_conns().await.len() != 0 {
tokio::time::sleep(std::time::Duration::from_millis(100)).await;
}
})
.await
.unwrap();
println!("wait for close handler");
close_handler.await.unwrap().unwrap();
}
}
+484
View File
@@ -0,0 +1,484 @@
use std::{pin::Pin, sync::Arc};
use easytier_rpc::{PeerConnInfo, PeerConnStats};
use futures::{SinkExt, StreamExt};
use pnet::datalink::NetworkInterface;
use tokio::{
sync::{broadcast, mpsc},
task::JoinSet,
time::{timeout, Duration},
};
use tokio_util::{
bytes::{Bytes, BytesMut},
sync::PollSender,
};
use tracing::Instrument;
use crate::{
common::global_ctx::ArcGlobalCtx,
define_tunnel_filter_chain,
tunnels::{
stats::{Throughput, WindowLatency},
tunnel_filter::StatsRecorderTunnelFilter,
DatagramSink, Tunnel, TunnelError,
},
};
use super::packet::{self, ArchivedCtrlPacketBody, ArchivedHandShake, Packet};
pub type PacketRecvChan = mpsc::Sender<Bytes>;
macro_rules! wait_response {
($stream: ident, $out_var:ident, $pattern:pat_param => $value:expr) => {
let rsp_vec = timeout(Duration::from_secs(1), $stream.next()).await;
if rsp_vec.is_err() {
return Err(TunnelError::WaitRespError(
"wait handshake response timeout".to_owned(),
));
}
let rsp_vec = rsp_vec.unwrap().unwrap()?;
let $out_var;
let rsp_bytes = Packet::decode(&rsp_vec);
match &rsp_bytes.body {
$pattern => $out_var = $value,
_ => {
log::error!(
"unexpected packet: {:?}, pattern: {:?}",
rsp_bytes,
stringify!($pattern)
);
return Err(TunnelError::WaitRespError("unexpected packet".to_owned()));
}
}
};
}
pub struct PeerInfo {
magic: u32,
pub my_peer_id: uuid::Uuid,
version: u32,
pub features: Vec<String>,
pub interfaces: Vec<NetworkInterface>,
}
impl<'a> From<&ArchivedHandShake> for PeerInfo {
fn from(hs: &ArchivedHandShake) -> Self {
PeerInfo {
magic: hs.magic.into(),
my_peer_id: hs.my_peer_id.to_uuid(),
version: hs.version.into(),
features: hs.features.iter().map(|x| x.to_string()).collect(),
interfaces: Vec::new(),
}
}
}
define_tunnel_filter_chain!(PeerConnTunnel, stats = StatsRecorderTunnelFilter);
pub struct PeerConn {
conn_id: uuid::Uuid,
my_node_id: uuid::Uuid,
global_ctx: ArcGlobalCtx,
sink: Pin<Box<dyn DatagramSink>>,
tunnel: Box<dyn Tunnel>,
tasks: JoinSet<Result<(), TunnelError>>,
info: Option<PeerInfo>,
close_event_sender: Option<mpsc::Sender<uuid::Uuid>>,
ctrl_resp_sender: broadcast::Sender<Bytes>,
latency_stats: Arc<WindowLatency>,
throughput: Arc<Throughput>,
}
enum PeerConnPacketType {
Data(Bytes),
CtrlReq(Bytes),
CtrlResp(Bytes),
}
static CTRL_REQ_PACKET_PREFIX: &[u8] = &[0x12, 0x34, 0x56, 0x78, 0x9a, 0xbc, 0xde, 0xf0];
static CTRL_RESP_PACKET_PREFIX: &[u8] = &[0x12, 0x34, 0x56, 0x78, 0x9a, 0xbc, 0xde, 0xf1];
impl PeerConn {
pub fn new(node_id: uuid::Uuid, global_ctx: ArcGlobalCtx, tunnel: Box<dyn Tunnel>) -> Self {
let (ctrl_sender, _ctrl_receiver) = broadcast::channel(100);
let peer_conn_tunnel = PeerConnTunnel::new();
let tunnel = peer_conn_tunnel.wrap_tunnel(tunnel);
PeerConn {
conn_id: uuid::Uuid::new_v4(),
my_node_id: node_id,
global_ctx,
sink: tunnel.pin_sink(),
tunnel: Box::new(tunnel),
tasks: JoinSet::new(),
info: None,
close_event_sender: None,
ctrl_resp_sender: ctrl_sender,
latency_stats: Arc::new(WindowLatency::new(15)),
throughput: peer_conn_tunnel.stats.get_throughput().clone(),
}
}
pub fn get_conn_id(&self) -> uuid::Uuid {
self.conn_id
}
pub async fn do_handshake_as_server(&mut self) -> Result<(), TunnelError> {
let mut stream = self.tunnel.pin_stream();
let mut sink = self.tunnel.pin_sink();
wait_response!(stream, hs_req, packet::ArchivedPacketBody::Ctrl(ArchivedCtrlPacketBody::HandShake(x)) => x);
self.info = Some(PeerInfo::from(hs_req));
log::info!("handshake request: {:?}", hs_req);
let hs_req = self
.global_ctx
.net_ns
.run(|| packet::Packet::new_handshake(self.my_node_id));
sink.send(hs_req.into()).await?;
Ok(())
}
pub async fn do_handshake_as_client(&mut self) -> Result<(), TunnelError> {
let mut stream = self.tunnel.pin_stream();
let mut sink = self.tunnel.pin_sink();
let hs_req = self
.global_ctx
.net_ns
.run(|| packet::Packet::new_handshake(self.my_node_id));
sink.send(hs_req.into()).await?;
wait_response!(stream, hs_rsp, packet::ArchivedPacketBody::Ctrl(ArchivedCtrlPacketBody::HandShake(x)) => x);
self.info = Some(PeerInfo::from(hs_rsp));
log::info!("handshake response: {:?}", hs_rsp);
Ok(())
}
pub fn handshake_done(&self) -> bool {
self.info.is_some()
}
async fn do_pingpong_once(
my_node_id: uuid::Uuid,
peer_id: uuid::Uuid,
sink: &mut Pin<Box<dyn DatagramSink>>,
receiver: &mut broadcast::Receiver<Bytes>,
) -> Result<u128, TunnelError> {
// should add seq here. so latency can be calculated more accurately
let req = Self::build_ctrl_msg(
packet::Packet::new_ping_packet(my_node_id, peer_id).into(),
true,
);
log::trace!("send ping packet: {:?}", req);
sink.send(req).await?;
let now = std::time::Instant::now();
// wait until we get a pong packet in ctrl_resp_receiver
let resp = timeout(Duration::from_secs(4), async {
loop {
match receiver.recv().await {
Ok(p) => {
if let packet::ArchivedPacketBody::Ctrl(
packet::ArchivedCtrlPacketBody::Pong,
) = &Packet::decode(&p).body
{
break;
}
}
Err(e) => {
log::warn!("recv pong resp error: {:?}", e);
return Err(TunnelError::WaitRespError(
"recv pong resp error".to_owned(),
));
}
}
}
Ok(())
})
.await;
if resp.is_err() {
return Err(TunnelError::WaitRespError(
"wait ping response timeout".to_owned(),
));
}
if resp.as_ref().unwrap().is_err() {
return Err(resp.unwrap().err().unwrap());
}
Ok(now.elapsed().as_micros())
}
fn start_pingpong(&mut self) {
let mut sink = self.tunnel.pin_sink();
let my_node_id = self.my_node_id;
let peer_id = self.get_peer_id();
let receiver = self.ctrl_resp_sender.subscribe();
let close_event_sender = self.close_event_sender.clone().unwrap();
let conn_id = self.conn_id;
let latency_stats = self.latency_stats.clone();
self.tasks.spawn(async move {
//sleep 1s
tokio::time::sleep(Duration::from_secs(1)).await;
loop {
let mut receiver = receiver.resubscribe();
if let Ok(lat) =
Self::do_pingpong_once(my_node_id, peer_id, &mut sink, &mut receiver).await
{
log::trace!(
"pingpong latency: {}us, my_node_id: {}, peer_id: {}",
lat,
my_node_id,
peer_id
);
latency_stats.record_latency(lat as u64);
tokio::time::sleep(Duration::from_secs(1)).await;
} else {
break;
}
}
log::warn!(
"pingpong task exit, my_node_id: {}, peer_id: {}",
my_node_id,
peer_id,
);
if let Err(e) = close_event_sender.send(conn_id).await {
log::warn!("close event sender error: {:?}", e);
}
Ok(())
});
}
fn get_packet_type(mut bytes_item: Bytes) -> PeerConnPacketType {
if bytes_item.starts_with(CTRL_REQ_PACKET_PREFIX) {
PeerConnPacketType::CtrlReq(bytes_item.split_off(CTRL_REQ_PACKET_PREFIX.len()))
} else if bytes_item.starts_with(CTRL_RESP_PACKET_PREFIX) {
PeerConnPacketType::CtrlResp(bytes_item.split_off(CTRL_RESP_PACKET_PREFIX.len()))
} else {
PeerConnPacketType::Data(bytes_item)
}
}
fn handle_ctrl_req_packet(
bytes_item: Bytes,
conn_info: &PeerConnInfo,
) -> Result<Bytes, TunnelError> {
let packet = Packet::decode(&bytes_item);
match packet.body {
packet::ArchivedPacketBody::Ctrl(packet::ArchivedCtrlPacketBody::Ping) => {
log::trace!("recv ping packet: {:?}", packet);
Ok(Self::build_ctrl_msg(
packet::Packet::new_pong_packet(
conn_info.my_node_id.parse().unwrap(),
conn_info.peer_id.parse().unwrap(),
)
.into(),
false,
))
}
_ => {
log::error!("unexpected packet: {:?}", packet);
Err(TunnelError::CommonError("unexpected packet".to_owned()))
}
}
}
pub fn start_recv_loop(&mut self, packet_recv_chan: PacketRecvChan) {
let mut stream = self.tunnel.pin_stream();
let mut sink = self.tunnel.pin_sink();
let mut sender = PollSender::new(packet_recv_chan.clone());
let close_event_sender = self.close_event_sender.clone().unwrap();
let conn_id = self.conn_id;
let ctrl_sender = self.ctrl_resp_sender.clone();
let conn_info = self.get_conn_info();
let conn_info_for_instrument = self.get_conn_info();
self.tasks.spawn(
async move {
tracing::info!("start recving peer conn packet");
while let Some(ret) = stream.next().await {
if ret.is_err() {
tracing::error!(error = ?ret, "peer conn recv error");
if let Err(close_ret) = sink.close().await {
tracing::error!(error = ?close_ret, "peer conn sink close error, ignore it");
}
if let Err(e) = close_event_sender.send(conn_id).await {
tracing::error!(error = ?e, "peer conn close event send error");
}
return Err(ret.err().unwrap());
}
match Self::get_packet_type(ret.unwrap().into()) {
PeerConnPacketType::Data(item) => sender.send(item).await.unwrap(),
PeerConnPacketType::CtrlReq(item) => {
let ret = Self::handle_ctrl_req_packet(item, &conn_info).unwrap();
if let Err(e) = sink.send(ret).await {
tracing::error!(?e, "peer conn send req error");
}
}
PeerConnPacketType::CtrlResp(item) => {
if let Err(e) = ctrl_sender.send(item) {
tracing::error!(?e, "peer conn send ctrl resp error");
}
}
}
}
tracing::info!("end recving peer conn packet");
Ok(())
}
.instrument(
tracing::info_span!("peer conn recv loop", conn_info = ?conn_info_for_instrument),
),
);
self.start_pingpong();
}
pub async fn send_msg(&mut self, msg: Bytes) -> Result<(), TunnelError> {
self.sink.send(msg).await
}
fn build_ctrl_msg(msg: Bytes, is_req: bool) -> Bytes {
let prefix: &'static [u8] = if is_req {
CTRL_REQ_PACKET_PREFIX
} else {
CTRL_RESP_PACKET_PREFIX
};
let mut new_msg = BytesMut::new();
new_msg.reserve(prefix.len() + msg.len());
new_msg.extend_from_slice(prefix);
new_msg.extend_from_slice(&msg);
new_msg.into()
}
pub fn get_peer_id(&self) -> uuid::Uuid {
self.info.as_ref().unwrap().my_peer_id
}
pub fn set_close_event_sender(&mut self, sender: mpsc::Sender<uuid::Uuid>) {
self.close_event_sender = Some(sender);
}
pub fn get_stats(&self) -> PeerConnStats {
PeerConnStats {
latency_us: self.latency_stats.get_latency_us(),
tx_bytes: self.throughput.tx_bytes(),
rx_bytes: self.throughput.rx_bytes(),
tx_packets: self.throughput.tx_packets(),
rx_packets: self.throughput.rx_packets(),
}
}
pub fn get_conn_info(&self) -> PeerConnInfo {
PeerConnInfo {
conn_id: self.conn_id.to_string(),
my_node_id: self.my_node_id.to_string(),
peer_id: self.get_peer_id().to_string(),
features: self.info.as_ref().unwrap().features.clone(),
tunnel: self.tunnel.info(),
stats: Some(self.get_stats()),
}
}
}
impl Drop for PeerConn {
fn drop(&mut self) {
let mut sink = self.tunnel.pin_sink();
tokio::spawn(async move {
let ret = sink.close().await;
tracing::info!(error = ?ret, "peer conn tunnel closed.");
});
log::info!("peer conn {:?} drop", self.conn_id);
}
}
#[cfg(test)]
mod tests {
use std::sync::Arc;
use super::*;
use crate::common::config_fs::ConfigFs;
use crate::common::global_ctx::GlobalCtx;
use crate::common::netns::NetNS;
use crate::tunnels::tunnel_filter::{PacketRecorderTunnelFilter, TunnelWithFilter};
#[tokio::test]
async fn peer_conn_handshake() {
use crate::tunnels::ring_tunnel::create_ring_tunnel_pair;
let (c, s) = create_ring_tunnel_pair();
let c_recorder = Arc::new(PacketRecorderTunnelFilter::new());
let s_recorder = Arc::new(PacketRecorderTunnelFilter::new());
let c = TunnelWithFilter::new(c, c_recorder.clone());
let s = TunnelWithFilter::new(s, s_recorder.clone());
let c_uuid = uuid::Uuid::new_v4();
let s_uuid = uuid::Uuid::new_v4();
let mut c_peer = PeerConn::new(
c_uuid,
Arc::new(GlobalCtx::new(
"c",
ConfigFs::new_with_dir("c", "/tmp"),
NetNS::new(None),
)),
Box::new(c),
);
let mut s_peer = PeerConn::new(
s_uuid,
Arc::new(GlobalCtx::new(
"c",
ConfigFs::new_with_dir("c", "/tmp"),
NetNS::new(None),
)),
Box::new(s),
);
let (c_ret, s_ret) = tokio::join!(
c_peer.do_handshake_as_client(),
s_peer.do_handshake_as_server()
);
c_ret.unwrap();
s_ret.unwrap();
assert_eq!(c_recorder.sent.lock().unwrap().len(), 1);
assert_eq!(c_recorder.received.lock().unwrap().len(), 1);
assert_eq!(s_recorder.sent.lock().unwrap().len(), 1);
assert_eq!(s_recorder.received.lock().unwrap().len(), 1);
assert_eq!(c_peer.get_peer_id(), s_uuid);
assert_eq!(s_peer.get_peer_id(), c_uuid);
}
}
+539
View File
@@ -0,0 +1,539 @@
use std::{
fmt::Debug,
net::Ipv4Addr,
sync::{atomic::AtomicU8, Arc},
};
use async_trait::async_trait;
use futures::{StreamExt, TryFutureExt};
use tokio::{
sync::{
mpsc::{self, UnboundedReceiver, UnboundedSender},
Mutex, RwLock,
},
task::JoinSet,
};
use tokio_stream::wrappers::ReceiverStream;
use tokio_util::bytes::{Bytes, BytesMut};
use uuid::Uuid;
use crate::{
common::{error::Error, global_ctx::ArcGlobalCtx, rkyv_util::extract_bytes_from_archived_vec},
peers::{
packet::{self},
peer_conn::PeerConn,
peer_rpc::PeerRpcManagerTransport,
route_trait::RouteInterface,
},
tunnels::{SinkItem, Tunnel, TunnelConnector},
};
use super::{
peer_map::PeerMap,
peer_rpc::PeerRpcManager,
route_trait::{ArcRoute, Route},
PeerId,
};
struct RpcTransport {
my_peer_id: uuid::Uuid,
peers: Arc<PeerMap>,
packet_recv: Mutex<UnboundedReceiver<Bytes>>,
peer_rpc_tspt_sender: UnboundedSender<Bytes>,
route: Arc<Mutex<Option<ArcRoute>>>,
}
#[async_trait::async_trait]
impl PeerRpcManagerTransport for RpcTransport {
fn my_peer_id(&self) -> Uuid {
self.my_peer_id
}
async fn send(&self, msg: Bytes, dst_peer_id: &uuid::Uuid) -> Result<(), Error> {
let route = self.route.lock().await;
if route.is_none() {
log::error!("no route info when send rpc msg");
return Err(Error::RouteError("No route info".to_string()));
}
self.peers
.send_msg(msg, dst_peer_id, route.as_ref().unwrap().clone())
.map_err(|e| e.into())
.await
}
async fn recv(&self) -> Result<Bytes, Error> {
if let Some(o) = self.packet_recv.lock().await.recv().await {
Ok(o)
} else {
Err(Error::Unknown)
}
}
}
#[async_trait::async_trait]
#[auto_impl::auto_impl(Arc)]
pub trait PeerPacketFilter {
async fn try_process_packet_from_peer(
&self,
packet: &packet::ArchivedPacket,
data: &Bytes,
) -> Option<()>;
}
#[async_trait::async_trait]
#[auto_impl::auto_impl(Arc)]
pub trait NicPacketFilter {
async fn try_process_packet_from_nic(&self, data: BytesMut) -> BytesMut;
}
type BoxPeerPacketFilter = Box<dyn PeerPacketFilter + Send + Sync>;
type BoxNicPacketFilter = Box<dyn NicPacketFilter + Send + Sync>;
pub struct PeerManager {
my_node_id: uuid::Uuid,
global_ctx: ArcGlobalCtx,
nic_channel: mpsc::Sender<SinkItem>,
tasks: Arc<Mutex<JoinSet<()>>>,
packet_recv: Arc<Mutex<Option<mpsc::Receiver<Bytes>>>>,
peers: Arc<PeerMap>,
route: Arc<Mutex<Option<ArcRoute>>>,
cur_route_id: AtomicU8,
peer_rpc_mgr: Arc<PeerRpcManager>,
peer_rpc_tspt: Arc<RpcTransport>,
peer_packet_process_pipeline: Arc<RwLock<Vec<BoxPeerPacketFilter>>>,
nic_packet_process_pipeline: Arc<RwLock<Vec<BoxNicPacketFilter>>>,
}
impl Debug for PeerManager {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("PeerManager")
.field("my_node_id", &self.my_node_id)
.field("instance_name", &self.global_ctx.inst_name)
.field("net_ns", &self.global_ctx.net_ns.name())
.field("cur_route_id", &self.cur_route_id)
.finish()
}
}
impl PeerManager {
pub fn new(global_ctx: ArcGlobalCtx, nic_channel: mpsc::Sender<SinkItem>) -> Self {
let (packet_send, packet_recv) = mpsc::channel(100);
let peers = Arc::new(PeerMap::new(packet_send.clone()));
// TODO: remove these because we have impl pipeline processor.
let (peer_rpc_tspt_sender, peer_rpc_tspt_recv) = mpsc::unbounded_channel();
let rpc_tspt = Arc::new(RpcTransport {
my_peer_id: global_ctx.get_id(),
peers: peers.clone(),
packet_recv: Mutex::new(peer_rpc_tspt_recv),
peer_rpc_tspt_sender,
route: Arc::new(Mutex::new(None)),
});
PeerManager {
my_node_id: global_ctx.get_id(),
global_ctx,
nic_channel,
tasks: Arc::new(Mutex::new(JoinSet::new())),
packet_recv: Arc::new(Mutex::new(Some(packet_recv))),
peers: peers.clone(),
route: Arc::new(Mutex::new(None)),
cur_route_id: AtomicU8::new(0),
peer_rpc_mgr: Arc::new(PeerRpcManager::new(rpc_tspt.clone())),
peer_rpc_tspt: rpc_tspt,
peer_packet_process_pipeline: Arc::new(RwLock::new(Vec::new())),
nic_packet_process_pipeline: Arc::new(RwLock::new(Vec::new())),
}
}
pub async fn add_client_tunnel(&self, tunnel: Box<dyn Tunnel>) -> Result<(Uuid, Uuid), Error> {
let mut peer = PeerConn::new(self.my_node_id, self.global_ctx.clone(), tunnel);
peer.do_handshake_as_client().await?;
let conn_id = peer.get_conn_id();
let peer_id = peer.get_peer_id();
self.peers
.add_new_peer_conn(peer, self.global_ctx.clone())
.await;
Ok((peer_id, conn_id))
}
#[tracing::instrument]
pub async fn try_connect<C>(&self, mut connector: C) -> Result<(Uuid, Uuid), Error>
where
C: TunnelConnector + Debug,
{
let ns = self.global_ctx.net_ns.clone();
let t = ns
.run_async(|| async move { connector.connect().await })
.await?;
self.add_client_tunnel(t).await
}
#[tracing::instrument]
pub async fn add_tunnel_as_server(&self, tunnel: Box<dyn Tunnel>) -> Result<(), Error> {
tracing::info!("add tunnel as server start");
let mut peer = PeerConn::new(self.my_node_id, self.global_ctx.clone(), tunnel);
peer.do_handshake_as_server().await?;
self.peers
.add_new_peer_conn(peer, self.global_ctx.clone())
.await;
tracing::info!("add tunnel as server done");
Ok(())
}
async fn start_peer_recv(&self) {
let mut recv = ReceiverStream::new(self.packet_recv.lock().await.take().unwrap());
let my_node_id = self.my_node_id;
let peers = self.peers.clone();
let arc_route = self.route.clone();
let pipe_line = self.peer_packet_process_pipeline.clone();
self.tasks.lock().await.spawn(async move {
log::trace!("start_peer_recv");
while let Some(ret) = recv.next().await {
log::trace!("peer recv a packet...: {:?}", ret);
let packet = packet::Packet::decode(&ret);
let from_peer_uuid = packet.from_peer.to_uuid();
let to_peer_uuid = packet.to_peer.as_ref().unwrap().to_uuid();
if to_peer_uuid != my_node_id {
let locked_arc_route = arc_route.lock().await;
if locked_arc_route.is_none() {
log::error!("no route info after recv a packet");
continue;
}
let route = locked_arc_route.as_ref().unwrap().clone();
drop(locked_arc_route);
log::trace!(
"need forward: to_peer_uuid: {:?}, my_uuid: {:?}",
to_peer_uuid,
my_node_id
);
let ret = peers
.send_msg(ret.clone(), &to_peer_uuid, route.clone())
.await;
if ret.is_err() {
log::error!(
"forward packet error: {:?}, dst: {:?}, from: {:?}",
ret,
to_peer_uuid,
from_peer_uuid
);
}
} else {
let mut processed = false;
for pipeline in pipe_line.read().await.iter().rev() {
if let Some(_) = pipeline.try_process_packet_from_peer(&packet, &ret).await
{
processed = true;
break;
}
}
if !processed {
tracing::error!("unexpected packet: {:?}", ret);
}
}
}
panic!("done_peer_recv");
});
}
pub async fn add_packet_process_pipeline(&self, pipeline: BoxPeerPacketFilter) {
// newest pipeline will be executed first
self.peer_packet_process_pipeline
.write()
.await
.push(pipeline);
}
pub async fn add_nic_packet_process_pipeline(&self, pipeline: BoxNicPacketFilter) {
// newest pipeline will be executed first
self.nic_packet_process_pipeline
.write()
.await
.push(pipeline);
}
async fn init_packet_process_pipeline(&self) {
use packet::ArchivedPacketBody;
// for tun/tap ip/eth packet.
struct NicPacketProcessor {
nic_channel: mpsc::Sender<SinkItem>,
}
#[async_trait::async_trait]
impl PeerPacketFilter for NicPacketProcessor {
async fn try_process_packet_from_peer(
&self,
packet: &packet::ArchivedPacket,
data: &Bytes,
) -> Option<()> {
if let packet::ArchivedPacketBody::Data(x) = &packet.body {
// TODO: use a function to get the body ref directly for zero copy
self.nic_channel
.send(extract_bytes_from_archived_vec(&data, &x.data))
.await
.unwrap();
Some(())
} else {
None
}
}
}
self.add_packet_process_pipeline(Box::new(NicPacketProcessor {
nic_channel: self.nic_channel.clone(),
}))
.await;
// for peer manager router packet
struct RoutePacketProcessor {
route: Arc<Mutex<Option<ArcRoute>>>,
}
#[async_trait::async_trait]
impl PeerPacketFilter for RoutePacketProcessor {
async fn try_process_packet_from_peer(
&self,
packet: &packet::ArchivedPacket,
data: &Bytes,
) -> Option<()> {
if let ArchivedPacketBody::Ctrl(packet::ArchivedCtrlPacketBody::RoutePacket(
route_packet,
)) = &packet.body
{
let r = self.route.lock().await;
match r.as_ref() {
Some(x) => {
let x = x.clone();
drop(r);
x.handle_route_packet(
packet.from_peer.to_uuid(),
extract_bytes_from_archived_vec(&data, &route_packet.body),
)
.await;
}
None => {
log::error!("no route info when handle route packet");
}
}
Some(())
} else {
None
}
}
}
self.add_packet_process_pipeline(Box::new(RoutePacketProcessor {
route: self.route.clone(),
}))
.await;
// for peer rpc packet
struct PeerRpcPacketProcessor {
peer_rpc_tspt_sender: UnboundedSender<Bytes>,
}
#[async_trait::async_trait]
impl PeerPacketFilter for PeerRpcPacketProcessor {
async fn try_process_packet_from_peer(
&self,
packet: &packet::ArchivedPacket,
data: &Bytes,
) -> Option<()> {
if let ArchivedPacketBody::Ctrl(packet::ArchivedCtrlPacketBody::TaRpc(..)) =
&packet.body
{
self.peer_rpc_tspt_sender.send(data.clone()).unwrap();
Some(())
} else {
None
}
}
}
self.add_packet_process_pipeline(Box::new(PeerRpcPacketProcessor {
peer_rpc_tspt_sender: self.peer_rpc_tspt.peer_rpc_tspt_sender.clone(),
}))
.await;
}
pub async fn set_route<T>(&self, route: T)
where
T: Route + Send + Sync + 'static,
{
struct Interface {
my_node_id: uuid::Uuid,
peers: Arc<PeerMap>,
}
#[async_trait]
impl RouteInterface for Interface {
async fn list_peers(&self) -> Vec<PeerId> {
self.peers.list_peers_with_conn().await
}
async fn send_route_packet(
&self,
msg: Bytes,
route_id: u8,
dst_peer_id: &PeerId,
) -> Result<(), Error> {
self.peers
.send_msg_directly(
packet::Packet::new_route_packet(
self.my_node_id,
*dst_peer_id,
route_id,
&msg,
)
.into(),
dst_peer_id,
)
.await
}
}
let my_node_id = self.my_node_id;
let route_id = route
.open(Box::new(Interface {
my_node_id,
peers: self.peers.clone(),
}))
.await
.unwrap();
self.cur_route_id
.store(route_id, std::sync::atomic::Ordering::Relaxed);
let arc_route: ArcRoute = Arc::new(Box::new(route));
self.route.lock().await.replace(arc_route.clone());
self.peer_rpc_tspt
.route
.lock()
.await
.replace(arc_route.clone());
}
pub async fn list_routes(&self) -> Vec<easytier_rpc::Route> {
let route_info = self.route.lock().await;
if route_info.is_none() {
return Vec::new();
}
let route = route_info.as_ref().unwrap().clone();
drop(route_info);
route.list_routes().await
}
async fn run_nic_packet_process_pipeline(&self, mut data: BytesMut) -> BytesMut {
for pipeline in self.nic_packet_process_pipeline.read().await.iter().rev() {
data = pipeline.try_process_packet_from_nic(data).await;
}
data
}
pub async fn send_msg(&self, msg: Bytes, dst_peer_id: &PeerId) -> Result<(), Error> {
self.peer_rpc_tspt.send(msg, dst_peer_id).await
}
pub async fn send_msg_ipv4(&self, msg: BytesMut, ipv4_addr: Ipv4Addr) -> Result<(), Error> {
let route_info = self.route.lock().await;
if route_info.is_none() {
log::error!("no route info");
return Err(Error::RouteError("No route info".to_string()));
}
let route = route_info.as_ref().unwrap().clone();
drop(route_info);
log::trace!(
"do send_msg in peer manager, msg: {:?}, ipv4_addr: {}",
msg,
ipv4_addr
);
match route.get_peer_id_by_ipv4(&ipv4_addr).await {
Some(peer_id) => {
let msg = self.run_nic_packet_process_pipeline(msg).await;
self.peers
.send_msg(
packet::Packet::new_data_packet(self.my_node_id, peer_id, &msg).into(),
&peer_id,
route.clone(),
)
.await?;
log::trace!(
"do send_msg in peer manager done, dst_peer_id: {:?}",
peer_id
);
}
None => {
log::trace!("no peer id for ipv4: {}", ipv4_addr);
return Ok(());
}
}
Ok(())
}
async fn run_clean_peer_without_conn_routine(&self) {
let peer_map = self.peers.clone();
self.tasks.lock().await.spawn(async move {
loop {
let mut to_remove = vec![];
for peer_id in peer_map.list_peers().await {
let conns = peer_map.list_peer_conns(&peer_id).await;
if conns.is_none() || conns.as_ref().unwrap().is_empty() {
to_remove.push(peer_id);
}
}
for peer_id in to_remove {
peer_map.close_peer(&peer_id).await.unwrap();
}
tokio::time::sleep(std::time::Duration::from_secs(10)).await;
}
});
}
pub async fn run(&self) -> Result<(), Error> {
self.init_packet_process_pipeline().await;
self.start_peer_recv().await;
self.peer_rpc_mgr.run();
self.run_clean_peer_without_conn_routine().await;
Ok(())
}
pub fn get_peer_map(&self) -> Arc<PeerMap> {
self.peers.clone()
}
pub fn get_peer_rpc_mgr(&self) -> Arc<PeerRpcManager> {
self.peer_rpc_mgr.clone()
}
pub fn my_node_id(&self) -> uuid::Uuid {
self.my_node_id
}
pub fn get_global_ctx(&self) -> ArcGlobalCtx {
self.global_ctx.clone()
}
pub fn get_nic_channel(&self) -> mpsc::Sender<SinkItem> {
self.nic_channel.clone()
}
}
+140
View File
@@ -0,0 +1,140 @@
use std::sync::Arc;
use dashmap::DashMap;
use easytier_rpc::PeerConnInfo;
use tokio::sync::mpsc;
use tokio_util::bytes::Bytes;
use crate::{
common::{error::Error, global_ctx::ArcGlobalCtx},
tunnels::TunnelError,
};
use super::{peer::Peer, peer_conn::PeerConn, route_trait::ArcRoute, PeerId};
pub struct PeerMap {
peer_map: DashMap<PeerId, Arc<Peer>>,
packet_send: mpsc::Sender<Bytes>,
}
impl PeerMap {
pub fn new(packet_send: mpsc::Sender<Bytes>) -> Self {
PeerMap {
peer_map: DashMap::new(),
packet_send,
}
}
async fn add_new_peer(&self, peer: Peer) {
self.peer_map.insert(peer.peer_node_id, Arc::new(peer));
}
pub async fn add_new_peer_conn(&self, peer_conn: PeerConn, global_ctx: ArcGlobalCtx) {
let peer_id = peer_conn.get_peer_id();
let no_entry = self.peer_map.get(&peer_id).is_none();
if no_entry {
let new_peer = Peer::new(peer_id, self.packet_send.clone(), global_ctx);
new_peer.add_peer_conn(peer_conn).await;
self.add_new_peer(new_peer).await;
} else {
let peer = self.peer_map.get(&peer_id).unwrap().clone();
peer.add_peer_conn(peer_conn).await;
}
}
fn get_peer_by_id(&self, peer_id: &PeerId) -> Option<Arc<Peer>> {
self.peer_map.get(peer_id).map(|v| v.clone())
}
pub async fn send_msg_directly(
&self,
msg: Bytes,
dst_peer_id: &uuid::Uuid,
) -> Result<(), Error> {
match self.get_peer_by_id(dst_peer_id) {
Some(peer) => {
peer.send_msg(msg).await?;
}
None => {
log::error!("no peer for dst_peer_id: {}", dst_peer_id);
return Ok(());
}
}
Ok(())
}
pub async fn send_msg(
&self,
msg: Bytes,
dst_peer_id: &uuid::Uuid,
route: ArcRoute,
) -> Result<(), Error> {
// get route info
let gateway_peer_id = route.get_next_hop(dst_peer_id).await;
if gateway_peer_id.is_none() {
log::error!("no gateway for dst_peer_id: {}", dst_peer_id);
return Ok(());
}
let gateway_peer_id = gateway_peer_id.unwrap();
self.send_msg_directly(msg, &gateway_peer_id).await?;
Ok(())
}
pub async fn list_peers(&self) -> Vec<PeerId> {
let mut ret = Vec::new();
for item in self.peer_map.iter() {
let peer_id = item.key();
ret.push(*peer_id);
}
ret
}
pub async fn list_peers_with_conn(&self) -> Vec<PeerId> {
let mut ret = Vec::new();
let peers = self.list_peers().await;
for peer_id in peers.iter() {
let Some(peer) = self.get_peer_by_id(peer_id) else {
continue;
};
if peer.list_peer_conns().await.len() > 0 {
ret.push(*peer_id);
}
}
ret
}
pub async fn list_peer_conns(&self, peer_id: &PeerId) -> Option<Vec<PeerConnInfo>> {
if let Some(p) = self.get_peer_by_id(peer_id) {
Some(p.list_peer_conns().await)
} else {
return None;
}
}
pub async fn close_peer_conn(
&self,
peer_id: &PeerId,
conn_id: &uuid::Uuid,
) -> Result<(), Error> {
if let Some(p) = self.get_peer_by_id(peer_id) {
p.close_peer_conn(conn_id).await
} else {
return Err(Error::NotFound);
}
}
pub async fn close_peer(&self, peer_id: &PeerId) -> Result<(), TunnelError> {
let remove_ret = self.peer_map.remove(peer_id);
tracing::info!(
?peer_id,
has_old_value = ?remove_ret.is_some(),
peer_ref_counter = ?remove_ret.map(|v| Arc::strong_count(&v.1)),
"peer is closed"
);
Ok(())
}
}
+509
View File
@@ -0,0 +1,509 @@
use std::sync::Arc;
use dashmap::DashMap;
use futures::{SinkExt, StreamExt};
use rkyv::Deserialize;
use tarpc::{server::Channel, transport::channel::UnboundedChannel};
use tokio::{
sync::mpsc::{self, UnboundedSender},
task::JoinSet,
};
use tokio_util::bytes::Bytes;
use tracing::Instrument;
use crate::{common::error::Error, peers::packet::Packet};
use super::packet::{CtrlPacketBody, PacketBody};
type PeerRpcServiceId = u32;
#[async_trait::async_trait]
#[auto_impl::auto_impl(Arc)]
pub trait PeerRpcManagerTransport: Send + Sync + 'static {
fn my_peer_id(&self) -> uuid::Uuid;
async fn send(&self, msg: Bytes, dst_peer_id: &uuid::Uuid) -> Result<(), Error>;
async fn recv(&self) -> Result<Bytes, Error>;
}
type PacketSender = UnboundedSender<Packet>;
struct PeerRpcEndPoint {
peer_id: uuid::Uuid,
packet_sender: PacketSender,
tasks: JoinSet<()>,
}
type PeerRpcEndPointCreator = Box<dyn Fn(uuid::Uuid) -> PeerRpcEndPoint + Send + Sync + 'static>;
#[derive(Hash, Eq, PartialEq, Clone)]
struct PeerRpcClientCtxKey(uuid::Uuid, PeerRpcServiceId);
// handle rpc request from one peer
pub struct PeerRpcManager {
service_map: Arc<DashMap<PeerRpcServiceId, PacketSender>>,
tasks: JoinSet<()>,
tspt: Arc<Box<dyn PeerRpcManagerTransport>>,
service_registry: Arc<DashMap<PeerRpcServiceId, PeerRpcEndPointCreator>>,
peer_rpc_endpoints: Arc<DashMap<(uuid::Uuid, PeerRpcServiceId), PeerRpcEndPoint>>,
client_resp_receivers: Arc<DashMap<PeerRpcClientCtxKey, PacketSender>>,
}
impl std::fmt::Debug for PeerRpcManager {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("PeerRpcManager")
.field("node_id", &self.tspt.my_peer_id())
.finish()
}
}
#[derive(Debug)]
struct TaRpcPacketInfo {
from_peer: uuid::Uuid,
to_peer: uuid::Uuid,
service_id: PeerRpcServiceId,
is_req: bool,
content: Vec<u8>,
}
impl PeerRpcManager {
pub fn new(tspt: impl PeerRpcManagerTransport) -> Self {
Self {
service_map: Arc::new(DashMap::new()),
tasks: JoinSet::new(),
tspt: Arc::new(Box::new(tspt)),
service_registry: Arc::new(DashMap::new()),
peer_rpc_endpoints: Arc::new(DashMap::new()),
client_resp_receivers: Arc::new(DashMap::new()),
}
}
pub fn run_service<S, Req>(self: &Self, service_id: PeerRpcServiceId, s: S) -> ()
where
S: tarpc::server::Serve<Req> + Clone + Send + Sync + 'static,
Req: Send + 'static + serde::Serialize + for<'a> serde::Deserialize<'a>,
S::Resp:
Send + std::fmt::Debug + 'static + serde::Serialize + for<'a> serde::Deserialize<'a>,
S::Fut: Send + 'static,
{
let tspt = self.tspt.clone();
let creator = Box::new(move |peer_id: uuid::Uuid| {
let mut tasks = JoinSet::new();
let (packet_sender, mut packet_receiver) = mpsc::unbounded_channel::<Packet>();
let (mut client_transport, server_transport) = tarpc::transport::channel::unbounded();
let server = tarpc::server::BaseChannel::with_defaults(server_transport);
let my_peer_id_clone = tspt.my_peer_id();
let peer_id_clone = peer_id.clone();
let o = server.execute(s.clone());
tasks.spawn(o);
let tspt = tspt.clone();
tasks.spawn(async move {
let mut cur_req_uuid = None;
loop {
tokio::select! {
Some(resp) = client_transport.next() => {
tracing::trace!(resp = ?resp, "recv packet from client");
if resp.is_err() {
tracing::warn!(err = ?resp.err(),
"[PEER RPC MGR] client_transport in server side got channel error, ignore it.");
continue;
}
let resp = resp.unwrap();
if cur_req_uuid.is_none() {
tracing::error!("[PEER RPC MGR] cur_req_uuid is none, ignore this resp");
continue;
}
let serialized_resp = bincode::serialize(&resp);
if serialized_resp.is_err() {
tracing::error!(error = ?serialized_resp.err(), "serialize resp failed");
continue;
}
let msg = Packet::new_tarpc_packet(
tspt.my_peer_id(),
cur_req_uuid.take().unwrap(),
service_id,
false,
serialized_resp.unwrap(),
);
if let Err(e) = tspt.send(msg.into(), &peer_id).await {
tracing::error!(error = ?e, peer_id = ?peer_id, service_id = ?service_id, "send resp to peer failed");
}
}
Some(packet) = packet_receiver.recv() => {
let info = Self::parse_rpc_packet(&packet);
if let Err(e) = info {
tracing::error!(error = ?e, packet = ?packet, "parse rpc packet failed");
continue;
}
let info = info.unwrap();
assert_eq!(info.service_id, service_id);
cur_req_uuid = Some(packet.from_peer.clone().into());
tracing::trace!("recv packet from peer, packet: {:?}", packet);
let decoded_ret = bincode::deserialize(&info.content.as_slice());
if let Err(e) = decoded_ret {
tracing::error!(error = ?e, "decode rpc packet failed");
continue;
}
let decoded: tarpc::ClientMessage<Req> = decoded_ret.unwrap();
if let Err(e) = client_transport.send(decoded).await {
tracing::error!(error = ?e, "send to req to client transport failed");
}
}
else => {
tracing::warn!("[PEER RPC MGR] service runner destroy, peer_id: {}, service_id: {}", peer_id, service_id);
}
}
}
}.instrument(tracing::info_span!("service_runner", my_id = ?my_peer_id_clone, peer_id = ?peer_id_clone, service_id = ?service_id)));
tracing::info!(
"[PEER RPC MGR] create new service endpoint for peer {}, service {}",
peer_id,
service_id
);
return PeerRpcEndPoint {
peer_id,
packet_sender,
tasks,
};
// let resp = client_transport.next().await;
});
if let Some(_) = self.service_registry.insert(service_id, creator) {
panic!(
"[PEER RPC MGR] service {} is already registered",
service_id
);
}
log::info!(
"[PEER RPC MGR] register service {} succeed, my_node_id {}",
service_id,
self.tspt.my_peer_id()
)
}
fn parse_rpc_packet(packet: &Packet) -> Result<TaRpcPacketInfo, Error> {
match &packet.body {
PacketBody::Ctrl(CtrlPacketBody::TaRpc(id, is_req, body)) => Ok(TaRpcPacketInfo {
from_peer: packet.from_peer.clone().into(),
to_peer: packet.to_peer.clone().unwrap().into(),
service_id: *id,
is_req: *is_req,
content: body.clone(),
}),
_ => Err(Error::ShellCommandError("invalid packet".to_owned())),
}
}
pub fn run(&self) {
let tspt = self.tspt.clone();
let service_registry = self.service_registry.clone();
let peer_rpc_endpoints = self.peer_rpc_endpoints.clone();
let client_resp_receivers = self.client_resp_receivers.clone();
tokio::spawn(async move {
loop {
let o = tspt.recv().await.unwrap();
let packet = Packet::decode(&o);
let packet: Packet = packet.deserialize(&mut rkyv::Infallible).unwrap();
let info = Self::parse_rpc_packet(&packet).unwrap();
if info.is_req {
if !service_registry.contains_key(&info.service_id) {
log::warn!(
"service {} not found, my_node_id: {}",
info.service_id,
tspt.my_peer_id()
);
continue;
}
let endpoint = peer_rpc_endpoints
.entry((info.to_peer, info.service_id))
.or_insert_with(|| {
service_registry.get(&info.service_id).unwrap()(info.from_peer)
});
endpoint.packet_sender.send(packet).unwrap();
} else {
if let Some(a) = client_resp_receivers
.get(&PeerRpcClientCtxKey(info.from_peer, info.service_id))
{
log::trace!("recv resp: {:?}", packet);
if let Err(e) = a.send(packet) {
tracing::error!(error = ?e, "send resp to client failed");
}
} else {
log::warn!("client resp receiver not found, info: {:?}", info);
}
}
}
});
}
#[tracing::instrument(skip(f))]
pub async fn do_client_rpc_scoped<CM, Req, RpcRet, Fut>(
&self,
service_id: PeerRpcServiceId,
dst_peer_id: uuid::Uuid,
f: impl FnOnce(UnboundedChannel<CM, Req>) -> Fut,
) -> RpcRet
where
CM: serde::Serialize + for<'a> serde::Deserialize<'a> + Send + Sync + 'static,
Req: serde::Serialize + for<'a> serde::Deserialize<'a> + Send + Sync + 'static,
Fut: std::future::Future<Output = RpcRet>,
{
let mut tasks = JoinSet::new();
let (packet_sender, mut packet_receiver) = mpsc::unbounded_channel::<Packet>();
let (client_transport, server_transport) =
tarpc::transport::channel::unbounded::<CM, Req>();
let (mut server_s, mut server_r) = server_transport.split();
let tspt = self.tspt.clone();
tasks.spawn(async move {
while let Some(a) = server_r.next().await {
if a.is_err() {
tracing::error!(error = ?a.err(), "channel error");
continue;
}
let a = bincode::serialize(&a.unwrap());
if a.is_err() {
tracing::error!(error = ?a.err(), "bincode serialize failed");
continue;
}
let a = Packet::new_tarpc_packet(
tspt.my_peer_id(),
dst_peer_id,
service_id,
true,
a.unwrap(),
);
if let Err(e) = tspt.send(a.into(), &dst_peer_id).await {
tracing::error!(error = ?e, dst_peer_id = ?dst_peer_id, "send to peer failed");
}
}
tracing::warn!("[PEER RPC MGR] server trasport read aborted");
});
tasks.spawn(async move {
while let Some(packet) = packet_receiver.recv().await {
tracing::trace!("tunnel recv: {:?}", packet);
let info = PeerRpcManager::parse_rpc_packet(&packet);
if let Err(e) = info {
tracing::error!(error = ?e, "parse rpc packet failed");
continue;
}
let decoded = bincode::deserialize(&info.unwrap().content.as_slice());
if let Err(e) = decoded {
tracing::error!(error = ?e, "decode rpc packet failed");
continue;
}
if let Err(e) = server_s.send(decoded.unwrap()).await {
tracing::error!(error = ?e, "send to rpc server channel failed");
}
}
tracing::warn!("[PEER RPC MGR] server packet read aborted");
});
let _insert_ret = self
.client_resp_receivers
.insert(PeerRpcClientCtxKey(dst_peer_id, service_id), packet_sender);
f(client_transport).await
}
pub fn my_peer_id(&self) -> uuid::Uuid {
self.tspt.my_peer_id()
}
}
#[cfg(test)]
mod tests {
use futures::{SinkExt, StreamExt};
use tokio_util::bytes::Bytes;
use crate::{
common::error::Error,
peers::{
peer_rpc::PeerRpcManager,
tests::{connect_peer_manager, create_mock_peer_manager, wait_route_appear},
},
tunnels::{self, ring_tunnel::create_ring_tunnel_pair},
};
use super::PeerRpcManagerTransport;
#[tarpc::service]
pub trait TestRpcService {
async fn hello(s: String) -> String;
}
#[derive(Clone)]
struct MockService {
prefix: String,
}
#[tarpc::server]
impl TestRpcService for MockService {
async fn hello(self, _: tarpc::context::Context, s: String) -> String {
format!("{} {}", self.prefix, s)
}
}
#[tokio::test]
async fn peer_rpc_basic_test() {
struct MockTransport {
tunnel: Box<dyn tunnels::Tunnel>,
my_peer_id: uuid::Uuid,
}
#[async_trait::async_trait]
impl PeerRpcManagerTransport for MockTransport {
fn my_peer_id(&self) -> uuid::Uuid {
self.my_peer_id
}
async fn send(&self, msg: Bytes, _dst_peer_id: &uuid::Uuid) -> Result<(), Error> {
println!("rpc mgr send: {:?}", msg);
self.tunnel.pin_sink().send(msg).await.unwrap();
Ok(())
}
async fn recv(&self) -> Result<Bytes, Error> {
let ret = self.tunnel.pin_stream().next().await.unwrap();
println!("rpc mgr recv: {:?}", ret);
return ret.map(|v| v.freeze()).map_err(|_| Error::Unknown);
}
}
let (ct, st) = create_ring_tunnel_pair();
let server_rpc_mgr = PeerRpcManager::new(MockTransport {
tunnel: st,
my_peer_id: uuid::Uuid::new_v4(),
});
server_rpc_mgr.run();
let s = MockService {
prefix: "hello".to_owned(),
};
server_rpc_mgr.run_service(1, s.serve());
let client_rpc_mgr = PeerRpcManager::new(MockTransport {
tunnel: ct,
my_peer_id: uuid::Uuid::new_v4(),
});
client_rpc_mgr.run();
let ret = client_rpc_mgr
.do_client_rpc_scoped(1, server_rpc_mgr.my_peer_id(), |c| async {
let c = TestRpcServiceClient::new(tarpc::client::Config::default(), c).spawn();
let ret = c.hello(tarpc::context::current(), "abc".to_owned()).await;
ret
})
.await;
println!("ret: {:?}", ret);
assert_eq!(ret.unwrap(), "hello abc");
}
#[tokio::test]
async fn test_rpc_with_peer_manager() {
let peer_mgr_a = create_mock_peer_manager().await;
let peer_mgr_b = create_mock_peer_manager().await;
connect_peer_manager(peer_mgr_a.clone(), peer_mgr_b.clone()).await;
wait_route_appear(peer_mgr_a.clone(), peer_mgr_b.my_node_id())
.await
.unwrap();
assert_eq!(peer_mgr_a.get_peer_map().list_peers().await.len(), 1);
assert_eq!(
peer_mgr_a.get_peer_map().list_peers().await[0],
peer_mgr_b.my_node_id()
);
let s = MockService {
prefix: "hello".to_owned(),
};
peer_mgr_b.get_peer_rpc_mgr().run_service(1, s.serve());
let ip_list = peer_mgr_a
.get_peer_rpc_mgr()
.do_client_rpc_scoped(1, peer_mgr_b.my_node_id(), |c| async {
let c = TestRpcServiceClient::new(tarpc::client::Config::default(), c).spawn();
let ret = c.hello(tarpc::context::current(), "abc".to_owned()).await;
ret
})
.await;
println!("ip_list: {:?}", ip_list);
assert_eq!(ip_list.as_ref().unwrap(), "hello abc");
}
#[tokio::test]
async fn test_multi_service_with_peer_manager() {
let peer_mgr_a = create_mock_peer_manager().await;
let peer_mgr_b = create_mock_peer_manager().await;
connect_peer_manager(peer_mgr_a.clone(), peer_mgr_b.clone()).await;
wait_route_appear(peer_mgr_a.clone(), peer_mgr_b.my_node_id())
.await
.unwrap();
assert_eq!(peer_mgr_a.get_peer_map().list_peers().await.len(), 1);
assert_eq!(
peer_mgr_a.get_peer_map().list_peers().await[0],
peer_mgr_b.my_node_id()
);
let s = MockService {
prefix: "hello_a".to_owned(),
};
peer_mgr_b.get_peer_rpc_mgr().run_service(1, s.serve());
let b = MockService {
prefix: "hello_b".to_owned(),
};
peer_mgr_b.get_peer_rpc_mgr().run_service(2, b.serve());
let ip_list = peer_mgr_a
.get_peer_rpc_mgr()
.do_client_rpc_scoped(1, peer_mgr_b.my_node_id(), |c| async {
let c = TestRpcServiceClient::new(tarpc::client::Config::default(), c).spawn();
let ret = c.hello(tarpc::context::current(), "abc".to_owned()).await;
ret
})
.await;
assert_eq!(ip_list.as_ref().unwrap(), "hello_a abc");
let ip_list = peer_mgr_a
.get_peer_rpc_mgr()
.do_client_rpc_scoped(2, peer_mgr_b.my_node_id(), |c| async {
let c = TestRpcServiceClient::new(tarpc::client::Config::default(), c).spawn();
let ret = c.hello(tarpc::context::current(), "abc".to_owned()).await;
ret
})
.await;
assert_eq!(ip_list.as_ref().unwrap(), "hello_b abc");
}
}
+480
View File
@@ -0,0 +1,480 @@
use std::{net::Ipv4Addr, sync::Arc, time::Duration};
use async_trait::async_trait;
use dashmap::DashMap;
use easytier_rpc::{NatType, StunInfo};
use rkyv::{Archive, Deserialize, Serialize};
use tokio::{sync::Mutex, task::JoinSet};
use tokio_util::bytes::Bytes;
use tracing::Instrument;
use uuid::Uuid;
use crate::{
common::{
error::Error,
global_ctx::ArcGlobalCtx,
rkyv_util::{decode_from_bytes, encode_to_bytes},
stun::StunInfoCollectorTrait,
},
peers::{
packet::{self, UUID},
route_trait::{Route, RouteInterfaceBox},
PeerId,
},
};
#[derive(Archive, Deserialize, Serialize, Clone, Debug)]
#[archive(compare(PartialEq), check_bytes)]
// Derives can be passed through to the generated type:
#[archive_attr(derive(Debug))]
pub struct SyncPeerInfo {
// means next hop in route table.
pub peer_id: UUID,
pub cost: u32,
pub ipv4_addr: Option<Ipv4Addr>,
pub proxy_cidrs: Vec<String>,
pub hostname: Option<String>,
pub udp_stun_info: i8,
}
impl SyncPeerInfo {
pub fn new_self(from_peer: UUID, global_ctx: &ArcGlobalCtx) -> Self {
SyncPeerInfo {
peer_id: from_peer,
cost: 0,
ipv4_addr: global_ctx.get_ipv4(),
proxy_cidrs: global_ctx
.get_proxy_cidrs()
.iter()
.map(|x| x.to_string())
.collect(),
hostname: global_ctx.get_hostname(),
udp_stun_info: global_ctx
.get_stun_info_collector()
.get_stun_info()
.udp_nat_type as i8,
}
}
pub fn clone_for_route_table(&self, next_hop: &UUID, cost: u32, from: &Self) -> Self {
SyncPeerInfo {
peer_id: next_hop.clone(),
cost,
ipv4_addr: from.ipv4_addr.clone(),
proxy_cidrs: from.proxy_cidrs.clone(),
hostname: from.hostname.clone(),
udp_stun_info: from.udp_stun_info,
}
}
}
#[derive(Archive, Deserialize, Serialize, Clone, Debug)]
#[archive(compare(PartialEq), check_bytes)]
// Derives can be passed through to the generated type:
#[archive_attr(derive(Debug))]
pub struct SyncPeer {
pub myself: SyncPeerInfo,
pub neighbors: Vec<SyncPeerInfo>,
}
impl SyncPeer {
pub fn new(
from_peer: UUID,
_to_peer: UUID,
neighbors: Vec<SyncPeerInfo>,
global_ctx: ArcGlobalCtx,
) -> Self {
SyncPeer {
myself: SyncPeerInfo::new_self(from_peer, &global_ctx),
neighbors,
}
}
}
struct SyncPeerFromRemote {
packet: SyncPeer,
last_update: std::time::Instant,
}
type SyncPeerFromRemoteMap = Arc<DashMap<uuid::Uuid, SyncPeerFromRemote>>;
#[derive(Clone, Debug)]
struct RouteTable {
route_info: DashMap<uuid::Uuid, SyncPeerInfo>,
ipv4_peer_id_map: DashMap<Ipv4Addr, uuid::Uuid>,
cidr_peer_id_map: DashMap<cidr::IpCidr, uuid::Uuid>,
}
impl RouteTable {
fn new() -> Self {
RouteTable {
route_info: DashMap::new(),
ipv4_peer_id_map: DashMap::new(),
cidr_peer_id_map: DashMap::new(),
}
}
fn copy_from(&self, other: &Self) {
self.route_info.clear();
for item in other.route_info.iter() {
let (k, v) = item.pair();
self.route_info.insert(*k, v.clone());
}
self.ipv4_peer_id_map.clear();
for item in other.ipv4_peer_id_map.iter() {
let (k, v) = item.pair();
self.ipv4_peer_id_map.insert(*k, *v);
}
self.cidr_peer_id_map.clear();
for item in other.cidr_peer_id_map.iter() {
let (k, v) = item.pair();
self.cidr_peer_id_map.insert(*k, *v);
}
}
}
pub struct BasicRoute {
my_peer_id: packet::UUID,
global_ctx: ArcGlobalCtx,
interface: Arc<Mutex<Option<RouteInterfaceBox>>>,
route_table: Arc<RouteTable>,
sync_peer_from_remote: SyncPeerFromRemoteMap,
tasks: Mutex<JoinSet<()>>,
need_sync_notifier: Arc<tokio::sync::Notify>,
}
impl BasicRoute {
pub fn new(my_peer_id: Uuid, global_ctx: ArcGlobalCtx) -> Self {
BasicRoute {
my_peer_id: my_peer_id.into(),
global_ctx,
interface: Arc::new(Mutex::new(None)),
route_table: Arc::new(RouteTable::new()),
sync_peer_from_remote: Arc::new(DashMap::new()),
tasks: Mutex::new(JoinSet::new()),
need_sync_notifier: Arc::new(tokio::sync::Notify::new()),
}
}
fn update_route_table(
my_id: packet::UUID,
sync_peer_reqs: SyncPeerFromRemoteMap,
route_table: Arc<RouteTable>,
) {
tracing::trace!(my_id = ?my_id, route_table = ?route_table, "update route table");
let new_route_table = Arc::new(RouteTable::new());
for item in sync_peer_reqs.iter() {
Self::update_route_table_with_req(
my_id.clone(),
&item.value().packet,
new_route_table.clone(),
);
}
route_table.copy_from(&new_route_table);
}
fn update_route_table_with_req(
my_id: packet::UUID,
packet: &SyncPeer,
route_table: Arc<RouteTable>,
) {
let peer_id = packet.myself.peer_id.clone();
let update = |cost: u32, peer_info: &SyncPeerInfo| {
let node_id: uuid::Uuid = peer_info.peer_id.clone().into();
let ret = route_table
.route_info
.entry(node_id.clone().into())
.and_modify(|info| {
if info.cost > cost {
*info = info.clone_for_route_table(&peer_id, cost, &peer_info);
}
})
.or_insert(
peer_info
.clone()
.clone_for_route_table(&peer_id, cost, &peer_info),
)
.value()
.clone();
if ret.cost > 32 {
log::error!(
"cost too large: {}, may lost connection, remove it",
ret.cost
);
route_table.route_info.remove(&node_id);
}
log::trace!(
"update route info, to: {:?}, gateway: {:?}, cost: {}, peer: {:?}",
node_id,
peer_id,
cost,
&peer_info
);
if let Some(ipv4) = peer_info.ipv4_addr {
route_table
.ipv4_peer_id_map
.insert(ipv4.clone(), node_id.clone().into());
}
for cidr in peer_info.proxy_cidrs.iter() {
let cidr: cidr::IpCidr = cidr.parse().unwrap();
route_table
.cidr_peer_id_map
.insert(cidr, node_id.clone().into());
}
};
for neighbor in packet.neighbors.iter() {
if neighbor.peer_id == my_id {
continue;
}
update(neighbor.cost + 1, &neighbor);
log::trace!("route info: {:?}", neighbor);
}
// add the sender peer to route info
update(1, &packet.myself);
log::trace!("my_id: {:?}, current route table: {:?}", my_id, route_table);
}
async fn send_sync_peer_request(
interface: &RouteInterfaceBox,
my_peer_id: packet::UUID,
global_ctx: ArcGlobalCtx,
peer_id: PeerId,
route_table: Arc<RouteTable>,
) -> Result<(), Error> {
let mut route_info_copy: Vec<SyncPeerInfo> = Vec::new();
// copy the route info
for item in route_table.route_info.iter() {
let (k, v) = item.pair();
route_info_copy.push(v.clone().clone_for_route_table(&(*k).into(), v.cost, &v));
}
let msg = SyncPeer::new(my_peer_id, peer_id.into(), route_info_copy, global_ctx);
// TODO: this may exceed the MTU of the tunnel
interface
.send_route_packet(encode_to_bytes::<_, 4096>(&msg), 1, &peer_id)
.await
}
async fn sync_peer_periodically(&self) {
let route_table = self.route_table.clone();
let global_ctx = self.global_ctx.clone();
let my_peer_id = self.my_peer_id.clone();
let interface = self.interface.clone();
let notifier = self.need_sync_notifier.clone();
self.tasks.lock().await.spawn(
async move {
loop {
let lockd_interface = interface.lock().await;
let interface = lockd_interface.as_ref().unwrap();
let peers = interface.list_peers().await;
for peer in peers.iter() {
let ret = Self::send_sync_peer_request(
interface,
my_peer_id.clone(),
global_ctx.clone(),
*peer,
route_table.clone(),
)
.await;
match &ret {
Ok(_) => {
log::trace!("send sync peer request to peer: {}", peer);
}
Err(Error::PeerNoConnectionError(_)) => {
log::trace!("peer {} no connection", peer);
}
Err(e) => {
log::error!(
"send sync peer request to peer: {} error: {:?}",
peer,
e
);
}
};
}
tokio::select! {
_ = notifier.notified() => {
log::trace!("sync peer request triggered by notifier");
}
_ = tokio::time::sleep(Duration::from_secs(1)) => {
log::trace!("sync peer request triggered by timeout");
}
}
}
}
.instrument(
tracing::info_span!("sync_peer_periodically", my_id = ?self.my_peer_id, global_ctx = ?self.global_ctx),
),
);
}
async fn check_expired_sync_peer_from_remote(&self) {
let route_table = self.route_table.clone();
let my_peer_id = self.my_peer_id.clone();
let sync_peer_from_remote = self.sync_peer_from_remote.clone();
let notifier = self.need_sync_notifier.clone();
self.tasks.lock().await.spawn(async move {
loop {
let mut need_update_route = false;
let now = std::time::Instant::now();
let mut need_remove = Vec::new();
for item in sync_peer_from_remote.iter() {
let (k, v) = item.pair();
if now.duration_since(v.last_update).as_secs() > 5 {
need_update_route = true;
need_remove.insert(0, k.clone());
}
}
for k in need_remove.iter() {
log::warn!("remove expired sync peer: {:?}", k);
sync_peer_from_remote.remove(k);
}
if need_update_route {
Self::update_route_table(
my_peer_id.clone(),
sync_peer_from_remote.clone(),
route_table.clone(),
);
notifier.notify_one();
}
tokio::time::sleep(Duration::from_secs(1)).await;
}
});
}
fn get_peer_id_for_proxy(&self, ipv4: &Ipv4Addr) -> Option<PeerId> {
let ipv4 = std::net::IpAddr::V4(*ipv4);
for item in self.route_table.cidr_peer_id_map.iter() {
let (k, v) = item.pair();
if k.contains(&ipv4) {
return Some(*v);
}
}
None
}
}
#[async_trait]
impl Route for BasicRoute {
async fn open(&self, interface: RouteInterfaceBox) -> Result<u8, ()> {
*self.interface.lock().await = Some(interface);
self.sync_peer_periodically().await;
self.check_expired_sync_peer_from_remote().await;
Ok(1)
}
async fn close(&self) {}
#[tracing::instrument(skip(self, packet), fields(my_id = ?self.my_peer_id, ctx = ?self.global_ctx))]
async fn handle_route_packet(&self, src_peer_id: uuid::Uuid, packet: Bytes) {
let packet = decode_from_bytes::<SyncPeer>(&packet).unwrap();
let p: SyncPeer = packet.deserialize(&mut rkyv::Infallible).unwrap();
let mut updated = true;
assert_eq!(packet.myself.peer_id.to_uuid(), src_peer_id);
self.sync_peer_from_remote
.entry(packet.myself.peer_id.to_uuid())
.and_modify(|v| {
if v.packet == *packet {
updated = false;
} else {
v.packet = p.clone();
}
v.last_update = std::time::Instant::now();
})
.or_insert(SyncPeerFromRemote {
packet: p.clone(),
last_update: std::time::Instant::now(),
});
if updated {
Self::update_route_table(
self.my_peer_id.clone(),
self.sync_peer_from_remote.clone(),
self.route_table.clone(),
);
self.need_sync_notifier.notify_one();
}
}
async fn get_peer_id_by_ipv4(&self, ipv4_addr: &Ipv4Addr) -> Option<PeerId> {
if let Some(peer_id) = self.route_table.ipv4_peer_id_map.get(ipv4_addr) {
return Some(*peer_id);
}
if let Some(peer_id) = self.get_peer_id_for_proxy(ipv4_addr) {
return Some(peer_id);
}
log::info!("no peer id for ipv4: {}", ipv4_addr);
return None;
}
async fn get_next_hop(&self, dst_peer_id: &PeerId) -> Option<PeerId> {
match self.route_table.route_info.get(dst_peer_id) {
Some(info) => {
return Some(info.peer_id.clone().into());
}
None => {
log::error!("no route info for dst_peer_id: {}", dst_peer_id);
return None;
}
}
}
async fn list_routes(&self) -> Vec<easytier_rpc::Route> {
let mut routes = Vec::new();
let parse_route_info = |real_peer_id: &Uuid, route_info: &SyncPeerInfo| {
let mut route = easytier_rpc::Route::default();
route.ipv4_addr = if let Some(ipv4_addr) = route_info.ipv4_addr {
ipv4_addr.to_string()
} else {
"".to_string()
};
route.peer_id = real_peer_id.to_string();
route.next_hop_peer_id = Uuid::from(route_info.peer_id.clone()).to_string();
route.cost = route_info.cost as i32;
route.proxy_cidrs = route_info.proxy_cidrs.clone();
route.hostname = if let Some(hostname) = &route_info.hostname {
hostname.clone()
} else {
"".to_string()
};
let mut stun_info = StunInfo::default();
if let Ok(udp_nat_type) = NatType::try_from(route_info.udp_stun_info as i32) {
stun_info.set_udp_nat_type(udp_nat_type);
}
route.stun_info = Some(stun_info);
route
};
self.route_table.route_info.iter().for_each(|item| {
routes.push(parse_route_info(item.key(), item.value()));
});
routes
}
}
+36
View File
@@ -0,0 +1,36 @@
use std::{net::Ipv4Addr, sync::Arc};
use async_trait::async_trait;
use tokio_util::bytes::Bytes;
use crate::common::error::Error;
use super::PeerId;
#[async_trait]
pub trait RouteInterface {
async fn list_peers(&self) -> Vec<PeerId>;
async fn send_route_packet(
&self,
msg: Bytes,
route_id: u8,
dst_peer_id: &PeerId,
) -> Result<(), Error>;
}
pub type RouteInterfaceBox = Box<dyn RouteInterface + Send + Sync>;
#[async_trait]
pub trait Route {
async fn open(&self, interface: RouteInterfaceBox) -> Result<u8, ()>;
async fn close(&self);
async fn get_peer_id_by_ipv4(&self, ipv4: &Ipv4Addr) -> Option<PeerId>;
async fn get_next_hop(&self, peer_id: &PeerId) -> Option<PeerId>;
async fn handle_route_packet(&self, src_peer_id: PeerId, packet: Bytes);
async fn list_routes(&self) -> Vec<easytier_rpc::Route>;
}
pub type ArcRoute = Arc<Box<dyn Route + Send + Sync>>;
+55
View File
@@ -0,0 +1,55 @@
use std::sync::Arc;
use easytier_rpc::peer_manage_rpc_server::PeerManageRpc;
use easytier_rpc::{ListPeerRequest, ListPeerResponse, ListRouteRequest, ListRouteResponse};
use tonic::{Request, Response, Status};
use super::peer_manager::PeerManager;
pub struct PeerManagerRpcService {
peer_manager: Arc<PeerManager>,
}
impl PeerManagerRpcService {
pub fn new(peer_manager: Arc<PeerManager>) -> Self {
PeerManagerRpcService { peer_manager }
}
}
#[tonic::async_trait]
impl PeerManageRpc for PeerManagerRpcService {
async fn list_peer(
&self,
_request: Request<ListPeerRequest>, // Accept request of type HelloRequest
) -> Result<Response<ListPeerResponse>, Status> {
let mut reply = ListPeerResponse::default();
let peers = self.peer_manager.get_peer_map().list_peers().await;
for peer in peers {
let mut peer_info = easytier_rpc::PeerInfo::default();
peer_info.peer_id = peer.to_string();
if let Some(conns) = self
.peer_manager
.get_peer_map()
.list_peer_conns(&peer)
.await
{
peer_info.conns = conns;
}
reply.peer_infos.push(peer_info);
}
Ok(Response::new(reply))
}
async fn list_route(
&self,
_request: Request<ListRouteRequest>, // Accept request of type HelloRequest
) -> Result<Response<ListRouteResponse>, Status> {
let mut reply = ListRouteResponse::default();
reply.routes = self.peer_manager.list_routes().await;
Ok(Response::new(reply))
}
}
+60
View File
@@ -0,0 +1,60 @@
use std::sync::Arc;
use crate::{
common::{error::Error, global_ctx::tests::get_mock_global_ctx},
peers::rip_route::BasicRoute,
tunnels::ring_tunnel::create_ring_tunnel_pair,
};
use super::peer_manager::PeerManager;
pub async fn create_mock_peer_manager() -> Arc<PeerManager> {
let (s, _r) = tokio::sync::mpsc::channel(1000);
let peer_mgr = Arc::new(PeerManager::new(get_mock_global_ctx(), s));
peer_mgr
.set_route(BasicRoute::new(
peer_mgr.my_node_id(),
peer_mgr.get_global_ctx(),
))
.await;
peer_mgr.run().await.unwrap();
peer_mgr
}
pub async fn connect_peer_manager(client: Arc<PeerManager>, server: Arc<PeerManager>) {
let (a_ring, b_ring) = create_ring_tunnel_pair();
let a_mgr_copy = client.clone();
tokio::spawn(async move {
a_mgr_copy.add_client_tunnel(a_ring).await.unwrap();
});
let b_mgr_copy = server.clone();
tokio::spawn(async move {
b_mgr_copy.add_tunnel_as_server(b_ring).await.unwrap();
});
}
pub async fn wait_route_appear_with_cost(
peer_mgr: Arc<PeerManager>,
node_id: uuid::Uuid,
cost: Option<i32>,
) -> Result<(), Error> {
let now = std::time::Instant::now();
while now.elapsed().as_secs() < 5 {
let route = peer_mgr.list_routes().await;
if route.iter().any(|r| {
r.peer_id.clone().parse::<uuid::Uuid>().unwrap() == node_id
&& (cost.is_none() || r.cost == cost.unwrap())
}) {
return Ok(());
}
tokio::time::sleep(std::time::Duration::from_millis(50)).await;
}
return Err(Error::NotFound);
}
pub async fn wait_route_appear(
peer_mgr: Arc<PeerManager>,
node_id: uuid::Uuid,
) -> Result<(), Error> {
wait_route_appear_with_cost(peer_mgr, node_id, None).await
}