Feat/web (PatchSet 1) (#436)

* move rpc-build out of easytier dir and make it a independant project
* easytier core use launcher
* fix flags not print on launch
* allow launcher not fetch node info
* abstract out peer rpc impl
* fix arm gui ci. see https://github.com/actions/runner-images/pull/10807
* add easytier-web crate
* fix manual_connector test case
This commit is contained in:
Sijie.Sun
2024-10-19 18:10:02 +08:00
committed by GitHub
parent 2134bc9139
commit 0bf42c53cc
29 changed files with 575 additions and 353 deletions
+25 -2
View File
@@ -164,7 +164,7 @@ pub struct Flags {
pub enable_ipv6: bool,
#[derivative(Default(value = "1380"))]
pub mtu: u16,
#[derivative(Default(value = "true"))]
#[derivative(Default(value = "false"))]
pub latency_first: bool,
#[derivative(Default(value = "false"))]
pub enable_exit_node: bool,
@@ -182,6 +182,8 @@ pub struct Flags {
pub disable_udp_hole_punching: bool,
#[derivative(Default(value = "\"udp://[::]:0\".to_string()"))]
pub ipv6_listener: String,
#[derivative(Default(value = "false"))]
pub multi_thread: bool,
}
#[derive(Debug, Clone, Deserialize, Serialize, PartialEq)]
@@ -529,7 +531,28 @@ impl ConfigLoader for TomlConfigLoader {
}
fn dump(&self) -> String {
toml::to_string_pretty(&*self.config.lock().unwrap()).unwrap()
let default_flags_json = serde_json::to_string(&Flags::default()).unwrap();
let default_flags_hashmap =
serde_json::from_str::<serde_json::Map<String, serde_json::Value>>(&default_flags_json)
.unwrap();
let cur_flags_json = serde_json::to_string(&self.get_flags()).unwrap();
let cur_flags_hashmap =
serde_json::from_str::<serde_json::Map<String, serde_json::Value>>(&cur_flags_json)
.unwrap();
let mut flag_map: serde_json::Map<String, serde_json::Value> = Default::default();
for (key, value) in default_flags_hashmap {
if let Some(v) = cur_flags_hashmap.get(&key) {
if *v != value {
flag_map.insert(key, v.clone());
}
}
}
let mut config = self.config.lock().unwrap().clone();
config.flags = Some(flag_map);
toml::to_string_pretty(&config).unwrap()
}
fn get_routes(&self) -> Option<Vec<cidr::Ipv4Cidr>> {
+2 -2
View File
@@ -44,8 +44,8 @@ pub enum GlobalCtxEvent {
DhcpIpv4Conflicted(Option<cidr::Ipv4Inet>),
}
type EventBus = tokio::sync::broadcast::Sender<GlobalCtxEvent>;
type EventBusSubscriber = tokio::sync::broadcast::Receiver<GlobalCtxEvent>;
pub type EventBus = tokio::sync::broadcast::Sender<GlobalCtxEvent>;
pub type EventBusSubscriber = tokio::sync::broadcast::Receiver<GlobalCtxEvent>;
pub struct GlobalCtx {
pub inst_name: String,
+19 -33
View File
@@ -19,6 +19,7 @@ mod common;
mod connector;
mod gateway;
mod instance;
mod launcher;
mod peer_center;
mod peers;
mod proto;
@@ -29,8 +30,9 @@ mod vpn_portal;
use common::{
config::{ConsoleLoggerConfig, FileLoggerConfig, NetworkIdentity, PeerConfig, VpnPortalConfig},
constants::EASYTIER_VERSION,
global_ctx::EventBusSubscriber,
scoped_task::ScopedTask,
};
use instance::instance::Instance;
use tokio::net::TcpSocket;
use utils::setup_panic_handler;
@@ -525,6 +527,7 @@ impl From<Cli> for TomlConfigLoader {
.with_context(|| format!("failed to parse ipv6 listener: {}", ipv6_listener))
.unwrap();
}
f.multi_thread = cli.multi_thread;
cfg.set_flags(f);
cfg.set_exit_nodes(cli.exit_nodes.clone());
@@ -549,13 +552,7 @@ fn peer_conn_info_to_string(p: crate::proto::cli::PeerConnInfo) -> String {
}
#[tracing::instrument]
pub async fn async_main(cli: Cli) {
let cfg: TomlConfigLoader = cli.into();
init_logger(&cfg, false).unwrap();
let mut inst = Instance::new(cfg.clone());
let mut events = inst.get_global_ctx().subscribe();
pub fn handle_event(mut events: EventBusSubscriber) -> tokio::task::JoinHandle<()> {
tokio::spawn(async move {
while let Ok(e) = events.recv().await {
match e {
@@ -658,39 +655,28 @@ pub async fn async_main(cli: Cli) {
}
}
}
});
println!("Starting easytier with config:");
println!("############### TOML ###############\n");
println!("{}", cfg.dump());
println!("-----------------------------------");
inst.run().await.unwrap();
inst.wait().await;
})
}
fn main() {
#[tokio::main]
async fn main() {
setup_panic_handler();
let locale = sys_locale::get_locale().unwrap_or_else(|| String::from("en-US"));
rust_i18n::set_locale(&locale);
let cli = Cli::parse();
tracing::info!(cli = ?cli, "cli args parsed");
let cfg = TomlConfigLoader::from(cli);
init_logger(&cfg, false).unwrap();
if cli.multi_thread {
tokio::runtime::Builder::new_multi_thread()
.worker_threads(2)
.enable_all()
.build()
.unwrap()
.block_on(async move { async_main(cli).await })
} else {
tokio::runtime::Builder::new_current_thread()
.enable_all()
.build()
.unwrap()
.block_on(async move { async_main(cli).await })
println!("Starting easytier with config:");
println!("############### TOML ###############\n");
println!("{}", cfg.dump());
println!("-----------------------------------");
let mut l = launcher::NetworkInstance::new(cfg).set_fetch_node_info(false);
let _t = ScopedTask::from(handle_event(l.start().unwrap()));
if let Some(e) = l.wait().await {
panic!("launcher error: {:?}", e);
}
}
+137 -69
View File
@@ -7,7 +7,7 @@ use crate::{
common::{
config::{ConfigLoader, TomlConfigLoader},
constants::EASYTIER_VERSION,
global_ctx::GlobalCtxEvent,
global_ctx::{EventBusSubscriber, GlobalCtxEvent},
stun::StunInfoCollectorTrait,
},
instance::instance::Instance,
@@ -21,7 +21,7 @@ use crate::{
};
use chrono::{DateTime, Local};
use serde::{Deserialize, Serialize};
use tokio::task::JoinSet;
use tokio::{sync::broadcast, task::JoinSet};
#[derive(Default, Clone, Debug, Serialize, Deserialize)]
pub struct MyNodeInfo {
@@ -34,14 +34,31 @@ pub struct MyNodeInfo {
pub vpn_portal_cfg: Option<String>,
}
#[derive(Default, Clone)]
struct EasyTierData {
events: Arc<RwLock<VecDeque<(DateTime<Local>, GlobalCtxEvent)>>>,
node_info: Arc<RwLock<MyNodeInfo>>,
routes: Arc<RwLock<Vec<Route>>>,
peers: Arc<RwLock<Vec<PeerInfo>>>,
events: RwLock<VecDeque<(DateTime<Local>, GlobalCtxEvent)>>,
node_info: RwLock<MyNodeInfo>,
routes: RwLock<Vec<Route>>,
peers: RwLock<Vec<PeerInfo>>,
tun_fd: Arc<RwLock<Option<i32>>>,
tun_dev_name: Arc<RwLock<String>>,
tun_dev_name: RwLock<String>,
event_subscriber: RwLock<broadcast::Sender<GlobalCtxEvent>>,
instance_stop_notifier: Arc<tokio::sync::Notify>,
}
impl Default for EasyTierData {
fn default() -> Self {
let (tx, _) = broadcast::channel(100);
Self {
event_subscriber: RwLock::new(tx),
events: RwLock::new(VecDeque::new()),
node_info: RwLock::new(MyNodeInfo::default()),
routes: RwLock::new(Vec::new()),
peers: RwLock::new(Vec::new()),
tun_fd: Arc::new(RwLock::new(None)),
tun_dev_name: RwLock::new(String::new()),
instance_stop_notifier: Arc::new(tokio::sync::Notify::new()),
}
}
}
pub struct EasyTierLauncher {
@@ -49,27 +66,30 @@ pub struct EasyTierLauncher {
stop_flag: Arc<AtomicBool>,
thread_handle: Option<std::thread::JoinHandle<()>>,
running_cfg: String,
fetch_node_info: bool,
error_msg: Arc<RwLock<Option<String>>>,
data: EasyTierData,
data: Arc<EasyTierData>,
}
impl EasyTierLauncher {
pub fn new() -> Self {
pub fn new(fetch_node_info: bool) -> Self {
let instance_alive = Arc::new(AtomicBool::new(false));
Self {
instance_alive,
thread_handle: None,
error_msg: Arc::new(RwLock::new(None)),
running_cfg: String::new(),
fetch_node_info,
stop_flag: Arc::new(AtomicBool::new(false)),
data: EasyTierData::default(),
data: Arc::new(EasyTierData::default()),
}
}
async fn handle_easytier_event(event: GlobalCtxEvent, data: EasyTierData) {
async fn handle_easytier_event(event: GlobalCtxEvent, data: &EasyTierData) {
let mut events = data.events.write().unwrap();
let _ = data.event_subscriber.read().unwrap().send(event.clone());
events.push_back((chrono::Local::now(), event));
if events.len() > 100 {
events.pop_front();
@@ -113,7 +133,8 @@ impl EasyTierLauncher {
async fn easytier_routine(
cfg: TomlConfigLoader,
stop_signal: Arc<tokio::sync::Notify>,
data: EasyTierData,
data: Arc<EasyTierData>,
fetch_node_info: bool,
) -> Result<(), anyhow::Error> {
let mut instance = Instance::new(cfg);
let peer_mgr = instance.get_peer_manager();
@@ -126,50 +147,53 @@ impl EasyTierLauncher {
tasks.spawn(async move {
let mut receiver = global_ctx.subscribe();
while let Ok(event) = receiver.recv().await {
Self::handle_easytier_event(event, data_c.clone()).await;
Self::handle_easytier_event(event, &data_c).await;
}
});
// update my node info
let data_c = data.clone();
let global_ctx_c = instance.get_global_ctx();
let peer_mgr_c = peer_mgr.clone();
let vpn_portal = instance.get_vpn_portal_inst();
tasks.spawn(async move {
loop {
// Update TUN Device Name
*data_c.tun_dev_name.write().unwrap() = global_ctx_c.get_flags().dev_name.clone();
if fetch_node_info {
let data_c = data.clone();
let global_ctx_c = instance.get_global_ctx();
let peer_mgr_c = peer_mgr.clone();
let vpn_portal = instance.get_vpn_portal_inst();
tasks.spawn(async move {
loop {
// Update TUN Device Name
*data_c.tun_dev_name.write().unwrap() =
global_ctx_c.get_flags().dev_name.clone();
let node_info = MyNodeInfo {
virtual_ipv4: global_ctx_c
.get_ipv4()
.map(|x| x.to_string())
.unwrap_or_default(),
hostname: global_ctx_c.get_hostname(),
version: EASYTIER_VERSION.to_string(),
ips: global_ctx_c.get_ip_collector().collect_ip_addrs().await,
stun_info: global_ctx_c.get_stun_info_collector().get_stun_info(),
listeners: global_ctx_c
.get_running_listeners()
.iter()
.map(|x| x.to_string())
.collect(),
vpn_portal_cfg: Some(
vpn_portal
.lock()
.await
.dump_client_config(peer_mgr_c.clone())
.await,
),
};
*data_c.node_info.write().unwrap() = node_info.clone();
*data_c.routes.write().unwrap() = peer_mgr_c.list_routes().await;
*data_c.peers.write().unwrap() = PeerManagerRpcService::new(peer_mgr_c.clone())
.list_peers()
.await;
tokio::time::sleep(std::time::Duration::from_secs(1)).await;
}
});
let node_info = MyNodeInfo {
virtual_ipv4: global_ctx_c
.get_ipv4()
.map(|x| x.to_string())
.unwrap_or_default(),
hostname: global_ctx_c.get_hostname(),
version: EASYTIER_VERSION.to_string(),
ips: global_ctx_c.get_ip_collector().collect_ip_addrs().await,
stun_info: global_ctx_c.get_stun_info_collector().get_stun_info(),
listeners: global_ctx_c
.get_running_listeners()
.iter()
.map(|x| x.to_string())
.collect(),
vpn_portal_cfg: Some(
vpn_portal
.lock()
.await
.dump_client_config(peer_mgr_c.clone())
.await,
),
};
*data_c.node_info.write().unwrap() = node_info.clone();
*data_c.routes.write().unwrap() = peer_mgr_c.list_routes().await;
*data_c.peers.write().unwrap() = PeerManagerRpcService::new(peer_mgr_c.clone())
.list_peers()
.await;
tokio::time::sleep(std::time::Duration::from_secs(1)).await;
}
});
}
#[cfg(target_os = "android")]
Self::run_routine_for_android(&instance, &data, &mut tasks).await;
@@ -188,13 +212,15 @@ impl EasyTierLauncher {
F: FnOnce() -> Result<TomlConfigLoader, anyhow::Error> + Send + Sync,
{
let error_msg = self.error_msg.clone();
let cfg = cfg_generator();
if let Err(e) = cfg {
error_msg.write().unwrap().replace(e.to_string());
return;
}
let cfg = match cfg_generator() {
Err(e) => {
error_msg.write().unwrap().replace(e.to_string());
return;
}
Ok(cfg) => cfg,
};
self.running_cfg = cfg.as_ref().unwrap().dump();
self.running_cfg = cfg.dump();
let stop_flag = self.stop_flag.clone();
@@ -202,12 +228,21 @@ impl EasyTierLauncher {
instance_alive.store(true, std::sync::atomic::Ordering::Relaxed);
let data = self.data.clone();
let fetch_node_info = self.fetch_node_info;
self.thread_handle = Some(std::thread::spawn(move || {
let rt = tokio::runtime::Builder::new_multi_thread()
.enable_all()
.build()
.unwrap();
let rt = if cfg.get_flags().multi_thread {
tokio::runtime::Builder::new_multi_thread()
.worker_threads(2)
.enable_all()
.build()
} else {
tokio::runtime::Builder::new_current_thread()
.enable_all()
.build()
}
.unwrap();
let stop_notifier = Arc::new(tokio::sync::Notify::new());
let stop_notifier_clone = stop_notifier.clone();
@@ -218,15 +253,18 @@ impl EasyTierLauncher {
stop_notifier_clone.notify_one();
});
let notifier = data.instance_stop_notifier.clone();
let ret = rt.block_on(Self::easytier_routine(
cfg.unwrap(),
cfg,
stop_notifier.clone(),
data,
fetch_node_info,
));
if let Err(e) = ret {
error_msg.write().unwrap().replace(e.to_string());
}
instance_alive.store(false, std::sync::atomic::Ordering::Relaxed);
notifier.notify_one();
}));
}
@@ -289,6 +327,8 @@ pub struct NetworkInstanceRunningInfo {
pub struct NetworkInstance {
config: TomlConfigLoader,
launcher: Option<EasyTierLauncher>,
fetch_node_info: bool,
}
impl NetworkInstance {
@@ -296,9 +336,15 @@ impl NetworkInstance {
Self {
config,
launcher: None,
fetch_node_info: true,
}
}
pub fn set_fetch_node_info(mut self, fetch_node_info: bool) -> Self {
self.fetch_node_info = fetch_node_info;
self
}
pub fn is_easytier_running(&self) -> bool {
self.launcher.is_some() && self.launcher.as_ref().unwrap().running()
}
@@ -333,15 +379,37 @@ impl NetworkInstance {
}
}
pub fn start(&mut self) -> Result<(), anyhow::Error> {
pub fn start(&mut self) -> Result<EventBusSubscriber, anyhow::Error> {
if self.is_easytier_running() {
return Ok(());
return Ok(self.subscribe_event().unwrap());
}
let mut launcher = EasyTierLauncher::new();
launcher.start(|| Ok(self.config.clone()));
let launcher = EasyTierLauncher::new(self.fetch_node_info);
self.launcher = Some(launcher);
Ok(())
let ev = self.subscribe_event().unwrap();
self.launcher
.as_mut()
.unwrap()
.start(|| Ok(self.config.clone()));
Ok(ev)
}
fn subscribe_event(&self) -> Option<broadcast::Receiver<GlobalCtxEvent>> {
if let Some(launcher) = self.launcher.as_ref() {
Some(launcher.data.event_subscriber.read().unwrap().subscribe())
} else {
None
}
}
pub async fn wait(&self) -> Option<String> {
if let Some(launcher) = self.launcher.as_ref() {
launcher.data.instance_stop_notifier.notified().await;
launcher.error_msg.read().unwrap().clone()
} else {
None
}
}
}
+1 -1
View File
@@ -6,11 +6,11 @@ mod gateway;
mod instance;
mod peer_center;
mod peers;
mod proto;
mod vpn_portal;
pub mod common;
pub mod launcher;
pub mod proto;
pub mod tunnel;
pub mod utils;
+14 -53
View File
@@ -1,12 +1,12 @@
use std::sync::{Arc, Mutex};
use futures::StreamExt;
use futures::{SinkExt as _, StreamExt};
use tokio::task::JoinSet;
use crate::{
common::{error::Error, PeerId},
proto::rpc_impl,
tunnel::packet_def::{PacketType, ZCPacket},
proto::rpc_impl::{self, bidirect::BidirectRpcManager},
tunnel::packet_def::ZCPacket,
};
const RPC_PACKET_CONTENT_MTU: usize = 1300;
@@ -25,9 +25,7 @@ pub trait PeerRpcManagerTransport: Send + Sync + 'static {
// handle rpc request from one peer
pub struct PeerRpcManager {
tspt: Arc<Box<dyn PeerRpcManagerTransport>>,
rpc_client: rpc_impl::client::Client,
rpc_server: rpc_impl::server::Server,
bidirect_rpc: BidirectRpcManager,
tasks: Arc<Mutex<JoinSet<()>>>,
}
@@ -43,78 +41,41 @@ impl PeerRpcManager {
pub fn new(tspt: impl PeerRpcManagerTransport) -> Self {
Self {
tspt: Arc::new(Box::new(tspt)),
rpc_client: rpc_impl::client::Client::new(),
rpc_server: rpc_impl::server::Server::new(),
bidirect_rpc: BidirectRpcManager::new(),
tasks: Arc::new(Mutex::new(JoinSet::new())),
}
}
pub fn run(&self) {
self.rpc_client.run();
self.rpc_server.run();
let (server_tx, mut server_rx) = (
self.rpc_server.get_transport_sink(),
self.rpc_server.get_transport_stream(),
);
let (client_tx, mut client_rx) = (
self.rpc_client.get_transport_sink(),
self.rpc_client.get_transport_stream(),
);
let ret = self.bidirect_rpc.run_and_create_tunnel();
let (mut rx, mut tx) = ret.split();
let tspt = self.tspt.clone();
self.tasks.lock().unwrap().spawn(async move {
loop {
let packet = tokio::select! {
Some(Ok(packet)) = server_rx.next() => {
tracing::trace!(?packet, "recv rpc packet from server");
packet
}
Some(Ok(packet)) = client_rx.next() => {
tracing::trace!(?packet, "recv rpc packet from client");
packet
}
else => {
tracing::warn!("rpc transport read aborted, exiting");
break;
}
};
while let Some(Ok(packet)) = rx.next().await {
let dst_peer_id = packet.peer_manager_header().unwrap().to_peer_id.into();
if let Err(e) = tspt.send(packet, dst_peer_id).await {
tracing::error!(error = ?e, dst_peer_id = ?dst_peer_id, "send to peer failed");
tracing::error!("send to rpc tspt error: {:?}", e);
}
}
});
let tspt = self.tspt.clone();
self.tasks.lock().unwrap().spawn(async move {
loop {
let Ok(o) = tspt.recv().await else {
tracing::warn!("peer rpc transport read aborted, exiting");
break;
};
if o.peer_manager_header().unwrap().packet_type == PacketType::RpcReq as u8 {
server_tx.send(o).await.unwrap();
continue;
} else if o.peer_manager_header().unwrap().packet_type == PacketType::RpcResp as u8
{
client_tx.send(o).await.unwrap();
continue;
while let Ok(packet) = tspt.recv().await {
if let Err(e) = tx.send(packet).await {
tracing::error!("send to rpc tspt error: {:?}", e);
}
}
});
}
pub fn rpc_client(&self) -> &rpc_impl::client::Client {
&self.rpc_client
self.bidirect_rpc.rpc_client()
}
pub fn rpc_server(&self) -> &rpc_impl::server::Server {
&self.rpc_server
self.bidirect_rpc.rpc_server()
}
pub fn my_peer_id(&self) -> PeerId {
-8
View File
@@ -1,8 +0,0 @@
[package]
name = "rpc_build"
version = "0.1.0"
edition = "2021"
[dependencies]
heck = "0.5"
prost-build = "0.13"
-383
View File
@@ -1,383 +0,0 @@
extern crate heck;
extern crate prost_build;
use std::fmt;
const NAMESPACE: &str = "crate::proto::rpc_types";
/// The service generator to be used with `prost-build` to generate RPC implementations for
/// `prost-simple-rpc`.
///
/// See the crate-level documentation for more info.
#[allow(missing_copy_implementations)]
#[derive(Clone, Debug)]
pub struct ServiceGenerator {
_private: (),
}
impl ServiceGenerator {
/// Create a new `ServiceGenerator` instance with the default options set.
pub fn new() -> ServiceGenerator {
ServiceGenerator { _private: () }
}
}
impl prost_build::ServiceGenerator for ServiceGenerator {
fn generate(&mut self, service: prost_build::Service, mut buf: &mut String) {
use std::fmt::Write;
let descriptor_name = format!("{}Descriptor", service.name);
let server_name = format!("{}Server", service.name);
let client_name = format!("{}Client", service.name);
let method_descriptor_name = format!("{}MethodDescriptor", service.name);
let mut trait_methods = String::new();
let mut enum_methods = String::new();
let mut list_enum_methods = String::new();
let mut client_methods = String::new();
let mut client_own_methods = String::new();
let mut match_name_methods = String::new();
let mut match_proto_name_methods = String::new();
let mut match_input_type_methods = String::new();
let mut match_input_proto_type_methods = String::new();
let mut match_output_type_methods = String::new();
let mut match_output_proto_type_methods = String::new();
let mut match_handle_methods = String::new();
let mut match_method_try_from = String::new();
for (idx, method) in service.methods.iter().enumerate() {
assert!(
!method.client_streaming,
"Client streaming not yet supported for method {}",
method.proto_name
);
assert!(
!method.server_streaming,
"Server streaming not yet supported for method {}",
method.proto_name
);
ServiceGenerator::write_comments(&mut trait_methods, 4, &method.comments).unwrap();
writeln!(
trait_methods,
r#" async fn {name}(&self, ctrl: Self::Controller, input: {input_type}) -> {namespace}::error::Result<{output_type}>;"#,
name = method.name,
input_type = method.input_type,
output_type = method.output_type,
namespace = NAMESPACE,
)
.unwrap();
ServiceGenerator::write_comments(&mut enum_methods, 4, &method.comments).unwrap();
writeln!(
enum_methods,
" {name} = {index},",
name = method.proto_name,
index = format!("{}", idx + 1)
)
.unwrap();
writeln!(
match_method_try_from,
" {index} => Ok({service_name}MethodDescriptor::{name}),",
service_name = service.name,
name = method.proto_name,
index = format!("{}", idx + 1),
)
.unwrap();
writeln!(
list_enum_methods,
" {service_name}MethodDescriptor::{name},",
service_name = service.name,
name = method.proto_name
)
.unwrap();
writeln!(
client_methods,
r#" async fn {name}(&self, ctrl: H::Controller, input: {input_type}) -> {namespace}::error::Result<{output_type}> {{
{client_name}::{name}_inner(self.0.clone(), ctrl, input).await
}}"#,
name = method.name,
input_type = method.input_type,
output_type = method.output_type,
client_name = format!("{}Client", service.name),
namespace = NAMESPACE,
)
.unwrap();
writeln!(
client_own_methods,
r#" async fn {name}_inner(handler: H, ctrl: H::Controller, input: {input_type}) -> {namespace}::error::Result<{output_type}> {{
{namespace}::__rt::call_method(handler, ctrl, {method_descriptor_name}::{proto_name}, input).await
}}"#,
name = method.name,
method_descriptor_name = method_descriptor_name,
proto_name = method.proto_name,
input_type = method.input_type,
output_type = method.output_type,
namespace = NAMESPACE,
).unwrap();
let case = format!(
" {service_name}MethodDescriptor::{proto_name} => ",
service_name = service.name,
proto_name = method.proto_name
);
writeln!(match_name_methods, "{}{:?},", case, method.name).unwrap();
writeln!(match_proto_name_methods, "{}{:?},", case, method.proto_name).unwrap();
writeln!(
match_input_type_methods,
"{}::std::any::TypeId::of::<{}>(),",
case, method.input_type
)
.unwrap();
writeln!(
match_input_proto_type_methods,
"{}{:?},",
case, method.input_proto_type
)
.unwrap();
writeln!(
match_output_type_methods,
"{}::std::any::TypeId::of::<{}>(),",
case, method.output_type
)
.unwrap();
writeln!(
match_output_proto_type_methods,
"{}{:?},",
case, method.output_proto_type
)
.unwrap();
write!(
match_handle_methods,
r#"{} {{
let decoded: {input_type} = {namespace}::__rt::decode(input)?;
let ret = service.{name}(ctrl, decoded).await?;
{namespace}::__rt::encode(ret)
}}
"#,
case,
input_type = method.input_type,
name = method.name,
namespace = NAMESPACE,
)
.unwrap();
}
ServiceGenerator::write_comments(&mut buf, 0, &service.comments).unwrap();
write!(
buf,
r#"
#[async_trait::async_trait]
#[auto_impl::auto_impl(&, Arc, Box)]
pub trait {name} {{
type Controller: {namespace}::controller::Controller;
{trait_methods}
}}
/// A service descriptor for a `{name}`.
#[derive(Clone, Debug, Eq, Ord, PartialEq, PartialOrd, Default)]
pub struct {descriptor_name};
/// Methods available on a `{name}`.
///
/// This can be used as a key when routing requests for servers/clients of a `{name}`.
#[derive(Clone, Copy, Debug, Eq, Ord, PartialEq, PartialOrd)]
#[repr(u8)]
pub enum {method_descriptor_name} {{
{enum_methods}
}}
impl std::convert::TryFrom<u8> for {method_descriptor_name} {{
type Error = {namespace}::error::Error;
fn try_from(value: u8) -> {namespace}::error::Result<Self> {{
match value {{
{match_method_try_from}
_ => Err({namespace}::error::Error::InvalidMethodIndex(value, "{name}".to_string())),
}}
}}
}}
/// A client for a `{name}`.
///
/// This implements the `{name}` trait by dispatching all method calls to the supplied `Handler`.
#[derive(Clone, Debug)]
pub struct {client_name}<H>(H) where H: {namespace}::handler::Handler;
impl<H> {client_name}<H> where H: {namespace}::handler::Handler<Descriptor = {descriptor_name}> {{
/// Creates a new client instance that delegates all method calls to the supplied handler.
pub fn new(handler: H) -> {client_name}<H> {{
{client_name}(handler)
}}
}}
impl<H> {client_name}<H> where H: {namespace}::handler::Handler<Descriptor = {descriptor_name}> {{
{client_own_methods}
}}
#[async_trait::async_trait]
impl<H> {name} for {client_name}<H> where H: {namespace}::handler::Handler<Descriptor = {descriptor_name}> {{
type Controller = H::Controller;
{client_methods}
}}
pub struct {client_name}Factory<C: {namespace}::controller::Controller>(std::marker::PhantomData<C>);
impl<C: {namespace}::controller::Controller> Clone for {client_name}Factory<C> {{
fn clone(&self) -> Self {{
Self(std::marker::PhantomData)
}}
}}
impl<C> {namespace}::__rt::RpcClientFactory for {client_name}Factory<C> where C: {namespace}::controller::Controller {{
type Descriptor = {descriptor_name};
type ClientImpl = Box<dyn {name}<Controller = C> + Send + 'static>;
type Controller = C;
fn new(handler: impl {namespace}::handler::Handler<Descriptor = Self::Descriptor, Controller = Self::Controller>) -> Self::ClientImpl {{
Box::new({client_name}::new(handler))
}}
}}
/// A server for a `{name}`.
///
/// This implements the `Server` trait by handling requests and dispatch them to methods on the
/// supplied `{name}`.
#[derive(Clone, Debug)]
pub struct {server_name}<A>(A) where A: {name} + Clone + Send + 'static;
impl<A> {server_name}<A> where A: {name} + Clone + Send + 'static {{
/// Creates a new server instance that dispatches all calls to the supplied service.
pub fn new(service: A) -> {server_name}<A> {{
{server_name}(service)
}}
async fn call_inner(
service: A,
method: {method_descriptor_name},
ctrl: A::Controller,
input: ::bytes::Bytes)
-> {namespace}::error::Result<::bytes::Bytes> {{
match method {{
{match_handle_methods}
}}
}}
}}
impl {namespace}::descriptor::ServiceDescriptor for {descriptor_name} {{
type Method = {method_descriptor_name};
fn name(&self) -> &'static str {{ {name:?} }}
fn proto_name(&self) -> &'static str {{ {proto_name:?} }}
fn package(&self) -> &'static str {{ {package:?} }}
fn methods(&self) -> &'static [Self::Method] {{
&[ {list_enum_methods} ]
}}
}}
#[async_trait::async_trait]
impl<A> {namespace}::handler::Handler for {server_name}<A>
where
A: {name} + Clone + Send + Sync + 'static {{
type Descriptor = {descriptor_name};
type Controller = A::Controller;
async fn call(
&self,
ctrl: A::Controller,
method: {method_descriptor_name},
input: ::bytes::Bytes)
-> {namespace}::error::Result<::bytes::Bytes> {{
{server_name}::call_inner(self.0.clone(), method, ctrl, input).await
}}
}}
impl {namespace}::descriptor::MethodDescriptor for {method_descriptor_name} {{
fn name(&self) -> &'static str {{
match *self {{
{match_name_methods}
}}
}}
fn proto_name(&self) -> &'static str {{
match *self {{
{match_proto_name_methods}
}}
}}
fn input_type(&self) -> ::std::any::TypeId {{
match *self {{
{match_input_type_methods}
}}
}}
fn input_proto_type(&self) -> &'static str {{
match *self {{
{match_input_proto_type_methods}
}}
}}
fn output_type(&self) -> ::std::any::TypeId {{
match *self {{
{match_output_type_methods}
}}
}}
fn output_proto_type(&self) -> &'static str {{
match *self {{
{match_output_proto_type_methods}
}}
}}
fn index(&self) -> u8 {{
*self as u8
}}
}}
"#,
name = service.name,
descriptor_name = descriptor_name,
server_name = server_name,
client_name = client_name,
method_descriptor_name = method_descriptor_name,
proto_name = service.proto_name,
package = service.package,
trait_methods = trait_methods,
enum_methods = enum_methods,
list_enum_methods = list_enum_methods,
client_own_methods = client_own_methods,
client_methods = client_methods,
match_name_methods = match_name_methods,
match_proto_name_methods = match_proto_name_methods,
match_input_type_methods = match_input_type_methods,
match_input_proto_type_methods = match_input_proto_type_methods,
match_output_type_methods = match_output_type_methods,
match_output_proto_type_methods = match_output_proto_type_methods,
match_handle_methods = match_handle_methods,
namespace = NAMESPACE,
).unwrap();
}
}
impl ServiceGenerator {
fn write_comments<W>(
mut write: W,
indent: usize,
comments: &prost_build::Comments,
) -> fmt::Result
where
W: fmt::Write,
{
for comment in &comments.leading {
for line in comment.lines().filter(|s| !s.is_empty()) {
writeln!(write, "{}///{}", " ".repeat(indent), line)?;
}
}
Ok(())
}
}
+164
View File
@@ -0,0 +1,164 @@
use std::sync::{Arc, Mutex};
use futures::{SinkExt as _, StreamExt};
use tokio::{task::JoinSet, time::timeout};
use crate::{
proto::rpc_types::error::Error,
tunnel::{packet_def::PacketType, ring::create_ring_tunnel_pair, Tunnel},
};
use super::{client::Client, server::Server};
pub struct BidirectRpcManager {
rpc_client: Client,
rpc_server: Server,
rx_timeout: Option<std::time::Duration>,
error: Arc<Mutex<Option<Error>>>,
tunnel: Mutex<Option<Box<dyn Tunnel>>>,
tasks: Mutex<Option<JoinSet<()>>>,
}
impl BidirectRpcManager {
pub fn new() -> Self {
Self {
rpc_client: Client::new(),
rpc_server: Server::new(),
rx_timeout: None,
error: Arc::new(Mutex::new(None)),
tunnel: Mutex::new(None),
tasks: Mutex::new(None),
}
}
pub fn set_rx_timeout(mut self, timeout: Option<std::time::Duration>) -> Self {
self.rx_timeout = timeout;
self
}
pub fn run_and_create_tunnel(&self) -> Box<dyn Tunnel> {
let (ret, inner) = create_ring_tunnel_pair();
self.run_with_tunnel(inner);
ret
}
pub fn run_with_tunnel(&self, inner: Box<dyn Tunnel>) {
let mut tasks = JoinSet::new();
self.rpc_client.run();
self.rpc_server.run();
let (server_tx, mut server_rx) = (
self.rpc_server.get_transport_sink(),
self.rpc_server.get_transport_stream(),
);
let (client_tx, mut client_rx) = (
self.rpc_client.get_transport_sink(),
self.rpc_client.get_transport_stream(),
);
let (mut inner_rx, mut inner_tx) = inner.split();
self.tunnel.lock().unwrap().replace(inner);
let e_clone = self.error.clone();
tasks.spawn(async move {
loop {
let packet = tokio::select! {
Some(Ok(packet)) = server_rx.next() => {
tracing::trace!(?packet, "recv rpc packet from server");
packet
}
Some(Ok(packet)) = client_rx.next() => {
tracing::trace!(?packet, "recv rpc packet from client");
packet
}
else => {
tracing::warn!("rpc transport read aborted, exiting");
break;
}
};
if let Err(e) = inner_tx.send(packet).await {
tracing::error!(error = ?e, "send to peer failed");
e_clone.lock().unwrap().replace(Error::from(e));
}
}
});
let recv_timeout = self.rx_timeout;
let e_clone = self.error.clone();
tasks.spawn(async move {
loop {
let ret = if let Some(recv_timeout) = recv_timeout {
match timeout(recv_timeout, inner_rx.next()).await {
Ok(ret) => ret,
Err(e) => {
e_clone.lock().unwrap().replace(e.into());
break;
}
}
} else {
inner_rx.next().await
};
let o = match ret {
Some(Ok(o)) => o,
Some(Err(e)) => {
tracing::error!(error = ?e, "recv from peer failed");
e_clone.lock().unwrap().replace(Error::from(e));
break;
}
None => {
tracing::warn!("peer rpc transport read aborted, exiting");
e_clone.lock().unwrap().replace(Error::Shutdown);
break;
}
};
if o.peer_manager_header().unwrap().packet_type == PacketType::RpcReq as u8 {
server_tx.send(o).await.unwrap();
continue;
} else if o.peer_manager_header().unwrap().packet_type == PacketType::RpcResp as u8
{
client_tx.send(o).await.unwrap();
continue;
}
}
});
self.tasks.lock().unwrap().replace(tasks);
}
pub fn rpc_client(&self) -> &Client {
&self.rpc_client
}
pub fn rpc_server(&self) -> &Server {
&self.rpc_server
}
pub async fn stop(&self) {
let Some(mut tasks) = self.tasks.lock().unwrap().take() else {
return;
};
tasks.abort_all();
while let Some(_) = tasks.join_next().await {}
}
pub fn take_error(&self) -> Option<Error> {
self.error.lock().unwrap().take()
}
pub async fn wait(&self) {
let Some(mut tasks) = self.tasks.lock().unwrap().take() else {
return;
};
while let Some(_) = tasks.join_next().await {
// when any task is done, abort all tasks
tasks.abort_all();
}
}
}
+1
View File
@@ -2,6 +2,7 @@ use crate::tunnel::{mpsc::MpscTunnel, Tunnel};
pub type RpcController = super::rpc_types::controller::BaseController;
pub mod bidirect;
pub mod client;
pub mod packet;
pub mod server;
@@ -59,6 +59,14 @@ impl ServiceRegistry {
}
}
pub fn replace_registry(&self, registry: &ServiceRegistry) {
self.table.clear();
for item in registry.table.iter() {
let (k, v) = item.pair();
self.table.insert(k.clone(), v.clone());
}
}
pub fn register<H: Handler<Controller = RpcController>>(&self, h: H, domain_name: &str) {
let desc = h.service_descriptor();
let key = ServiceKey {
+16 -135
View File
@@ -4,66 +4,18 @@ use std::{
};
use anyhow::Context as _;
use futures::{SinkExt as _, StreamExt};
use tokio::task::JoinSet;
use crate::{
common::join_joinset_background,
proto::rpc_types::{__rt::RpcClientFactory, error::Error},
proto::{
rpc_impl::bidirect::BidirectRpcManager,
rpc_types::{__rt::RpcClientFactory, error::Error},
},
tunnel::{Tunnel, TunnelConnector, TunnelListener},
};
use super::{client::Client, server::Server, service_registry::ServiceRegistry};
struct StandAloneServerOneTunnel {
tunnel: Box<dyn Tunnel>,
rpc_server: Server,
}
impl StandAloneServerOneTunnel {
pub fn new(tunnel: Box<dyn Tunnel>, registry: Arc<ServiceRegistry>) -> Self {
let rpc_server = Server::new_with_registry(registry);
StandAloneServerOneTunnel { tunnel, rpc_server }
}
pub async fn run(self) {
use tokio_stream::StreamExt as _;
let (tunnel_rx, tunnel_tx) = self.tunnel.split();
let (rpc_rx, rpc_tx) = (
self.rpc_server.get_transport_stream(),
self.rpc_server.get_transport_sink(),
);
let mut tasks = JoinSet::new();
tasks.spawn(async move {
let ret = tunnel_rx.timeout(Duration::from_secs(60));
tokio::pin!(ret);
while let Ok(Some(Ok(p))) = ret.try_next().await {
if let Err(e) = rpc_tx.send(p).await {
tracing::error!("tunnel_rx send to rpc_tx error: {:?}", e);
break;
}
}
tracing::info!("forward tunnel_rx to rpc_tx done");
});
tasks.spawn(async move {
let ret = rpc_rx.forward(tunnel_tx).await;
tracing::info!("rpc_rx forward tunnel_tx done: {:?}", ret);
});
self.rpc_server.run();
while let Some(ret) = tasks.join_next().await {
self.rpc_server.close();
tracing::info!("task done: {:?}", ret);
}
tracing::info!("all tasks done");
}
}
use super::service_registry::ServiceRegistry;
pub struct StandAloneServer<L> {
registry: Arc<ServiceRegistry>,
@@ -102,11 +54,15 @@ impl<L: TunnelListener + 'static> StandAloneServer<L> {
self.tasks.lock().unwrap().spawn(async move {
while let Ok(tunnel) = listener.accept().await {
let server = StandAloneServerOneTunnel::new(tunnel, registry.clone());
let registry = registry.clone();
let inflight_server = inflight_server.clone();
inflight_server.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
tasks.lock().unwrap().spawn(async move {
server.run().await;
let server =
BidirectRpcManager::new().set_rx_timeout(Some(Duration::from_secs(60)));
server.rpc_server().registry().replace_registry(&registry);
server.run_with_tunnel(tunnel);
server.wait().await;
inflight_server.fetch_sub(1, std::sync::atomic::Ordering::Relaxed);
});
}
@@ -122,86 +78,9 @@ impl<L: TunnelListener + 'static> StandAloneServer<L> {
}
}
struct StandAloneClientOneTunnel {
rpc_client: Client,
tasks: Arc<Mutex<JoinSet<()>>>,
error: Arc<Mutex<Option<Error>>>,
}
impl StandAloneClientOneTunnel {
pub fn new(tunnel: Box<dyn Tunnel>) -> Self {
let rpc_client = Client::new();
let (mut rpc_rx, rpc_tx) = (
rpc_client.get_transport_stream(),
rpc_client.get_transport_sink(),
);
let tasks = Arc::new(Mutex::new(JoinSet::new()));
let (mut tunnel_rx, mut tunnel_tx) = tunnel.split();
let error_store = Arc::new(Mutex::new(None));
let error = error_store.clone();
tasks.lock().unwrap().spawn(async move {
while let Some(p) = rpc_rx.next().await {
match p {
Ok(p) => {
if let Err(e) = tunnel_tx
.send(p)
.await
.with_context(|| "failed to send packet")
{
*error.lock().unwrap() = Some(e.into());
}
}
Err(e) => {
*error.lock().unwrap() = Some(anyhow::Error::from(e).into());
}
}
}
*error.lock().unwrap() = Some(anyhow::anyhow!("rpc_rx next exit").into());
});
let error = error_store.clone();
tasks.lock().unwrap().spawn(async move {
while let Some(p) = tunnel_rx.next().await {
match p {
Ok(p) => {
if let Err(e) = rpc_tx
.send(p)
.await
.with_context(|| "failed to send packet")
{
*error.lock().unwrap() = Some(e.into());
}
}
Err(e) => {
*error.lock().unwrap() = Some(anyhow::Error::from(e).into());
}
}
}
*error.lock().unwrap() = Some(anyhow::anyhow!("tunnel_rx next exit").into());
});
rpc_client.run();
StandAloneClientOneTunnel {
rpc_client,
tasks,
error: error_store,
}
}
pub fn take_error(&self) -> Option<Error> {
self.error.lock().unwrap().take()
}
}
pub struct StandAloneClient<C: TunnelConnector> {
connector: C,
client: Option<StandAloneClientOneTunnel>,
client: Option<BidirectRpcManager>,
}
impl<C: TunnelConnector> StandAloneClient<C> {
@@ -230,7 +109,9 @@ impl<C: TunnelConnector> StandAloneClient<C> {
if c.is_none() || error.is_some() {
tracing::info!("reconnect due to error: {:?}", error);
let tunnel = self.connect().await?;
c = Some(StandAloneClientOneTunnel::new(tunnel));
let mgr = BidirectRpcManager::new().set_rx_timeout(Some(Duration::from_secs(60)));
mgr.run_with_tunnel(tunnel);
c = Some(mgr);
}
self.client = c;
@@ -239,7 +120,7 @@ impl<C: TunnelConnector> StandAloneClient<C> {
.client
.as_ref()
.unwrap()
.rpc_client
.rpc_client()
.scoped_client::<F>(1, 1, domain_name))
}
}
+3
View File
@@ -29,6 +29,9 @@ pub enum Error {
#[error("Tunnel error: {0}")]
TunnelError(#[from] crate::tunnel::TunnelError),
#[error("Shutdown")]
Shutdown,
}
pub type Result<T> = result::Result<T, Error>;
+72
View File
@@ -300,3 +300,75 @@ async fn standalone_rpc_test() {
tokio::time::sleep(std::time::Duration::from_secs(1)).await;
assert_eq!(0, server.inflight_server());
}
#[tokio::test]
async fn test_bidirect_rpc_manager() {
use crate::common::scoped_task::ScopedTask;
use crate::proto::rpc_impl::bidirect::BidirectRpcManager;
use crate::tunnel::tcp::{TcpTunnelConnector, TcpTunnelListener};
use crate::tunnel::{TunnelConnector, TunnelListener};
let c = BidirectRpcManager::new();
let s = BidirectRpcManager::new();
let service = GreetingServer::new(GreetingService {
delay_ms: 0,
prefix: "Hello Client".to_string(),
});
c.rpc_server().registry().register(service, "test");
let service = GreetingServer::new(GreetingService {
delay_ms: 0,
prefix: "Hello Server".to_string(),
});
s.rpc_server().registry().register(service, "test");
let mut tcp_listener = TcpTunnelListener::new("tcp://0.0.0.0:55443".parse().unwrap());
let s_task: ScopedTask<()> = tokio::spawn(async move {
tcp_listener.listen().await.unwrap();
let tunnel = tcp_listener.accept().await.unwrap();
s.run_with_tunnel(tunnel);
let s_c = s
.rpc_client()
.scoped_client::<GreetingClientFactory<RpcController>>(1, 1, "test".to_string());
let ret = s_c
.say_hello(
RpcController::default(),
SayHelloRequest {
name: "world".to_string(),
},
)
.await
.unwrap();
assert_eq!(ret.greeting, "Hello Client world!");
println!("server done, {:?}", ret);
s.wait().await;
})
.into();
tokio::time::sleep(std::time::Duration::from_secs(1)).await;
let mut tcp_connector = TcpTunnelConnector::new("tcp://0.0.0.0:55443".parse().unwrap());
let c_tunnel = tcp_connector.connect().await.unwrap();
c.run_with_tunnel(c_tunnel);
let c_c = c
.rpc_client()
.scoped_client::<GreetingClientFactory<RpcController>>(1, 1, "test".to_string());
let ret = c_c
.say_hello(
RpcController::default(),
SayHelloRequest {
name: "world".to_string(),
},
)
.await
.unwrap();
assert_eq!(ret.greeting, "Hello Server world!");
println!("client done, {:?}", ret);
drop(c);
s_task.await.unwrap();
}
+1 -1
View File
@@ -789,7 +789,7 @@ pub async fn manual_reconnector(#[values(true, false)] is_foreign: bool) {
.await
.unwrap();
assert_eq!(1, conns.len());
assert!(conns.len() >= 1);
wait_for_condition(
|| async { ping_test("net_b", "10.144.145.2", None).await },