forward original peer info in ospf route (#589)

prost doesn't support unknown field, and these info may be lost when
they go through a old version node.
This commit is contained in:
Sijie.Sun
2025-01-27 20:38:22 +08:00
committed by GitHub
parent 08546925cc
commit 4aea0821dd
10 changed files with 318 additions and 37 deletions
Generated
+101 -4
View File
@@ -670,6 +670,12 @@ version = "1.6.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8c3c1a368f70d6cf7302d78f8f7093da241fb8e8807c05cc9e51a125895a6d5b" checksum = "8c3c1a368f70d6cf7302d78f8f7093da241fb8e8807c05cc9e51a125895a6d5b"
[[package]]
name = "beef"
version = "0.5.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3a8241f3ebb85c056b509d4327ad0358fbbba6ffb340bf388f26350aeda225b1"
[[package]] [[package]]
name = "bigdecimal" name = "bigdecimal"
version = "0.4.6" version = "0.4.6"
@@ -693,7 +699,7 @@ dependencies = [
"bitflags 2.8.0", "bitflags 2.8.0",
"cexpr", "cexpr",
"clang-sys", "clang-sys",
"itertools 0.11.0", "itertools 0.12.1",
"proc-macro2", "proc-macro2",
"quote", "quote",
"regex", "regex",
@@ -1917,6 +1923,8 @@ dependencies = [
"pnet", "pnet",
"prost", "prost",
"prost-build", "prost-build",
"prost-reflect",
"prost-reflect-build",
"prost-types", "prost-types",
"quinn", "quinn",
"rand 0.8.5", "rand 0.8.5",
@@ -3709,6 +3717,39 @@ version = "0.4.22"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a7a70ba024b9dc04c27ea2f0c0548feb474ec5c54bba33a7f72f873a39d07b24" checksum = "a7a70ba024b9dc04c27ea2f0c0548feb474ec5c54bba33a7f72f873a39d07b24"
[[package]]
name = "logos"
version = "0.14.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7251356ef8cb7aec833ddf598c6cb24d17b689d20b993f9d11a3d764e34e6458"
dependencies = [
"logos-derive",
]
[[package]]
name = "logos-codegen"
version = "0.14.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "59f80069600c0d66734f5ff52cc42f2dabd6b29d205f333d61fd7832e9e9963f"
dependencies = [
"beef",
"fnv",
"lazy_static",
"proc-macro2",
"quote",
"regex-syntax 0.8.4",
"syn 2.0.87",
]
[[package]]
name = "logos-derive"
version = "0.14.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "24fb722b06a9dc12adb0963ed585f19fc61dc5413e6a9be9422ef92c091e731d"
dependencies = [
"logos-codegen",
]
[[package]] [[package]]
name = "loom" name = "loom"
version = "0.5.6" version = "0.5.6"
@@ -4578,6 +4619,15 @@ version = "0.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "04744f49eae99ab78e0d5c0b603ab218f515ea8cfe5a456d7629ad883a3b6e7d" checksum = "04744f49eae99ab78e0d5c0b603ab218f515ea8cfe5a456d7629ad883a3b6e7d"
[[package]]
name = "ordered-float"
version = "2.10.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "68f19d67e5a2795c94e73e0bb1cc1a7edeb2e28efd39e2e1c9b7a40c1108b11c"
dependencies = [
"num-traits",
]
[[package]] [[package]]
name = "ordered-float" name = "ordered-float"
version = "3.9.2" version = "3.9.2"
@@ -5340,7 +5390,7 @@ checksum = "f8650aabb6c35b860610e9cff5dc1af886c9e25073b7b1712a68972af4281302"
dependencies = [ dependencies = [
"bytes", "bytes",
"heck 0.5.0", "heck 0.5.0",
"itertools 0.11.0", "itertools 0.12.1",
"log", "log",
"multimap", "multimap",
"once_cell", "once_cell",
@@ -5360,7 +5410,44 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "acf0c195eebb4af52c752bec4f52f645da98b6e92077a04110c7f349477ae5ac" checksum = "acf0c195eebb4af52c752bec4f52f645da98b6e92077a04110c7f349477ae5ac"
dependencies = [ dependencies = [
"anyhow", "anyhow",
"itertools 0.11.0", "itertools 0.12.1",
"proc-macro2",
"quote",
"syn 2.0.87",
]
[[package]]
name = "prost-reflect"
version = "0.14.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e92b959d24e05a3e2da1d0beb55b48bc8a97059b8336ea617780bd6addbbfb5a"
dependencies = [
"base64 0.22.1",
"logos",
"once_cell",
"prost",
"prost-reflect-derive",
"prost-types",
"serde",
"serde-value",
]
[[package]]
name = "prost-reflect-build"
version = "0.14.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "50e2537231d94dd2778920c2ada37dd9eb1ac0325bb3ee3ee651bd44c1134123"
dependencies = [
"prost-build",
"prost-reflect",
]
[[package]]
name = "prost-reflect-derive"
version = "0.14.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f4fce6b22f15cc8d8d400a2b98ad29202b33bd56c7d9ddd815bc803a807ecb65"
dependencies = [
"proc-macro2", "proc-macro2",
"quote", "quote",
"syn 2.0.87", "syn 2.0.87",
@@ -6311,7 +6398,7 @@ dependencies = [
"bigdecimal", "bigdecimal",
"chrono", "chrono",
"inherent", "inherent",
"ordered-float", "ordered-float 3.9.2",
"rust_decimal", "rust_decimal",
"sea-query-derive", "sea-query-derive",
"serde_json", "serde_json",
@@ -6451,6 +6538,16 @@ dependencies = [
"typeid", "typeid",
] ]
[[package]]
name = "serde-value"
version = "0.7.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f3a1a3341211875ef120e117ea7fd5228530ae7e7036a779fdc9117be6b3282c"
dependencies = [
"ordered-float 2.10.1",
"serde",
]
[[package]] [[package]]
name = "serde_derive" name = "serde_derive"
version = "1.0.207" version = "1.0.207"
+7
View File
@@ -188,6 +188,12 @@ async-compression = { version = "0.4.17", default-features = false, features = [
kcp-sys = { git = "https://github.com/EasyTier/kcp-sys" } kcp-sys = { git = "https://github.com/EasyTier/kcp-sys" }
prost-reflect = { version = "0.14.5", features = [
"serde",
"derive",
"text-format"
] }
[target.'cfg(any(target_os = "linux", target_os = "macos", target_os = "windows", target_os = "freebsd"))'.dependencies] [target.'cfg(any(target_os = "linux", target_os = "macos", target_os = "windows", target_os = "freebsd"))'.dependencies]
machine-uid = "0.5.3" machine-uid = "0.5.3"
@@ -208,6 +214,7 @@ globwalk = "0.8.1"
regex = "1" regex = "1"
prost-build = "0.13.2" prost-build = "0.13.2"
rpc_build = { package = "easytier-rpc-build", version = "0.1.0", features = ["internal-namespace"] } rpc_build = { package = "easytier-rpc-build", version = "0.1.0", features = ["internal-namespace"] }
prost-reflect-build = { version = "0.14.0" }
[target.'cfg(windows)'.build-dependencies] [target.'cfg(windows)'.build-dependencies]
reqwest = { version = "0.11", features = ["blocking"] } reqwest = { version = "0.11", features = ["blocking"] }
+7 -4
View File
@@ -141,7 +141,8 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
println!("cargo:rerun-if-changed={}", proto_file); println!("cargo:rerun-if-changed={}", proto_file);
} }
prost_build::Config::new() let mut config = prost_build::Config::new();
config
.protoc_arg("--experimental_allow_proto3_optional") .protoc_arg("--experimental_allow_proto3_optional")
.type_attribute(".common", "#[derive(serde::Serialize, serde::Deserialize)]") .type_attribute(".common", "#[derive(serde::Serialize, serde::Deserialize)]")
.type_attribute(".error", "#[derive(serde::Serialize, serde::Deserialize)]") .type_attribute(".error", "#[derive(serde::Serialize, serde::Deserialize)]")
@@ -156,9 +157,11 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
.type_attribute("peer_rpc.ForeignNetworkRouteInfoKey", "#[derive(Hash, Eq)]") .type_attribute("peer_rpc.ForeignNetworkRouteInfoKey", "#[derive(Hash, Eq)]")
.type_attribute("common.RpcDescriptor", "#[derive(Hash, Eq)]") .type_attribute("common.RpcDescriptor", "#[derive(Hash, Eq)]")
.service_generator(Box::new(rpc_build::ServiceGenerator::new())) .service_generator(Box::new(rpc_build::ServiceGenerator::new()))
.btree_map(["."]) .btree_map(["."]);
.compile_protos(&proto_files, &["src/proto/"])
.unwrap(); prost_reflect_build::Builder::new()
.file_descriptor_set_bytes("crate::proto::DESCRIPTOR_POOL_BYTES")
.compile_protos_with_config(config, &proto_files, &["src/proto/"])?;
check_locale(); check_locale();
Ok(()) Ok(())
@@ -284,6 +284,7 @@ impl PunchSymToConeHoleClient {
BaseController { BaseController {
timeout_ms: 4000, timeout_ms: 4000,
trace_id: 0, trace_id: 0,
..Default::default()
}, },
req, req,
) )
@@ -314,6 +315,7 @@ impl PunchSymToConeHoleClient {
BaseController { BaseController {
timeout_ms: 4000, timeout_ms: 4000,
trace_id: 0, trace_id: 0,
..Default::default()
}, },
req, req,
) )
-1
View File
@@ -1,5 +1,4 @@
use std::{ use std::{
fmt::Debug,
sync::{ sync::{
atomic::{AtomicU32, Ordering}, atomic::{AtomicU32, Ordering},
Arc, Arc,
+133 -15
View File
@@ -16,6 +16,8 @@ use petgraph::{
graph::NodeIndex, graph::NodeIndex,
Directed, Graph, Directed, Graph,
}; };
use prost::Message;
use prost_reflect::{DynamicMessage, ReflectMessage};
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use tokio::{ use tokio::{
select, select,
@@ -283,6 +285,8 @@ type Error = SyncRouteInfoError;
#[derive(Debug)] #[derive(Debug)]
struct SyncedRouteInfo { struct SyncedRouteInfo {
peer_infos: DashMap<PeerId, RoutePeerInfo>, peer_infos: DashMap<PeerId, RoutePeerInfo>,
// prost doesn't support unknown fields, so we use DynamicMessage to store raw infos and progate them to other peers.
raw_peer_infos: DashMap<PeerId, DynamicMessage>,
conn_map: DashMap<PeerId, (BTreeSet<PeerId>, AtomicVersion)>, conn_map: DashMap<PeerId, (BTreeSet<PeerId>, AtomicVersion)>,
foreign_network: DashMap<ForeignNetworkRouteInfoKey, ForeignNetworkRouteInfoEntry>, foreign_network: DashMap<ForeignNetworkRouteInfoKey, ForeignNetworkRouteInfoEntry>,
} }
@@ -297,6 +301,7 @@ impl SyncedRouteInfo {
fn remove_peer(&self, peer_id: PeerId) { fn remove_peer(&self, peer_id: PeerId) {
tracing::warn!(?peer_id, "remove_peer from synced_route_info"); tracing::warn!(?peer_id, "remove_peer from synced_route_info");
self.peer_infos.remove(&peer_id); self.peer_infos.remove(&peer_id);
self.raw_peer_infos.remove(&peer_id);
self.conn_map.remove(&peer_id); self.conn_map.remove(&peer_id);
self.foreign_network.retain(|k, _| k.peer_id != peer_id); self.foreign_network.retain(|k, _| k.peer_id != peer_id);
} }
@@ -369,8 +374,11 @@ impl SyncedRouteInfo {
my_peer_route_id: u64, my_peer_route_id: u64,
dst_peer_id: PeerId, dst_peer_id: PeerId,
peer_infos: &Vec<RoutePeerInfo>, peer_infos: &Vec<RoutePeerInfo>,
raw_peer_infos: &Vec<DynamicMessage>,
) -> Result<(), Error> { ) -> Result<(), Error> {
for mut route_info in peer_infos.iter().map(Clone::clone) { for (idx, route_info) in peer_infos.iter().enumerate() {
let mut route_info = route_info.clone();
let raw_route_info = &raw_peer_infos[idx];
self.check_duplicate_peer_id( self.check_duplicate_peer_id(
my_peer_id, my_peer_id,
my_peer_route_id, my_peer_route_id,
@@ -391,10 +399,16 @@ impl SyncedRouteInfo {
.entry(route_info.peer_id) .entry(route_info.peer_id)
.and_modify(|old_entry| { .and_modify(|old_entry| {
if route_info.version > old_entry.version { if route_info.version > old_entry.version {
self.raw_peer_infos
.insert(route_info.peer_id, raw_route_info.clone());
*old_entry = route_info.clone(); *old_entry = route_info.clone();
} }
}) })
.or_insert_with(|| route_info.clone()); .or_insert_with(|| {
self.raw_peer_infos
.insert(route_info.peer_id, raw_route_info.clone());
route_info.clone()
});
} }
Ok(()) Ok(())
} }
@@ -1047,6 +1061,7 @@ impl PeerRouteServiceImpl {
synced_route_info: SyncedRouteInfo { synced_route_info: SyncedRouteInfo {
peer_infos: DashMap::new(), peer_infos: DashMap::new(),
raw_peer_infos: DashMap::new(),
conn_map: DashMap::new(), conn_map: DashMap::new(),
foreign_network: DashMap::new(), foreign_network: DashMap::new(),
}, },
@@ -1381,6 +1396,39 @@ impl PeerRouteServiceImpl {
} }
} }
fn build_sync_route_raw_req(
req: &SyncRouteInfoRequest,
raw_peer_infos: &DashMap<PeerId, DynamicMessage>,
) -> DynamicMessage {
use prost_reflect::Value;
let mut req_dynamic_msg = DynamicMessage::new(SyncRouteInfoRequest::default().descriptor());
req_dynamic_msg.transcode_from(req).unwrap();
let peer_infos = req.peer_infos.as_ref().map(|x| &x.items);
if let Some(peer_infos) = peer_infos {
let mut peer_info_raws = Vec::new();
for peer_info in peer_infos.iter() {
if let Some(info) = raw_peer_infos.get(&peer_info.peer_id) {
peer_info_raws.push(Value::Message(info.clone()));
} else {
let mut p = DynamicMessage::new(RoutePeerInfo::default().descriptor());
p.transcode_from(peer_info).unwrap();
peer_info_raws.push(Value::Message(p));
}
}
let mut peer_infos = DynamicMessage::new(RoutePeerInfos::default().descriptor());
peer_infos.set_field_by_name("items", Value::List(peer_info_raws));
req_dynamic_msg.set_field_by_name("peer_infos", Value::Message(peer_infos));
}
tracing::trace!(?req_dynamic_msg, "build_sync_route_raw_req");
req_dynamic_msg
}
async fn sync_route_with_peer( async fn sync_route_with_peer(
&self, &self,
dst_peer_id: PeerId, dst_peer_id: PeerId,
@@ -1419,20 +1467,27 @@ impl PeerRouteServiceImpl {
self.global_ctx.get_network_name(), self.global_ctx.get_network_name(),
); );
let sync_route_info_req = SyncRouteInfoRequest {
my_peer_id,
my_session_id: session.my_session_id.load(Ordering::Relaxed),
is_initiator: session.we_are_initiator.load(Ordering::Relaxed),
peer_infos: peer_infos.clone().map(|x| RoutePeerInfos { items: x }),
conn_bitmap: conn_bitmap.clone().map(Into::into),
foreign_network_infos: foreign_network.clone(),
};
let mut ctrl = BaseController::default(); let mut ctrl = BaseController::default();
ctrl.set_timeout_ms(3000); ctrl.set_timeout_ms(3000);
let ret = rpc_stub ctrl.set_raw_input(
.sync_route_info( Self::build_sync_route_raw_req(
ctrl, &sync_route_info_req,
SyncRouteInfoRequest { &self.synced_route_info.raw_peer_infos,
my_peer_id,
my_session_id: session.my_session_id.load(Ordering::Relaxed),
is_initiator: session.we_are_initiator.load(Ordering::Relaxed),
peer_infos: peer_infos.clone().map(|x| RoutePeerInfos { items: x }),
conn_bitmap: conn_bitmap.clone().map(Into::into),
foreign_network_infos: foreign_network.clone(),
},
) )
.encode_to_vec()
.into(),
);
let ret = rpc_stub
.sync_route_info(ctrl, SyncRouteInfoRequest::default())
.await; .await;
if let Err(e) = &ret { if let Err(e) = &ret {
@@ -1508,12 +1563,30 @@ impl Debug for RouteSessionManager {
} }
} }
fn get_raw_peer_infos(req_raw_input: &mut bytes::Bytes) -> Option<Vec<DynamicMessage>> {
let sync_req_dynamic_msg =
DynamicMessage::decode(SyncRouteInfoRequest::default().descriptor(), req_raw_input)
.unwrap();
let peer_infos = sync_req_dynamic_msg.get_field_by_name("peer_infos")?;
let infos = peer_infos
.as_message()?
.get_field_by_name("items")?
.as_list()?
.iter()
.map(|x| x.as_message().unwrap().clone())
.collect();
Some(infos)
}
#[async_trait::async_trait] #[async_trait::async_trait]
impl OspfRouteRpc for RouteSessionManager { impl OspfRouteRpc for RouteSessionManager {
type Controller = BaseController; type Controller = BaseController;
async fn sync_route_info( async fn sync_route_info(
&self, &self,
_ctrl: BaseController, ctrl: BaseController,
request: SyncRouteInfoRequest, request: SyncRouteInfoRequest,
) -> Result<SyncRouteInfoResponse, rpc_types::error::Error> { ) -> Result<SyncRouteInfoResponse, rpc_types::error::Error> {
let from_peer_id = request.my_peer_id; let from_peer_id = request.my_peer_id;
@@ -1522,6 +1595,13 @@ impl OspfRouteRpc for RouteSessionManager {
let peer_infos = request.peer_infos.map(|x| x.items); let peer_infos = request.peer_infos.map(|x| x.items);
let conn_bitmap = request.conn_bitmap.map(Into::into); let conn_bitmap = request.conn_bitmap.map(Into::into);
let foreign_network = request.foreign_network_infos; let foreign_network = request.foreign_network_infos;
let raw_peer_infos = if peer_infos.is_some() {
let r = get_raw_peer_infos(&mut ctrl.get_raw_input().unwrap()).unwrap();
assert_eq!(r.len(), peer_infos.as_ref().unwrap().len());
Some(r)
} else {
None
};
let ret = self let ret = self
.do_sync_route_info( .do_sync_route_info(
@@ -1529,6 +1609,7 @@ impl OspfRouteRpc for RouteSessionManager {
from_session_id, from_session_id,
is_initiator, is_initiator,
peer_infos, peer_infos,
raw_peer_infos,
conn_bitmap, conn_bitmap,
foreign_network, foreign_network,
) )
@@ -1783,6 +1864,7 @@ impl RouteSessionManager {
from_session_id: SessionId, from_session_id: SessionId,
is_initiator: bool, is_initiator: bool,
peer_infos: Option<Vec<RoutePeerInfo>>, peer_infos: Option<Vec<RoutePeerInfo>>,
raw_peer_infos: Option<Vec<DynamicMessage>>,
conn_bitmap: Option<RouteConnBitmap>, conn_bitmap: Option<RouteConnBitmap>,
foreign_network: Option<RouteForeignNetworkInfos>, foreign_network: Option<RouteForeignNetworkInfos>,
) -> Result<SyncRouteInfoResponse, Error> { ) -> Result<SyncRouteInfoResponse, Error> {
@@ -1805,6 +1887,7 @@ impl RouteSessionManager {
service_impl.my_peer_route_id, service_impl.my_peer_route_id,
from_peer_id, from_peer_id,
peer_infos, peer_infos,
raw_peer_infos.as_ref().unwrap(),
)?; )?;
session.update_dst_saved_peer_info_version(peer_infos); session.update_dst_saved_peer_info_version(peer_infos);
need_update_route_table = true; need_update_route_table = true;
@@ -2123,18 +2206,26 @@ mod tests {
time::Duration, time::Duration,
}; };
use dashmap::DashMap;
use prost_reflect::{DynamicMessage, ReflectMessage};
use crate::{ use crate::{
common::{global_ctx::tests::get_mock_global_ctx, PeerId}, common::{global_ctx::tests::get_mock_global_ctx, PeerId},
connector::udp_hole_punch::tests::replace_stun_info_collector, connector::udp_hole_punch::tests::replace_stun_info_collector,
peers::{ peers::{
create_packet_recv_chan, create_packet_recv_chan,
peer_manager::{PeerManager, RouteAlgoType}, peer_manager::{PeerManager, RouteAlgoType},
peer_ospf_route::PeerRouteServiceImpl,
route_trait::{NextHopPolicy, Route, RouteCostCalculatorInterface}, route_trait::{NextHopPolicy, Route, RouteCostCalculatorInterface},
tests::connect_peer_manager, tests::connect_peer_manager,
}, },
proto::common::NatType, proto::{
common::NatType,
peer_rpc::{RoutePeerInfo, RoutePeerInfos, SyncRouteInfoRequest},
},
tunnel::common::tests::wait_for_condition, tunnel::common::tests::wait_for_condition,
}; };
use prost::Message;
use super::PeerRoute; use super::PeerRoute;
@@ -2554,4 +2645,31 @@ mod tests {
) )
.await; .await;
} }
#[tokio::test]
async fn test_raw_peer_info() {
let mut req = SyncRouteInfoRequest::default();
let raw_info_map: DashMap<PeerId, DynamicMessage> = DashMap::new();
req.peer_infos = Some(RoutePeerInfos {
items: vec![RoutePeerInfo {
peer_id: 1,
..Default::default()
}],
});
let mut raw_req = DynamicMessage::new(RoutePeerInfo::default().descriptor());
raw_req
.transcode_from(&req.peer_infos.as_ref().unwrap().items[0])
.unwrap();
raw_info_map.insert(1, raw_req);
let out = PeerRouteServiceImpl::build_sync_route_raw_req(&req, &raw_info_map);
let out_bytes = out.encode_to_vec();
let req2 = SyncRouteInfoRequest::decode(out_bytes.as_slice()).unwrap();
assert_eq!(req, req2);
}
} }
+3
View File
@@ -9,3 +9,6 @@ pub mod web;
#[cfg(test)] #[cfg(test)]
pub mod tests; pub mod tests;
const DESCRIPTOR_POOL_BYTES: &[u8] =
include_bytes!(concat!(env!("OUT_DIR"), "/file_descriptor_set.bin"));
+10 -3
View File
@@ -192,7 +192,7 @@ impl Client {
async fn call( async fn call(
&self, &self,
ctrl: Self::Controller, mut ctrl: Self::Controller,
method: <Self::Descriptor as ServiceDescriptor>::Method, method: <Self::Descriptor as ServiceDescriptor>::Method,
input: bytes::Bytes, input: bytes::Bytes,
) -> Result<bytes::Bytes> { ) -> Result<bytes::Bytes> {
@@ -224,7 +224,11 @@ impl Client {
}; };
let rpc_req = RpcRequest { let rpc_req = RpcRequest {
request: input.into(), request: if let Some(raw_input) = ctrl.get_raw_input() {
raw_input.into()
} else {
input.into()
},
timeout_ms: ctrl.timeout_ms(), timeout_ms: ctrl.timeout_ms(),
..Default::default() ..Default::default()
}; };
@@ -280,7 +284,10 @@ impl Client {
return Err(err.into()); return Err(err.into());
} }
Ok(bytes::Bytes::from(rpc_resp.response)) let raw_output = Bytes::from(rpc_resp.response.clone());
ctrl.set_raw_output(raw_output.clone());
Ok(raw_output)
} }
} }
+12 -9
View File
@@ -13,7 +13,7 @@ use crate::{
common::{join_joinset_background, PeerId}, common::{join_joinset_background, PeerId},
proto::{ proto::{
common::{self, CompressionAlgoPb, RpcCompressionInfo, RpcPacket, RpcRequest, RpcResponse}, common::{self, CompressionAlgoPb, RpcCompressionInfo, RpcPacket, RpcRequest, RpcResponse},
rpc_types::error::Result, rpc_types::{controller::Controller, error::Result},
}, },
tunnel::{ tunnel::{
mpsc::{MpscTunnel, MpscTunnelSender}, mpsc::{MpscTunnel, MpscTunnelSender},
@@ -155,16 +155,19 @@ impl Server {
}; };
let rpc_request = RpcRequest::decode(Bytes::from(body))?; let rpc_request = RpcRequest::decode(Bytes::from(body))?;
let timeout_duration = std::time::Duration::from_millis(rpc_request.timeout_ms as u64); let timeout_duration = std::time::Duration::from_millis(rpc_request.timeout_ms as u64);
let ctrl = RpcController::default(); let mut ctrl = RpcController::default();
Ok(timeout( let raw_req = Bytes::from(rpc_request.request);
ctrl.set_raw_input(raw_req.clone());
let ret = timeout(
timeout_duration, timeout_duration,
reg.call_method( reg.call_method(packet.descriptor.unwrap(), ctrl.clone(), raw_req),
packet.descriptor.unwrap(),
ctrl,
Bytes::from(rpc_request.request),
),
) )
.await??) .await??;
if let Some(raw_output) = ctrl.get_raw_output() {
Ok(raw_output)
} else {
Ok(ret)
}
} }
async fn handle_rpc(sender: MpscTunnelSender, packet: RpcPacket, reg: Arc<ServiceRegistry>) { async fn handle_rpc(sender: MpscTunnelSender, packet: RpcPacket, reg: Arc<ServiceRegistry>) {
+43 -1
View File
@@ -1,4 +1,9 @@
pub trait Controller: Send + Sync + 'static { use std::sync::{Arc, Mutex};
use bytes::Bytes;
// Controller must impl clone and all cloned controllers share the same data
pub trait Controller: Send + Sync + Clone + 'static {
fn timeout_ms(&self) -> i32 { fn timeout_ms(&self) -> i32 {
5000 5000
} }
@@ -10,12 +15,29 @@ pub trait Controller: Send + Sync + 'static {
fn trace_id(&self) -> i32 { fn trace_id(&self) -> i32 {
0 0
} }
fn set_raw_input(&mut self, _raw_input: Bytes) {}
fn get_raw_input(&self) -> Option<Bytes> {
None
}
fn set_raw_output(&mut self, _raw_output: Bytes) {}
fn get_raw_output(&self) -> Option<Bytes> {
None
}
} }
#[derive(Debug)] #[derive(Debug)]
pub struct BaseControllerRawData {
pub raw_input: Option<Bytes>,
pub raw_output: Option<Bytes>,
}
#[derive(Debug, Clone)]
pub struct BaseController { pub struct BaseController {
pub timeout_ms: i32, pub timeout_ms: i32,
pub trace_id: i32, pub trace_id: i32,
pub raw_data: Arc<Mutex<BaseControllerRawData>>,
} }
impl Controller for BaseController { impl Controller for BaseController {
@@ -34,6 +56,22 @@ impl Controller for BaseController {
fn trace_id(&self) -> i32 { fn trace_id(&self) -> i32 {
self.trace_id self.trace_id
} }
fn set_raw_input(&mut self, raw_input: Bytes) {
self.raw_data.lock().unwrap().raw_input = Some(raw_input);
}
fn get_raw_input(&self) -> Option<Bytes> {
self.raw_data.lock().unwrap().raw_input.clone()
}
fn set_raw_output(&mut self, raw_output: Bytes) {
self.raw_data.lock().unwrap().raw_output = Some(raw_output);
}
fn get_raw_output(&self) -> Option<Bytes> {
self.raw_data.lock().unwrap().raw_output.clone()
}
} }
impl Default for BaseController { impl Default for BaseController {
@@ -41,6 +79,10 @@ impl Default for BaseController {
Self { Self {
timeout_ms: 5000, timeout_ms: 5000,
trace_id: 0, trace_id: 0,
raw_data: Arc::new(Mutex::new(BaseControllerRawData {
raw_input: None,
raw_output: None,
})),
} }
} }
} }