Compare commits

..

13 Commits

Author SHA1 Message Date
Sijie.Sun
1eec27b5ff bump version to 2.4.2 (#1218) 2025-08-11 09:03:13 +08:00
Sijie.Sun
1de7777a71 fix quic transport panic (#1216) 2025-08-11 08:30:59 +08:00
Sijie.Sun
975ca8bd9c Update docker workflow (#1217)
1. push all supported platform
2. support unstable tag
2025-08-10 23:36:50 +08:00
Sijie.Sun
e43537939a clippy all codes (#1214)
1. clippy code
2. add fmt and clippy check in ci
2025-08-10 22:56:41 +08:00
CyiceK
0087ac3ffc feat(encrypt): Add XOR and ChaCha20 encryption with low-end device optimization and openssl support. (#1186)
Add ChaCha20 XOR algorithm, extend AES-GCM-256 capabilities, and integrate OpenSSL support.

---------

Co-authored-by: Sijie.Sun <sunsijie@buaa.edu.cn>
2025-08-09 18:53:55 +08:00
21paradox
7de4b33dd1 add FOREGROUND_SERVICE for no_tun mode, not using vpn service (#1203)
1. add FOREGROUND_SERVICE related code, connection not to be **blocked by android system** when apps running in background
2. no_tun mode not enabling vpnservice, makeing other app to use vpnservice

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
2025-08-09 18:34:45 +08:00
Sijie.Sun
8ffc2f12e4 optimize the condition of enabling kcp (#1210) 2025-08-09 16:16:09 +08:00
FuturePrayer
37b24164b6 add portforward config to gui (#1198)
* Added port forwarding to the GUI interface
* Separated port forwarding into a separate drop-down menu
2025-08-09 09:50:09 +08:00
Sijie.Sun
8cdb27d43d add stats metrics (#1207)
support new cli command `easytier-cli stats`

It's useful to find out which components are consuming bandwidth.
2025-08-09 00:06:35 +08:00
Sijie.Sun
efa17a7c10 fix dead loop in direct connecto if disable-p2p is enabled in dst (#1206) 2025-08-08 22:30:39 +08:00
Sijie.Sun
6d14e9e441 fix jemalloc prof feature (#1201) 2025-08-08 17:54:39 +08:00
fanyang
e3e406dcde cli: sort peers by IPv4 and hostname (#1191)
* cli: sort entries by IPv4 and hostname

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
2025-08-04 21:18:49 +08:00
sijie.sun
d0a6c93c2c fix ipv6 packet routing and avoid route looping
properly handle ipv6 link local address and exit node.
2025-08-03 18:10:27 +08:00
170 changed files with 4154 additions and 1626 deletions

View File

@@ -10,8 +10,24 @@ RUN ARTIFACT_ARCH=""; \
ARTIFACT_ARCH="x86_64"; \
elif [ "$TARGETPLATFORM" = "linux/arm64" ]; then \
ARTIFACT_ARCH="aarch64"; \
elif [ "$TARGETPLATFORM" = "linux/riscv64" ]; then \
ARTIFACT_ARCH="riscv64"; \
elif [ "$TARGETPLATFORM" = "linux/mips" ]; then \
ARTIFACT_ARCH="mips"; \
elif [ "$TARGETPLATFORM" = "linux/mipsel" ]; then \
ARTIFACT_ARCH="mipsel"; \
elif [ "$TARGETPLATFORM" = "linux/arm/v7" ]; then \
ARTIFACT_ARCH="armv7hf"; \
elif [ "$TARGETPLATFORM" = "linux/arm/v6" ]; then \
ARTIFACT_ARCH="armhf"; \
elif [ "$TARGETPLATFORM" = "linux/arm/v5" ]; then \
ARTIFACT_ARCH="arm"; \
elif [ "$TARGETPLATFORM" = "linux/arm" ]; then \
ARTIFACT_ARCH="armv7"; \
elif [ "$TARGETPLATFORM" = "linux/loong64" ]; then \
ARTIFACT_ARCH="loongarch64"; \
else \
echo "Unsupported architecture: $TARGETARCH"; \
echo "Unsupported architecture: $TARGETPLATFORM"; \
exit 1; \
fi; \
cp /tmp/artifacts/easytier-linux-${ARTIFACT_ARCH}/* /tmp/output;

View File

@@ -229,8 +229,8 @@ jobs:
rustup set auto-self-update disable
rustup install 1.87
rustup default 1.87
rustup install 1.89
rustup default 1.89
export CC=clang
export CXX=clang++

View File

@@ -11,13 +11,18 @@ on:
image_tag:
description: 'Tag for this image build'
type: string
default: 'v2.4.1'
default: 'v2.4.2'
required: true
mark_latest:
description: 'Mark this image as latest'
type: boolean
default: false
required: true
mark_unstable:
description: 'Mark this image as unstable'
type: boolean
default: false
required: true
jobs:
docker:
@@ -27,6 +32,13 @@ jobs:
-
name: Checkout
uses: actions/checkout@v4
-
name: Validate inputs
run: |
if [[ "${{ inputs.mark_latest }}" == "true" && "${{ inputs.mark_unstable }}" == "true" ]]; then
echo "Error: mark_latest and mark_unstable cannot both be true"
exit 1
fi
-
name: Set up QEMU
uses: docker/setup-qemu-action@v3
@@ -56,14 +68,36 @@ jobs:
- name: List files
run: |
ls -l -R .
- name: Prepare Docker tags
id: tags
run: |
# Base tags with version
DOCKERHUB_TAGS="easytier/easytier:${{ inputs.image_tag }}"
GHCR_TAGS="ghcr.io/easytier/easytier:${{ inputs.image_tag }}"
# Add latest tags if requested
if [[ "${{ inputs.mark_latest }}" == "true" ]]; then
DOCKERHUB_TAGS="${DOCKERHUB_TAGS},easytier/easytier:latest"
GHCR_TAGS="${GHCR_TAGS},ghcr.io/easytier/easytier:latest"
fi
# Add unstable tags if requested
if [[ "${{ inputs.mark_unstable }}" == "true" ]]; then
DOCKERHUB_TAGS="${DOCKERHUB_TAGS},easytier/easytier:unstable"
GHCR_TAGS="${GHCR_TAGS},ghcr.io/easytier/easytier:unstable"
fi
# Combine all tags
ALL_TAGS="${DOCKERHUB_TAGS},${GHCR_TAGS}"
echo "tags=${ALL_TAGS}" >> $GITHUB_OUTPUT
echo "Generated tags: ${ALL_TAGS}"
-
name: Build and push
uses: docker/build-push-action@v6
with:
context: ./docker_context
platforms: linux/amd64,linux/arm64
platforms: linux/amd64,linux/arm64,linux/riscv64,linux/mips,linux/mipsel,linux/arm/v7,linux/arm/v6,linux/arm/v5,linux/arm,linux/loong64
push: true
file: .github/workflows/Dockerfile
tags: |
easytier/easytier:${{ inputs.image_tag }}${{ inputs.mark_latest && ',easytier/easytier:latest' || '' }},
ghcr.io/easytier/easytier:${{ inputs.image_tag }}${{ inputs.mark_latest && ',easytier/easytier:latest' || '' }},
tags: ${{ steps.tags.outputs.tags }}

View File

@@ -29,7 +29,7 @@ jobs:
concurrent_skipping: 'same_content_newer'
skip_after_successful_duplicate: 'true'
cancel_others: 'true'
paths: '["Cargo.toml", "Cargo.lock", "easytier/**", "easytier-gui/**", ".github/workflows/gui.yml", ".github/workflows/install_rust.sh"]'
paths: '["Cargo.toml", "Cargo.lock", "easytier/**", "easytier-gui/**", ".github/workflows/gui.yml", ".github/workflows/install_rust.sh", ".github/workflows/install_gui_dep.sh"]'
build-gui:
strategy:
fail-fast: false
@@ -78,20 +78,11 @@ jobs:
needs: pre_job
if: needs.pre_job.outputs.should_skip != 'true'
steps:
- uses: actions/checkout@v3
- name: Install GUI dependencies (x86 only)
if: ${{ matrix.TARGET == 'x86_64-unknown-linux-musl' }}
run: |
sudo apt update
sudo apt install -qq libwebkit2gtk-4.1-dev \
build-essential \
curl \
wget \
file \
libgtk-3-dev \
librsvg2-dev \
libxdo-dev \
libssl-dev \
patchelf
run: bash ./.github/workflows/install_gui_dep.sh
- name: Install GUI cross compile (aarch64 only)
if: ${{ matrix.TARGET == 'aarch64-unknown-linux-musl' }}
@@ -128,8 +119,6 @@ jobs:
echo "PKG_CONFIG_SYSROOT_DIR=/usr/aarch64-linux-gnu/" >> "$GITHUB_ENV"
echo "PKG_CONFIG_PATH=/usr/lib/aarch64-linux-gnu/pkgconfig/" >> "$GITHUB_ENV"
- uses: actions/checkout@v3
- name: Set current ref as env variable
run: |
echo "GIT_DESC=$(git log -1 --format=%cd.%h --date=format:%Y-%m-%d_%H:%M:%S)" >> $GITHUB_ENV

11
.github/workflows/install_gui_dep.sh vendored Normal file
View File

@@ -0,0 +1,11 @@
sudo apt update
sudo apt install -qq libwebkit2gtk-4.1-dev \
build-essential \
curl \
wget \
file \
libgtk-3-dev \
librsvg2-dev \
libxdo-dev \
libssl-dev \
patchelf

View File

@@ -31,8 +31,8 @@ fi
# see https://github.com/rust-lang/rustup/issues/3709
rustup set auto-self-update disable
rustup install 1.87
rustup default 1.87
rustup install 1.89
rustup default 1.89
# mips/mipsel cannot add target from rustup, need compile by ourselves
if [[ $OS =~ ^ubuntu.*$ && $TARGET =~ ^mips.*$ ]]; then

View File

@@ -21,7 +21,7 @@ on:
version:
description: 'Version for this release'
type: string
default: 'v2.4.1'
default: 'v2.4.2'
required: true
make_latest:
description: 'Mark this release as latest'

View File

@@ -28,7 +28,7 @@ jobs:
# All of these options are optional, so you can remove them if you are happy with the defaults
concurrent_skipping: 'never'
skip_after_successful_duplicate: 'true'
paths: '["Cargo.toml", "Cargo.lock", "easytier/**", ".github/workflows/test.yml"]'
paths: '["Cargo.toml", "Cargo.lock", "easytier/**", ".github/workflows/test.yml", ".github/workflows/install_gui_dep.sh", ".github/workflows/install_rust.sh"]'
test:
runs-on: ubuntu-22.04
needs: pre_job
@@ -89,6 +89,24 @@ jobs:
./target
key: ${{ runner.os }}-cargo-test-${{ hashFiles('**/Cargo.lock') }}
- name: Install GUI dependencies (Used by clippy)
run: |
bash ./.github/workflows/install_gui_dep.sh
bash ./.github/workflows/install_rust.sh
rustup component add rustfmt
rustup component add clippy
- name: Check formatting
if: ${{ !cancelled() }}
run: cargo fmt --all -- --check
- name: Check Clippy
if: ${{ !cancelled() }}
# NOTE: tauri need `dist` dir in build.rs
run: |
mkdir -p easytier-gui/dist
cargo clippy --all-targets --all-features --all -- -D warnings
- name: Run tests
run: |
sudo prlimit --pid $$ --nofile=1048576:1048576

View File

@@ -26,7 +26,7 @@ Thank you for your interest in contributing to EasyTier! This document provides
#### Required Tools
- Node.js v21 or higher
- pnpm v9 or higher
- Rust toolchain (version 1.87)
- Rust toolchain (version 1.89)
- LLVM and Clang
- Protoc (Protocol Buffers compiler)
@@ -79,8 +79,8 @@ sudo apt install -y bridge-utils
2. Install dependencies:
```bash
# Install Rust toolchain
rustup install 1.87
rustup default 1.87
rustup install 1.89
rustup default 1.89
# Install project dependencies
pnpm -r install

View File

@@ -34,7 +34,7 @@
#### 必需工具
- Node.js v21 或更高版本
- pnpm v9 或更高版本
- Rust 工具链(版本 1.87
- Rust 工具链(版本 1.89
- LLVM 和 Clang
- ProtocProtocol Buffers 编译器)
@@ -87,8 +87,8 @@ sudo apt install -y bridge-utils
2. 安装依赖:
```bash
# 安装 Rust 工具链
rustup install 1.87
rustup default 1.87
rustup install 1.89
rustup default 1.89
# 安装项目依赖
pnpm -r install

17
Cargo.lock generated
View File

@@ -1979,7 +1979,7 @@ checksum = "0d6ef0072f8a535281e4876be788938b528e9a1d43900b82c2569af7da799125"
[[package]]
name = "easytier"
version = "2.4.0"
version = "2.4.2"
dependencies = [
"aes-gcm",
"anyhow",
@@ -2033,6 +2033,7 @@ dependencies = [
"network-interface",
"nix 0.29.0",
"once_cell",
"openssl",
"parking_lot",
"percent-encoding",
"petgraph 0.8.1",
@@ -2112,7 +2113,7 @@ dependencies = [
[[package]]
name = "easytier-gui"
version = "2.4.0"
version = "2.4.2"
dependencies = [
"anyhow",
"chrono",
@@ -2162,7 +2163,7 @@ dependencies = [
[[package]]
name = "easytier-web"
version = "2.4.0"
version = "2.4.2"
dependencies = [
"anyhow",
"async-trait",
@@ -5234,6 +5235,15 @@ version = "0.1.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ff011a302c396a5197692431fc1948019154afc178baf7d8e37367442a4601cf"
[[package]]
name = "openssl-src"
version = "300.5.2+3.5.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d270b79e2926f5150189d475bc7e9d2c69f9c4697b185fa917d5a32b792d21b4"
dependencies = [
"cc",
]
[[package]]
name = "openssl-sys"
version = "0.9.103"
@@ -5242,6 +5252,7 @@ checksum = "7f9e8deee91df40a943c71b917e5874b951d32a802526c85721ce3b776c929d6"
dependencies = [
"cc",
"libc",
"openssl-src",
"pkg-config",
"vcpkg",
]

View File

@@ -44,5 +44,7 @@
"prettier.enable": false,
"editor.formatOnSave": true,
"editor.formatOnSaveMode": "modifications",
"editor.formatOnPaste": false,
"editor.formatOnType": true,
}
}

View File

@@ -105,9 +105,9 @@ After successful execution, you can check the network status using `easytier-cli
```text
| ipv4 | hostname | cost | lat_ms | loss_rate | rx_bytes | tx_bytes | tunnel_proto | nat_type | id | version |
| ------------ | -------------- | ----- | ------ | --------- | -------- | -------- | ------------ | -------- | ---------- | --------------- |
| 10.126.126.1 | abc-1 | Local | * | * | * | * | udp | FullCone | 439804259 | 2.4.1-70e69a38~ |
| 10.126.126.2 | abc-2 | p2p | 3.452 | 0 | 17.33 kB | 20.42 kB | udp | FullCone | 390879727 | 2.4.1-70e69a38~ |
| | PublicServer_a | p2p | 27.796 | 0.000 | 50.01 kB | 67.46 kB | tcp | Unknown | 3771642457 | 2.4.1-70e69a38~ |
| 10.126.126.1 | abc-1 | Local | * | * | * | * | udp | FullCone | 439804259 | 2.4.2-70e69a38~ |
| 10.126.126.2 | abc-2 | p2p | 3.452 | 0 | 17.33 kB | 20.42 kB | udp | FullCone | 390879727 | 2.4.2-70e69a38~ |
| | PublicServer_a | p2p | 27.796 | 0.000 | 50.01 kB | 67.46 kB | tcp | Unknown | 3771642457 | 2.4.2-70e69a38~ |
```
You can test connectivity between nodes:

View File

@@ -106,9 +106,9 @@ sudo easytier-core -d --network-name abc --network-secret abc -p tcp://public.ea
```text
| ipv4 | hostname | cost | lat_ms | loss_rate | rx_bytes | tx_bytes | tunnel_proto | nat_type | id | version |
| ------------ | -------------- | ----- | ------ | --------- | -------- | -------- | ------------ | -------- | ---------- | --------------- |
| 10.126.126.1 | abc-1 | Local | * | * | * | * | udp | FullCone | 439804259 | 2.4.1-70e69a38~ |
| 10.126.126.2 | abc-2 | p2p | 3.452 | 0 | 17.33 kB | 20.42 kB | udp | FullCone | 390879727 | 2.4.1-70e69a38~ |
| | PublicServer_a | p2p | 27.796 | 0.000 | 50.01 kB | 67.46 kB | tcp | Unknown | 3771642457 | 2.4.1-70e69a38~ |
| 10.126.126.1 | abc-1 | Local | * | * | * | * | udp | FullCone | 439804259 | 2.4.2-70e69a38~ |
| 10.126.126.2 | abc-2 | p2p | 3.452 | 0 | 17.33 kB | 20.42 kB | udp | FullCone | 390879727 | 2.4.2-70e69a38~ |
| | PublicServer_a | p2p | 27.796 | 0.000 | 50.01 kB | 67.46 kB | tcp | Unknown | 3771642457 | 2.4.2-70e69a38~ |
```
您可以测试节点之间的连通性:

View File

@@ -29,8 +29,10 @@ fn set_error_msg(msg: &str) {
msg_buf[..len].copy_from_slice(bytes);
}
/// # Safety
/// Set the tun fd
#[no_mangle]
pub extern "C" fn set_tun_fd(
pub unsafe extern "C" fn set_tun_fd(
inst_name: *const std::ffi::c_char,
fd: std::ffi::c_int,
) -> std::ffi::c_int {
@@ -43,18 +45,23 @@ pub extern "C" fn set_tun_fd(
if !INSTANCE_NAME_ID_MAP.contains_key(&inst_name) {
return -1;
}
match INSTANCE_MANAGER.set_tun_fd(&INSTANCE_NAME_ID_MAP.get(&inst_name).unwrap().value(), fd) {
Ok(_) => {
0
}
Err(_) => {
-1
}
let inst_id = *INSTANCE_NAME_ID_MAP
.get(&inst_name)
.as_ref()
.unwrap()
.value();
match INSTANCE_MANAGER.set_tun_fd(&inst_id, fd) {
Ok(_) => 0,
Err(_) => -1,
}
}
/// # Safety
/// Get the last error message
#[no_mangle]
pub extern "C" fn get_error_msg(out: *mut *const std::ffi::c_char) {
pub unsafe extern "C" fn get_error_msg(out: *mut *const std::ffi::c_char) {
let msg_buf = ERROR_MSG.lock().unwrap();
if msg_buf.is_empty() {
unsafe {
@@ -78,8 +85,10 @@ pub extern "C" fn free_string(s: *const std::ffi::c_char) {
}
}
/// # Safety
/// Parse the config
#[no_mangle]
pub extern "C" fn parse_config(cfg_str: *const std::ffi::c_char) -> std::ffi::c_int {
pub unsafe extern "C" fn parse_config(cfg_str: *const std::ffi::c_char) -> std::ffi::c_int {
let cfg_str = unsafe {
assert!(!cfg_str.is_null());
std::ffi::CStr::from_ptr(cfg_str)
@@ -95,8 +104,10 @@ pub extern "C" fn parse_config(cfg_str: *const std::ffi::c_char) -> std::ffi::c_
0
}
/// # Safety
/// Run the network instance
#[no_mangle]
pub extern "C" fn run_network_instance(cfg_str: *const std::ffi::c_char) -> std::ffi::c_int {
pub unsafe extern "C" fn run_network_instance(cfg_str: *const std::ffi::c_char) -> std::ffi::c_int {
let cfg_str = unsafe {
assert!(!cfg_str.is_null());
std::ffi::CStr::from_ptr(cfg_str)
@@ -131,8 +142,10 @@ pub extern "C" fn run_network_instance(cfg_str: *const std::ffi::c_char) -> std:
0
}
/// # Safety
/// Retain the network instance
#[no_mangle]
pub extern "C" fn retain_network_instance(
pub unsafe extern "C" fn retain_network_instance(
inst_names: *const *const std::ffi::c_char,
length: usize,
) -> std::ffi::c_int {
@@ -168,13 +181,15 @@ pub extern "C" fn retain_network_instance(
return -1;
}
let _ = INSTANCE_NAME_ID_MAP.retain(|k, _| inst_names.contains(k));
INSTANCE_NAME_ID_MAP.retain(|k, _| inst_names.contains(k));
0
}
/// # Safety
/// Collect the network infos
#[no_mangle]
pub extern "C" fn collect_network_infos(
pub unsafe extern "C" fn collect_network_infos(
infos: *mut KeyValuePair,
max_length: usize,
) -> std::ffi::c_int {
@@ -233,7 +248,9 @@ mod tests {
network = "test_network"
"#;
let cstr = std::ffi::CString::new(cfg_str).unwrap();
assert_eq!(parse_config(cstr.as_ptr()), 0);
unsafe {
assert_eq!(parse_config(cstr.as_ptr()), 0);
}
}
#[test]
@@ -243,6 +260,8 @@ mod tests {
network = "test_network"
"#;
let cstr = std::ffi::CString::new(cfg_str).unwrap();
assert_eq!(run_network_instance(cstr.as_ptr()), 0);
unsafe {
assert_eq!(run_network_instance(cstr.as_ptr()), 0);
}
}
}

View File

@@ -1,6 +1,6 @@
id=easytier_magisk
name=EasyTier_Magisk
version=v2.4.1
version=v2.4.2
versionCode=1
author=EasyTier
description=easytier magisk module @EasyTier(https://github.com/EasyTier/EasyTier)

View File

@@ -1010,7 +1010,7 @@ dependencies = [
[[package]]
name = "easytier"
version = "2.4.1"
version = "2.4.2"
source = "git+https://github.com/EasyTier/EasyTier.git#a4bb555fac1046d0099c44676fa9d0d8cca55c99"
dependencies = [
"anyhow",

View File

@@ -1,7 +1,7 @@
{
"name": "easytier-gui",
"type": "module",
"version": "2.4.1",
"version": "2.4.2",
"private": true,
"packageManager": "pnpm@9.12.1+sha512.e5a7e52a4183a02d5931057f7a0dbff9d5e9ce3161e33fa68ae392125b79282a8a8a470a51dfc8a0ed86221442eb2fb57019b0990ed24fab519bf0e1bc5ccfc4",
"scripts": {

View File

@@ -1,6 +1,6 @@
[package]
name = "easytier-gui"
version = "2.4.1"
version = "2.4.2"
description = "EasyTier GUI"
authors = ["you"]
edition = "2021"

View File

@@ -1,6 +1,10 @@
<?xml version="1.0" encoding="utf-8"?>
<manifest xmlns:android="http://schemas.android.com/apk/res/android">
<uses-permission android:name="android.permission.INTERNET" />
<uses-permission android:name="android.permission.FOREGROUND_SERVICE" />
<uses-permission android:name="android.permission.POST_NOTIFICATIONS" />
<uses-permission android:name="android.permission.FOREGROUND_SERVICE_DATA_SYNC" />
<application
android:icon="@mipmap/ic_launcher"
android:label="@string/app_name"
@@ -18,6 +22,12 @@
<category android:name="android.intent.category.LAUNCHER" />
</intent-filter>
</activity>
<service
android:name=".MainForegroundService"
android:foregroundServiceType="dataSync"
android:enabled="true"
android:exported="false">
</service>
<provider
android:name="androidx.core.content.FileProvider"

View File

@@ -1,3 +1,20 @@
package com.kkrainbow.easytier
class MainActivity : TauriActivity()
import android.content.Intent
import android.os.Build
import android.os.Bundle
class MainActivity : TauriActivity() {
override fun onCreate(savedInstanceState: Bundle?) {
super.onCreate(savedInstanceState)
initService()
}
private fun initService() {
val serviceIntent = Intent(this, MainForegroundService::class.java)
if (Build.VERSION.SDK_INT >= Build.VERSION_CODES.O) {
startForegroundService(serviceIntent)
} else {
startService(serviceIntent)
}
}
}

View File

@@ -0,0 +1,64 @@
package com.kkrainbow.easytier
import android.app.Notification
import android.app.NotificationChannel
import android.app.NotificationManager
import android.app.Service
import android.content.Intent
import android.content.pm.ServiceInfo
import android.os.Build
import android.os.IBinder
import androidx.core.app.NotificationCompat
import android.util.Log
class MainForegroundService : Service() {
companion object {
const val CHANNEL_ID = "easytier_channel"
const val NOTIFICATION_ID = 1355
// You can add more constants if needed
}
override fun onCreate() {
super.onCreate()
}
override fun onStartCommand(intent: Intent?, flags: Int, startId: Int): Int {
createNotificationChannel()
val notification = NotificationCompat.Builder(this, CHANNEL_ID)
.setContentTitle("easytier Running")
.setContentText("easytier is available on localhost")
.setSmallIcon(android.R.drawable.ic_menu_manage)
.build()
if (Build.VERSION.SDK_INT >= Build.VERSION_CODES.Q) {
startForeground(
NOTIFICATION_ID,
notification,
ServiceInfo.FOREGROUND_SERVICE_TYPE_DATA_SYNC
)
} else {
startForeground(NOTIFICATION_ID, notification)
}
return START_STICKY
}
override fun onDestroy() {
super.onDestroy()
}
override fun onBind(intent: Intent?): IBinder? = null
private fun createNotificationChannel() {
if (Build.VERSION.SDK_INT >= Build.VERSION_CODES.O) {
try {
val channel = NotificationChannel(
CHANNEL_ID,
"easytier notice",
NotificationManager.IMPORTANCE_DEFAULT
)
val manager = getSystemService(NotificationManager::class.java)
manager?.createNotificationChannel(channel)
} catch (e: Exception) {
Log.e("MainForegroundService", "Failed to create notification channel", e)
}
}
}
}

View File

@@ -16,41 +16,13 @@ impl Command {
/// Check the state the current program running
///
/// Return `true` if the program is running as root, otherwise false
///
/// # Examples
///
/// ```no_run
/// use elevated_command::Command;
///
/// fn main() {
/// let is_elevated = Command::is_elevated();
///
/// }
/// ```
pub fn is_elevated() -> bool {
let uid = unsafe { libc::getuid() };
if uid == 0 {
true
} else {
false
}
uid == 0
}
/// Prompting the user with a graphical OS dialog for the root password,
/// excuting the command with escalated privileges, and return the output
///
/// # Examples
///
/// ```no_run
/// use elevated_command::Command;
/// use std::process::Command as StdCommand;
///
/// fn main() {
/// let mut cmd = StdCommand::new("path to the application");
/// let elevated_cmd = Command::new(cmd);
/// let output = elevated_cmd.output().unwrap();
/// }
/// ```
pub fn output(&self) -> Result<Output> {
let pkexec = PathBuf::from_str("/bin/pkexec")?;
let mut command = StdCommand::new(pkexec);
@@ -70,10 +42,8 @@ impl Command {
if let Ok(home) = home {
command.arg(format!("HOME={}", home));
}
} else {
if self.cmd.get_envs().any(|(_, v)| v.is_some()) {
command.arg("env");
}
} else if self.cmd.get_envs().any(|(_, v)| v.is_some()) {
command.arg("env");
}
for (k, v) in self.cmd.get_envs() {
if let Some(value) = v {

View File

@@ -40,22 +40,6 @@ impl Command {
/// To pass environment variables on Windows,
/// to inherit environment variables from the parent process and
/// to change the working directory will be supported in later versions
///
/// # Examples
///
/// ```no_run
/// use elevated_command::Command;
/// use std::process::Command as StdCommand;
///
/// fn main() {
/// let mut cmd = StdCommand::new("path to the application");
///
/// cmd.arg("some arg");
/// cmd.env("some key", "some value");
///
/// let elevated_cmd = Command::new(cmd);
/// }
/// ```
pub fn new(cmd: StdCommand) -> Self {
Self {
cmd,
@@ -67,73 +51,21 @@ impl Command {
/// Consumes the `Take`, returning the wrapped std::process::Command
///
/// # Examples
///
/// ```no_run
/// use elevated_command::Command;
/// use std::process::Command as StdCommand;
///
/// fn main() {
/// let mut cmd = StdCommand::new("path to the application");
/// let elevated_cmd = Command::new(cmd);
/// let cmd = elevated_cmd.into_inner();
/// }
/// ```
pub fn into_inner(self) -> StdCommand {
self.cmd
}
/// Gets a mutable reference to the underlying std::process::Command
///
/// # Examples
///
/// ```no_run
/// use elevated_command::Command;
/// use std::process::Command as StdCommand;
///
/// fn main() {
/// let mut cmd = StdCommand::new("path to the application");
/// let elevated_cmd = Command::new(cmd);
/// let cmd = elevated_cmd.get_ref();
/// }
/// ```
pub fn get_ref(&self) -> &StdCommand {
&self.cmd
}
/// Gets a reference to the underlying std::process::Command
///
/// # Examples
///
/// ```no_run
/// use elevated_command::Command;
/// use std::process::Command as StdCommand;
///
/// fn main() {
/// let mut cmd = StdCommand::new("path to the application");
/// let elevated_cmd = Command::new(cmd);
/// let cmd = elevated_cmd.get_mut();
/// }
/// ```
pub fn get_mut(&mut self) -> &mut StdCommand {
&mut self.cmd
}
/// Set the `icon` for the pop-up graphical OS dialog
///
/// This method is only applicable on `MacOS`
///
/// # Examples
///
/// ```no_run
/// use elevated_command::Command;
/// use std::process::Command as StdCommand;
///
/// fn main() {
/// let mut cmd = StdCommand::new("path to the application");
/// let elevated_cmd = Command::new(cmd);
/// elevated_cmd.icon(include_bytes!("path to the icon").to_vec());
/// }
/// ```
pub fn icon(&mut self, icon: Vec<u8>) -> &mut Self {
self.icon = Some(icon);
self
@@ -142,19 +74,6 @@ impl Command {
/// Set the name for the pop-up graphical OS dialog
///
/// This method is only applicable on `MacOS`
///
/// # Examples
///
/// ```no_run
/// use elevated_command::Command;
/// use std::process::Command as StdCommand;
///
/// fn main() {
/// let mut cmd = StdCommand::new("path to the application");
/// let elevated_cmd = Command::new(cmd);
/// elevated_cmd.name("some name".to_string());
/// }
/// ```
pub fn name(&mut self, name: String) -> &mut Self {
self.name = Some(name);
self

View File

@@ -17,7 +17,7 @@
"createUpdaterArtifacts": false
},
"productName": "easytier-gui",
"version": "2.4.1",
"version": "2.4.2",
"identifier": "com.kkrainbow.easytier",
"plugins": {},
"app": {

View File

@@ -115,6 +115,11 @@ function getRoutesForVpn(routes: Route[]): string[] {
async function onNetworkInstanceChange() {
console.error('vpn service watch network instance change ids', JSON.stringify(networkStore.networkInstanceIds))
const insts = networkStore.networkInstanceIds
const no_tun = networkStore.isNoTunEnabled(insts[0])
if (no_tun) {
await doStopVpn()
return
}
if (!insts) {
await doStopVpn()
return
@@ -132,14 +137,6 @@ async function onNetworkInstanceChange() {
return
}
// if use no tun mode, stop the vpn service
const no_tun = networkStore.isNoTunEnabled(insts[0])
if (no_tun) {
console.error('no tun mode, stop vpn service')
await doStopVpn()
return
}
let network_length = curNetworkInfo?.my_node_info?.virtual_ipv4.network_length
if (!network_length) {
network_length = 24
@@ -187,12 +184,26 @@ async function watchNetworkInstance() {
console.error('vpn service watch network instance')
}
function isNoTunEnabled(instanceId: string | undefined) {
if (!instanceId) {
return false
}
const no_tun = networkStore.isNoTunEnabled(instanceId)
if (no_tun) {
return true
}
return false
}
export async function initMobileVpnService() {
await registerVpnServiceListener()
await watchNetworkInstance()
}
export async function prepareVpnService() {
export async function prepareVpnService(instanceId: string) {
if (isNoTunEnabled(instanceId)) {
return
}
console.log('prepare vpn')
const prepare_ret = await prepare_vpn()
console.log('prepare vpn', JSON.stringify((prepare_ret)))

View File

@@ -102,7 +102,7 @@ networkStore.$subscribe(async () => {
async function runNetworkCb(cfg: NetworkTypes.NetworkConfig, cb: () => void) {
if (type() === 'android') {
await prepareVpnService()
await prepareVpnService(cfg.instance_id)
networkStore.clearNetworkInstances()
}
else {

View File

@@ -8,7 +8,7 @@ repository = "https://github.com/EasyTier/EasyTier"
authors = ["kkrainbow"]
keywords = ["vpn", "p2p", "network", "easytier"]
categories = ["network-programming", "command-line-utilities"]
rust-version = "1.87.0"
rust-version = "1.89.0"
license-file = "LICENSE"
readme = "README.md"

View File

@@ -14,18 +14,11 @@ const NAMESPACE: &str = "easytier::proto::rpc_types";
///
/// See the crate-level documentation for more info.
#[allow(missing_copy_implementations)]
#[derive(Clone, Debug)]
#[derive(Clone, Debug, Default)]
pub struct ServiceGenerator {
_private: (),
}
impl ServiceGenerator {
/// Create a new `ServiceGenerator` instance with the default options set.
pub fn new() -> ServiceGenerator {
ServiceGenerator { _private: () }
}
}
impl prost_build::ServiceGenerator for ServiceGenerator {
fn generate(&mut self, service: prost_build::Service, mut buf: &mut String) {
use std::fmt::Write;
@@ -78,7 +71,7 @@ impl prost_build::ServiceGenerator for ServiceGenerator {
enum_methods,
" {name} = {index},",
name = method.proto_name,
index = format!("{}", idx + 1)
index = idx + 1
)
.unwrap();
@@ -87,7 +80,7 @@ impl prost_build::ServiceGenerator for ServiceGenerator {
" {index} => Ok({service_name}MethodDescriptor::{name}),",
service_name = service.name,
name = method.proto_name,
index = format!("{}", idx + 1),
index = idx + 1,
)
.unwrap();
@@ -102,12 +95,12 @@ impl prost_build::ServiceGenerator for ServiceGenerator {
writeln!(
client_methods,
r#" async fn {name}(&self, ctrl: H::Controller, input: {input_type}) -> {namespace}::error::Result<{output_type}> {{
{client_name}::{name}_inner(self.0.clone(), ctrl, input).await
{client_name}Client::{name}_inner(self.0.clone(), ctrl, input).await
}}"#,
name = method.name,
input_type = method.input_type,
output_type = method.output_type,
client_name = format!("{}Client", service.name),
client_name = service.name,
namespace = NAMESPACE,
)
.unwrap();

View File

@@ -1,6 +1,6 @@
[package]
name = "easytier-web"
version = "2.4.1"
version = "2.4.2"
edition = "2021"
description = "Config server for easytier. easytier-core gets config from this and web frontend use it as restful api server."

View File

@@ -1,7 +1,10 @@
fn main() {
// enable thunk-rs when target os is windows and arch is x86_64 or i686
#[cfg(target_os = "windows")]
if !std::env::var("TARGET").unwrap_or_default().contains("aarch64"){
thunk::thunk();
}
}
fn main() {
// enable thunk-rs when target os is windows and arch is x86_64 or i686
#[cfg(target_os = "windows")]
if !std::env::var("TARGET")
.unwrap_or_default()
.contains("aarch64")
{
thunk::thunk();
}
}

View File

@@ -2,7 +2,13 @@
import InputGroup from 'primevue/inputgroup'
import InputGroupAddon from 'primevue/inputgroupaddon'
import { SelectButton, Checkbox, InputText, InputNumber, AutoComplete, Panel, Divider, ToggleButton, Button, Password } from 'primevue'
import { DEFAULT_NETWORK_CONFIG, NetworkConfig, NetworkingMethod } from '../types/network'
import {
addRow,
DEFAULT_NETWORK_CONFIG,
NetworkConfig,
NetworkingMethod,
removeRow
} from '../types/network'
import { defineProps, defineEmits, ref, } from 'vue'
import { useI18n } from 'vue-i18n'
@@ -163,6 +169,8 @@ const bool_flags: BoolFlag[] = [
{ field: 'enable_private_mode', help: 'enable_private_mode_help' },
]
const portForwardProtocolOptions = ref(["tcp","udp"]);
</script>
<template>
@@ -416,6 +424,73 @@ const bool_flags: BoolFlag[] = [
</div>
</Panel>
<Divider />
<Panel :header="t('port_forwards')" toggleable collapsed>
<div class="flex flex-col gap-y-2">
<div class="flex flex-row gap-x-9 flex-wrap w-full">
<div class="flex flex-col gap-2 grow p-fluid">
<div class="flex">
<label for="port_forwards">{{ t('port_forwards_help') }}</label>
</div>
<div v-for="(row, index) in curNetwork.port_forwards" class="form-row">
<div style="display: flex; gap: 0.5rem; align-items: flex-end;">
<SelectButton v-model="row.proto" :options="portForwardProtocolOptions" :allow-empty="false"/>
<div style="flex-grow: 4;">
<InputGroup>
<InputText
v-model="row.bind_ip"
:placeholder="t('port_forwards_bind_addr')"
/>
<InputGroupAddon>
<span style="font-weight: bold">:</span>
</InputGroupAddon>
<InputNumber v-model="row.bind_port" :format="false"
inputId="horizontal-buttons" :step="1" mode="decimal" :min="1"
:max="65535" fluid
class="max-w-20"/>
</InputGroup>
</div>
<div style="flex-grow: 4;">
<InputGroup>
<InputText
v-model="row.dst_ip"
:placeholder="t('port_forwards_dst_addr')"
/>
<InputGroupAddon>
<span style="font-weight: bold">:</span>
</InputGroupAddon>
<InputNumber v-model="row.dst_port" :format="false"
inputId="horizontal-buttons" :step="1" mode="decimal" :min="1"
:max="65535" fluid
class="max-w-20"/>
</InputGroup>
</div>
<div style="flex-grow: 1;">
<Button
v-if="curNetwork.port_forwards.length > 0"
icon="pi pi-trash"
severity="danger"
text
rounded
@click="removeRow(index,curNetwork.port_forwards)"
/>
</div>
</div>
</div>
<div class="flex justify-content-end mt-4">
<Button
icon="pi pi-plus"
:label="t('port_forwards_add_btn')"
severity="success"
@click="addRow(curNetwork.port_forwards)"
/>
</div>
</div>
</div>
</div>
</Panel>
<div class="flex pt-6 justify-center">
<Button :label="t('run_network')" icon="pi pi-arrow-right" icon-pos="right" :disabled="configInvalid"
@click="$emit('runNetwork', curNetwork)" />

View File

@@ -150,6 +150,12 @@ socks5_help: |
exit_nodes: 出口节点列表
exit_nodes_help: 转发所有流量的出口节点虚拟IPv4地址优先级由列表顺序决定
port_forwards: 端口转发
port_forwards_help: "将本地端口转发到虚拟网络中的远程端口。例如udp://0.0.0.0:12345/10.126.126.1:23456表示将本地UDP端口12345转发到虚拟网络中的10.126.126.1:23456。可以指定多个。"
port_forwards_bind_addr: "绑定地址0.0.0.0"
port_forwards_dst_addr: "目标地址10.126.126.1"
port_forwards_add_btn: "添加"
mtu: MTU
mtu_help: |

View File

@@ -151,6 +151,12 @@ socks5_help: |
exit_nodes: Exit Nodes
exit_nodes_help: Exit nodes to forward all traffic to, a virtual ipv4 address, priority is determined by the order of the list
port_forwards: Port Forward
port_forwards_help: "forward local port to remote port in virtual network. e.g.: udp://0.0.0.0:12345/10.126.126.1:23456, means forward local udp port 12345 to 10.126.126.1:23456 in the virtual network. can specify multiple."
port_forwards_bind_addr: "Bind address, e.g.: 0.0.0.0"
port_forwards_dst_addr: "Destination address, e.g.: 10.126.126.1"
port_forwards_add_btn: "Add"
mtu: MTU
mtu_help: |
MTU of the TUN device, default is 1380 for non-encryption, 1360 for encryption. Range:400-1380

View File

@@ -70,6 +70,8 @@ export interface NetworkConfig {
enable_private_mode?: boolean
rpc_portal_whitelists: string[]
port_forwards: PortForwardConfig[]
}
export function DEFAULT_NETWORK_CONFIG(): NetworkConfig {
@@ -132,6 +134,7 @@ export function DEFAULT_NETWORK_CONFIG(): NetworkConfig {
enable_magic_dns: false,
enable_private_mode: false,
rpc_portal_whitelists: [],
port_forwards: [],
}
}
@@ -255,6 +258,30 @@ export interface PeerConnStats {
latency_us: number
}
export interface PortForwardConfig {
bind_ip: string,
bind_port: number,
dst_ip: string,
dst_port: number,
proto: string
}
// 添加新行
export const addRow = (rows: PortForwardConfig[]) => {
rows.push({
proto: 'tcp',
bind_ip: '',
bind_port: 65535,
dst_ip: '',
dst_port: 65535,
});
};
// 删除行
export const removeRow = (index: number, rows: PortForwardConfig[]) => {
rows.splice(index, 1);
};
export enum EventType {
TunDeviceReady = 'TunDeviceReady', // string
TunDeviceError = 'TunDeviceError', // string

View File

@@ -25,7 +25,7 @@ fn load_geoip_db(geoip_db: Option<String>) -> Option<maxminddb::Reader<Vec<u8>>>
match maxminddb::Reader::open_readfile(&path) {
Ok(reader) => {
tracing::info!("Successfully loaded GeoIP2 database from {}", path);
return Some(reader);
Some(reader)
}
Err(err) => {
tracing::debug!("Failed to load GeoIP2 database from {}: {}", path, err);
@@ -207,10 +207,8 @@ impl ClientManager {
let region = city.subdivisions.map(|r| {
r.iter()
.map(|x| x.names.as_ref())
.flatten()
.map(|x| x.get("zh-CN").or_else(|| x.get("en")))
.flatten()
.filter_map(|x| x.names.as_ref())
.filter_map(|x| x.get("zh-CN").or_else(|| x.get("en")))
.map(|x| x.to_string())
.collect::<Vec<_>>()
.join(",")

View File

@@ -94,14 +94,10 @@ impl SessionRpcService {
return Ok(HeartbeatResponse {});
};
let machine_id: uuid::Uuid =
let machine_id: uuid::Uuid = req.machine_id.map(Into::into).ok_or(anyhow::anyhow!(
"Machine id is not set correctly, expect uuid but got: {:?}",
req.machine_id
.clone()
.map(Into::into)
.ok_or(anyhow::anyhow!(
"Machine id is not set correctly, expect uuid but got: {:?}",
req.machine_id
))?;
))?;
let user_id = storage
.db()
@@ -121,7 +117,7 @@ impl SessionRpcService {
if data.req.replace(req.clone()).is_none() {
assert!(data.storage_token.is_none());
data.storage_token = Some(StorageToken {
token: req.user_token.clone().into(),
token: req.user_token.clone(),
client_url: data.client_url.clone(),
machine_id,
user_id,

View File

@@ -34,7 +34,7 @@ impl TryFrom<WeakRefStorage> for Storage {
type Error = ();
fn try_from(weak: Weak<StorageInner>) -> Result<Self, Self::Error> {
weak.upgrade().map(|inner| Storage(inner)).ok_or(())
weak.upgrade().map(Storage).ok_or(())
}
}
@@ -51,9 +51,7 @@ impl Storage {
machine_id: &uuid::Uuid,
client_url: &url::Url,
) {
map.remove_if(&machine_id, |_, v| {
v.storage_token.client_url == *client_url
});
map.remove_if(machine_id, |_, v| v.storage_token.client_url == *client_url);
}
fn update_mid_to_client_info_map(
@@ -74,11 +72,7 @@ impl Storage {
}
pub fn update_client(&self, stoken: StorageToken, report_time: i64) {
let inner = self
.0
.user_clients_map
.entry(stoken.user_id)
.or_insert_with(DashMap::new);
let inner = self.0.user_clients_map.entry(stoken.user_id).or_default();
let client_info = ClientInfo {
storage_token: stoken.clone(),

View File

@@ -151,7 +151,7 @@ async fn get_dual_stack_listener(
} else {
None
};
let v4_listener = if let Ok(_) = local_ipv4().await {
let v4_listener = if local_ipv4().await.is_ok() {
get_listener_by_url(&format!("{}://0.0.0.0:{}", protocol, port).parse().unwrap()).ok()
} else {
None

View File

@@ -137,7 +137,7 @@ mod post {
mod get {
use crate::restful::{
captcha::{
captcha::spec::SpecCaptcha,
builder::spec::SpecCaptcha,
extension::{axum_tower_sessions::CaptchaAxumTowerSessionExt as _, CaptchaUtil},
NewCaptcha as _,
},

View File

@@ -46,22 +46,22 @@ pub(crate) struct Captcha {
/// 验证码文本类型 The character type of the captcha
pub enum CaptchaType {
/// 字母数字混合
TypeDefault = 1,
Default = 1,
/// 纯数字
TypeOnlyNumber,
OnlyNumber,
/// 纯字母
TypeOnlyChar,
OnlyChar,
/// 纯大写字母
TypeOnlyUpper,
OnlyUpper,
/// 纯小写字母
TypeOnlyLower,
OnlyLower,
/// 数字大写字母
TypeNumAndUpper,
NumAndUpper,
}
/// 内置字体 Fonts shipped with the library
@@ -92,29 +92,29 @@ impl Captcha {
/// 生成随机验证码
pub fn alphas(&mut self) -> Vec<char> {
let mut cs = vec!['\0'; self.len];
for i in 0..self.len {
for cs_i in cs.iter_mut() {
match self.char_type {
CaptchaType::TypeDefault => cs[i] = self.randoms.alpha(),
CaptchaType::TypeOnlyNumber => {
cs[i] = self.randoms.alpha_under(self.randoms.num_max_index)
CaptchaType::Default => *cs_i = self.randoms.alpha(),
CaptchaType::OnlyNumber => {
*cs_i = self.randoms.alpha_under(self.randoms.num_max_index)
}
CaptchaType::TypeOnlyChar => {
cs[i] = self
CaptchaType::OnlyChar => {
*cs_i = self
.randoms
.alpha_between(self.randoms.char_min_index, self.randoms.char_max_index)
}
CaptchaType::TypeOnlyUpper => {
cs[i] = self
CaptchaType::OnlyUpper => {
*cs_i = self
.randoms
.alpha_between(self.randoms.upper_min_index, self.randoms.upper_max_index)
}
CaptchaType::TypeOnlyLower => {
cs[i] = self
CaptchaType::OnlyLower => {
*cs_i = self
.randoms
.alpha_between(self.randoms.lower_min_index, self.randoms.lower_max_index)
}
CaptchaType::TypeNumAndUpper => {
cs[i] = self.randoms.alpha_under(self.randoms.upper_max_index)
CaptchaType::NumAndUpper => {
*cs_i = self.randoms.alpha_under(self.randoms.upper_max_index)
}
}
}
@@ -142,7 +142,7 @@ impl Captcha {
}
}
pub fn get_font(&mut self) -> Arc<Font> {
pub fn get_font(&'_ mut self) -> Arc<Font<'_>> {
if let Some(font) = font::get_font(&self.font_name) {
font
} else {
@@ -185,6 +185,7 @@ where
/// 特别地/In particular:
///
/// - 对算术验证码[ArithmeticCaptcha](crate::captcha::arithmetic::ArithmeticCaptcha)而言,这里的`len`是验证码中数字的数量。
///
/// For [ArithmeticCaptcha](crate::captcha::arithmetic::ArithmeticCaptcha), the `len` presents the count of the digits
/// in the Captcha.
fn with_size_and_len(width: i32, height: i32, len: usize) -> Self;
@@ -226,7 +227,7 @@ impl NewCaptcha for Captcha {
let len = 5;
let width = 130;
let height = 48;
let char_type = CaptchaType::TypeDefault;
let char_type = CaptchaType::Default;
let chars = None;
Self {

View File

@@ -1,6 +1,4 @@
use rand::{random};
use rand::random;
/// 随机数工具类
pub(crate) struct Randoms {

View File

@@ -10,7 +10,7 @@ use axum::response::Response;
use std::fmt::Debug;
use tower_sessions::Session;
const CAPTCHA_KEY: &'static str = "ez-captcha";
const CAPTCHA_KEY: &str = "ez-captcha";
/// Axum & Tower_Sessions
#[async_trait]
@@ -32,7 +32,7 @@ pub trait CaptchaAxumTowerSessionStaticExt {
/// Verify the Captcha code, and return whether user's code is correct.
async fn ver(code: &str, session: &Session) -> bool {
match session.get::<String>(CAPTCHA_KEY).await {
Ok(Some(ans)) => ans.to_ascii_lowercase() == code.to_ascii_lowercase(),
Ok(Some(ans)) => ans.eq_ignore_ascii_case(code),
_ => false,
}
}

View File

@@ -1,7 +1,7 @@
pub mod axum_tower_sessions;
use super::base::captcha::AbstractCaptcha;
use super::captcha::spec::SpecCaptcha;
use super::builder::spec::SpecCaptcha;
use super::{CaptchaFont, NewCaptcha};
/// 验证码工具类 - Captcha Utils

View File

@@ -117,7 +117,7 @@
#![allow(dead_code)]
pub(crate) mod base;
pub mod captcha;
pub mod builder;
pub mod extension;
mod utils;

View File

@@ -32,21 +32,24 @@ impl From<(u8, u8, u8)> for Color {
}
}
impl Into<(u8, u8, u8, u8)> for Color {
fn into(self) -> (u8, u8, u8, u8) {
impl From<Color> for (u8, u8, u8, u8) {
fn from(val: Color) -> Self {
(
(self.0 * 255.0) as u8,
(self.1 * 255.0) as u8,
(self.2 * 255.0) as u8,
(self.3 * 255.0) as u8,
(val.0 * 255.0) as u8,
(val.1 * 255.0) as u8,
(val.2 * 255.0) as u8,
(val.3 * 255.0) as u8,
)
}
}
impl Into<u32> for Color {
fn into(self) -> u32 {
let color: (u8, u8, u8, u8) = self.into();
(color.0 as u32) << 24 + (color.1 as u32) << 16 + (color.2 as u32) << 8 + (color.3 as u32)
impl From<Color> for u32 {
fn from(val: Color) -> Self {
let color: (u8, u8, u8, u8) = val.into();
(color.0 as u32)
<< (24 + (color.1 as u32))
<< (16 + (color.2 as u32))
<< (8 + (color.3 as u32))
}
}

View File

@@ -11,7 +11,7 @@ struct FontAssets;
// pub(crate) static ref FONTS: RwLock<HashMap<String, Arc<Font>>> = Default::default();
// }
pub fn get_font(font_name: &str) -> Option<Arc<Font>> {
pub fn get_font(font_name: &'_ str) -> Option<Arc<Font<'_>>> {
// let fonts_cell = FONTS.get_or_init(|| Default::default());
// let guard = fonts_cell.read();
//
@@ -31,7 +31,7 @@ pub fn get_font(font_name: &str) -> Option<Arc<Font>> {
// }
}
pub fn load_font(font_name: &str) -> Result<Option<Font>, Box<dyn Error>> {
pub fn load_font(font_name: &'_ str) -> Result<Option<Font<'_>>, Box<dyn Error>> {
match FontAssets::get(font_name) {
Some(assets) => {
let font = Font::try_from_vec(Vec::from(assets.data)).unwrap();

View File

@@ -143,7 +143,7 @@ impl RestfulServer {
return Err((StatusCode::UNAUTHORIZED, other_error("No such user").into()));
};
let machines = client_mgr.list_machine_by_user_id(user.id().clone()).await;
let machines = client_mgr.list_machine_by_user_id(user.id()).await;
Ok(GetSummaryJsonResp {
device_count: machines.len() as u32,

View File

@@ -8,7 +8,7 @@ use axum_login::AuthUser;
use easytier::launcher::NetworkConfig;
use easytier::proto::common::Void;
use easytier::proto::rpc_types::controller::BaseController;
use easytier::proto::web::*;
use easytier::proto::{self, web::*};
use crate::client_manager::session::{Location, Session};
use crate::client_manager::ClientManager;
@@ -85,7 +85,7 @@ impl NetworkApi {
let Some(user_id) = auth_session.user.as_ref().map(|x| x.id()) else {
return Err((
StatusCode::UNAUTHORIZED,
other_error(format!("No user id found")).into(),
other_error("No user id found".to_string()).into(),
));
};
Ok(user_id)
@@ -108,7 +108,7 @@ impl NetworkApi {
let Some(token) = result.get_token().await else {
return Err((
StatusCode::UNAUTHORIZED,
other_error(format!("No token reported")).into(),
other_error("No token reported".to_string()).into(),
));
};
@@ -120,7 +120,7 @@ impl NetworkApi {
{
return Err((
StatusCode::FORBIDDEN,
other_error(format!("Token mismatch")).into(),
other_error("Token mismatch".to_string()).into(),
));
}
@@ -177,7 +177,7 @@ impl NetworkApi {
.insert_or_update_user_network_config(
auth_session.user.as_ref().unwrap().id(),
machine_id,
resp.inst_id.clone().unwrap_or_default().into(),
resp.inst_id.unwrap_or_default().into(),
serde_json::to_string(&config).unwrap(),
)
.await
@@ -248,7 +248,7 @@ impl NetworkApi {
.await
.map_err(convert_rpc_error)?;
let running_inst_ids = ret.inst_ids.clone().into_iter().map(Into::into).collect();
let running_inst_ids = ret.inst_ids.clone().into_iter().collect();
// collect networks that are disabled
let disabled_inst_ids = client_mgr
@@ -261,7 +261,7 @@ impl NetworkApi {
.await
.map_err(convert_db_error)?
.iter()
.filter_map(|x| x.network_instance_id.clone().try_into().ok())
.map(|x| Into::<proto::common::Uuid>::into(x.network_instance_id.clone()))
.collect::<Vec<_>>();
Ok(ListNetworkInstanceIdsJsonResp {
@@ -330,9 +330,8 @@ impl NetworkApi {
// not implement disable all
return Err((
StatusCode::NOT_IMPLEMENTED,
other_error(format!("Not implemented")).into(),
))
.into();
other_error("Not implemented".to_string()).into(),
));
};
let sess = Self::get_session_by_machine_id(&auth_session, &client_mgr, &machine_id).await?;

View File

@@ -76,32 +76,32 @@ impl Backend {
pub async fn register_new_user(&self, new_user: &RegisterNewUser) -> anyhow::Result<()> {
let hashed_password = password_auth::generate_hash(new_user.credentials.password.as_str());
let mut txn = self.db.orm_db().begin().await?;
let txn = self.db.orm_db().begin().await?;
entity::users::ActiveModel {
username: Set(new_user.credentials.username.clone()),
password: Set(hashed_password.clone()),
..Default::default()
}
.save(&mut txn)
.save(&txn)
.await?;
entity::users_groups::ActiveModel {
user_id: Set(entity::users::Entity::find()
.filter(entity::users::Column::Username.eq(new_user.credentials.username.as_str()))
.one(&mut txn)
.one(&txn)
.await?
.unwrap()
.id),
group_id: Set(entity::groups::Entity::find()
.filter(entity::groups::Column::Name.eq("users"))
.one(&mut txn)
.one(&txn)
.await?
.unwrap()
.id),
..Default::default()
}
.save(&mut txn)
.save(&txn)
.await?;
txn.commit().await?;

View File

@@ -52,9 +52,7 @@ pub fn build_router(api_host: Option<url::Url>) -> Router {
router
};
let router = router.fallback_service(service);
router
router.fallback_service(service)
}
pub struct WebServer {

View File

@@ -3,12 +3,12 @@ name = "easytier"
description = "A full meshed p2p VPN, connecting all your devices in one network with one command."
homepage = "https://github.com/EasyTier/EasyTier"
repository = "https://github.com/EasyTier/EasyTier"
version = "2.4.1"
version = "2.4.2"
edition = "2021"
authors = ["kkrainbow"]
keywords = ["vpn", "p2p", "network", "easytier"]
categories = ["network-programming", "command-line-utilities"]
rust-version = "1.87.0"
rust-version = "1.89.0"
license-file = "LICENSE"
readme = "README.md"
@@ -150,6 +150,7 @@ boringtun = { package = "boringtun-easytier", version = "0.6.1", optional = true
ring = { version = "0.17", optional = true }
bitflags = "2.5"
aes-gcm = { version = "0.10.3", optional = true }
openssl = { version = "0.10", optional = true, features = ["vendored"] }
# for cli
tabled = "0.16"
@@ -249,7 +250,9 @@ windows-sys = { version = "0.52", features = [
winapi = { version = "0.3.9", features = ["impl-default"] }
[target.'cfg(not(windows))'.dependencies]
jemallocator = { package = "tikv-jemallocator", version = "0.6.0", optional = true }
jemallocator = { package = "tikv-jemallocator", version = "0.6.0", optional = true, features = [
"unprefixed_malloc_on_supported_platforms"
] }
jemalloc-ctl = { package = "tikv-jemalloc-ctl", version = "0.6.0", optional = true, features = [
] }
jemalloc-sys = { package = "tikv-jemalloc-sys", version = "0.6.0", features = [
@@ -296,6 +299,7 @@ full = [
"websocket",
"wireguard",
"aes-gcm",
"openssl-crypto", # need openssl-dev libs
"smoltcp",
"tun",
"socks5",
@@ -304,6 +308,7 @@ wireguard = ["dep:boringtun", "dep:ring"]
quic = ["dep:quinn", "dep:rustls", "dep:rcgen"]
mimalloc = ["dep:mimalloc"]
aes-gcm = ["dep:aes-gcm"]
openssl-crypto = ["dep:openssl"]
tun = ["dep:tun"]
websocket = [
"dep:tokio-websockets",

View File

@@ -116,7 +116,7 @@ fn check_locale() {
if let Ok(globs) = globwalk::glob(locale_path) {
for entry in globs {
if let Err(e) = entry {
println!("cargo:i18n-error={}", e);
println!("cargo:i18n-error={e}");
continue;
}
@@ -151,7 +151,7 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
];
for proto_file in proto_files.iter().chain(proto_files_reflect.iter()) {
println!("cargo:rerun-if-changed={}", proto_file);
println!("cargo:rerun-if-changed={proto_file}");
}
let mut config = prost_build::Config::new();
@@ -173,7 +173,7 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
.field_attribute(".web.NetworkConfig", "#[serde(default)]")
.service_generator(Box::new(rpc_build::ServiceGenerator::new()))
.btree_map(["."])
.skip_debug(&[".common.Ipv4Addr", ".common.Ipv6Addr", ".common.UUID"]);
.skip_debug([".common.Ipv4Addr", ".common.Ipv6Addr", ".common.UUID"]);
config.compile_protos(&proto_files, &["src/proto/"])?;

View File

@@ -95,6 +95,9 @@ core_clap:
disable_encryption:
en: "disable encryption for peers communication, default is false, must be same with peers"
zh-CN: "禁用对等节点通信的加密默认为false必须与对等节点相同"
encryption_algorithm:
en: "encryption algorithm to use, supported: '', 'xor', 'chacha20', 'aes-gcm', 'aes-gcm-256', 'openssl-aes128-gcm', 'openssl-aes256-gcm', 'openssl-chacha20'. Empty string means default (aes-gcm)"
zh-CN: "要使用的加密算法,支持:''默认aes-gcm、'xor'、'chacha20'、'aes-gcm'、'aes-gcm-256'、'openssl-aes128-gcm'、'openssl-aes256-gcm'、'openssl-chacha20'"
multi_thread:
en: "use multi-thread runtime, default is single-thread"
zh-CN: "使用多线程运行时,默认为单线程"
@@ -190,6 +193,18 @@ core_clap:
foreign_relay_bps_limit:
en: "the maximum bps limit for foreign network relay, default is no limit. unit: BPS (bytes per second)"
zh-CN: "作为共享节点时,限制非本地网络的流量转发速率,默认无限制,单位 BPS (字节每秒)"
tcp_whitelist:
en: "tcp port whitelist. Supports single ports (80) and ranges (8000-9000)"
zh-CN: "TCP 端口白名单。支持单个端口80和范围8000-9000"
udp_whitelist:
en: "udp port whitelist. Supports single ports (53) and ranges (5000-6000)"
zh-CN: "UDP 端口白名单。支持单个端口53和范围5000-6000"
disable_relay_kcp:
en: "if true, disable relay kcp packets. avoid consuming too many bandwidth. default is false"
zh-CN: "如果为true则禁止节点转发 KCP 数据包防止过度消耗流量。默认值为false"
enable_relay_foreign_network_kcp:
en: "if true, allow relay kcp packets from foreign network. default is false (not forward foreign network kcp packets)"
zh-CN: "如果为true则作为共享节点时也可以转发其他网络的 KCP 数据包。默认值为false不转发"
core_app:
panic_backtrace_save:

View File

@@ -178,6 +178,12 @@ impl AclLogContext {
}
}
pub type SharedState = (
Arc<DashMap<String, ConnTrackEntry>>,
Arc<DashMap<RateLimitKey, Arc<TokenBucket>>>,
Arc<DashMap<AclStatKey, u64>>,
);
// High-performance ACL processor - No more internal locks!
pub struct AclProcessor {
// Immutable rule vectors - no locks needed since they're never modified after creation
@@ -321,7 +327,7 @@ impl AclProcessor {
.rules
.iter()
.filter(|rule| rule.enabled)
.map(|rule| Self::convert_to_fast_lookup_rule(rule))
.map(Self::convert_to_fast_lookup_rule)
.collect::<Vec<_>>();
// Sort by priority (higher priority first)
@@ -422,7 +428,7 @@ impl AclProcessor {
self.inc_cache_entry_stats(cache_entry, packet_info);
return cache_entry.acl_result.clone().unwrap();
cache_entry.acl_result.clone().unwrap()
}
fn inc_cache_entry_stats(&self, cache_entry: &AclCacheEntry, packet_info: &PacketInfo) {
@@ -539,7 +545,7 @@ impl AclProcessor {
cache_entry.rule_stats_vec.push(rule.rule_stats.clone());
cache_entry.matched_rule = RuleId::Priority(rule.priority);
cache_entry.acl_result = Some(AclResult {
action: rule.action.clone(),
action: rule.action,
matched_rule: Some(RuleId::Priority(rule.priority)),
should_log: false,
log_context: Some(AclLogContext::RuleMatch {
@@ -595,13 +601,7 @@ impl AclProcessor {
}
/// Get shared state for preserving across hot reloads
pub fn get_shared_state(
&self,
) -> (
Arc<DashMap<String, ConnTrackEntry>>,
Arc<DashMap<RateLimitKey, Arc<TokenBucket>>>,
Arc<DashMap<AclStatKey, u64>>,
) {
pub fn get_shared_state(&self) -> SharedState {
(
self.conn_track.clone(),
self.rate_limiters.clone(),
@@ -698,9 +698,9 @@ impl AclProcessor {
}
/// Check connection state for stateful rules
fn check_connection_state(&self, conn_track_key: &String, packet_info: &PacketInfo) {
fn check_connection_state(&self, conn_track_key: &str, packet_info: &PacketInfo) {
self.conn_track
.entry(conn_track_key.clone())
.entry(conn_track_key.to_string())
.and_modify(|x| {
x.last_seen = SystemTime::now()
.duration_since(UNIX_EPOCH)
@@ -764,13 +764,13 @@ impl AclProcessor {
let src_ip_ranges = rule
.source_ips
.iter()
.filter_map(|ip_inet| Self::convert_ip_inet_to_cidr(ip_inet))
.filter_map(|x| Self::convert_ip_inet_to_cidr(x.as_str()))
.collect();
let dst_ip_ranges = rule
.destination_ips
.iter()
.filter_map(|ip_inet| Self::convert_ip_inet_to_cidr(ip_inet))
.filter_map(|x| Self::convert_ip_inet_to_cidr(x.as_str()))
.collect();
let src_port_ranges = rule
@@ -820,8 +820,8 @@ impl AclProcessor {
}
/// Convert IpInet to CIDR for fast lookup
fn convert_ip_inet_to_cidr(input: &String) -> Option<cidr::IpCidr> {
cidr::IpCidr::from_str(input.as_str()).ok()
fn convert_ip_inet_to_cidr(input: &str) -> Option<cidr::IpCidr> {
cidr::IpCidr::from_str(input).ok()
}
/// Increment statistics counter
@@ -898,17 +898,13 @@ impl AclProcessor {
}
// 新增辅助函数
fn parse_port_start(
port_strs: &::prost::alloc::vec::Vec<::prost::alloc::string::String>,
) -> Option<u16> {
fn parse_port_start(port_strs: &[String]) -> Option<u16> {
port_strs
.iter()
.filter_map(|s| parse_port_range(s).map(|(start, _)| start))
.min()
}
fn parse_port_end(
port_strs: &::prost::alloc::vec::Vec<::prost::alloc::string::String>,
) -> Option<u16> {
fn parse_port_end(port_strs: &[String]) -> Option<u16> {
port_strs
.iter()
.filter_map(|s| parse_port_range(s).map(|(_, end)| end))
@@ -1154,18 +1150,22 @@ mod tests {
let mut acl_v1 = AclV1::default();
// Create inbound chain
let mut chain = Chain::default();
chain.name = "test_inbound".to_string();
chain.chain_type = ChainType::Inbound as i32;
chain.enabled = true;
let mut chain = Chain {
name: "test_inbound".to_string(),
chain_type: ChainType::Inbound as i32,
enabled: true,
..Default::default()
};
// Allow all rule
let mut rule = Rule::default();
rule.name = "allow_all".to_string();
rule.priority = 100;
rule.enabled = true;
rule.action = Action::Allow as i32;
rule.protocol = Protocol::Any as i32;
let rule = Rule {
name: "allow_all".to_string(),
priority: 100,
enabled: true,
action: Action::Allow as i32,
protocol: Protocol::Any as i32,
..Default::default()
};
chain.rules.push(rule);
acl_v1.chains.push(chain);
@@ -1278,12 +1278,14 @@ mod tests {
// 创建新配置(模拟热加载)
let mut new_config = create_test_acl_config();
if let Some(ref mut acl_v1) = new_config.acl_v1 {
let mut drop_rule = Rule::default();
drop_rule.name = "drop_all".to_string();
drop_rule.priority = 200;
drop_rule.enabled = true;
drop_rule.action = Action::Drop as i32;
drop_rule.protocol = Protocol::Any as i32;
let drop_rule = Rule {
name: "drop_all".to_string(),
priority: 200,
enabled: true,
action: Action::Drop as i32,
protocol: Protocol::Any as i32,
..Default::default()
};
acl_v1.chains[0].rules.push(drop_rule);
}
@@ -1321,40 +1323,48 @@ mod tests {
let mut acl_config = Acl::default();
let mut acl_v1 = AclV1::default();
let mut chain = Chain::default();
chain.name = "performance_test".to_string();
chain.chain_type = ChainType::Inbound as i32;
chain.enabled = true;
let mut chain = Chain {
name: "performance_test".to_string(),
chain_type: ChainType::Inbound as i32,
enabled: true,
..Default::default()
};
// 1. High-priority simple rule for UDP (can be cached efficiently)
let mut simple_rule = Rule::default();
simple_rule.name = "simple_udp".to_string();
simple_rule.priority = 300;
simple_rule.enabled = true;
simple_rule.action = Action::Allow as i32;
simple_rule.protocol = Protocol::Udp as i32;
let simple_rule = Rule {
name: "simple_udp".to_string(),
priority: 300,
enabled: true,
action: Action::Allow as i32,
protocol: Protocol::Udp as i32,
..Default::default()
};
// No stateful or rate limit - can benefit from full cache optimization
chain.rules.push(simple_rule);
// 2. Medium-priority stateful + rate-limited rule for TCP (security critical)
let mut security_rule = Rule::default();
security_rule.name = "security_tcp".to_string();
security_rule.priority = 200;
security_rule.enabled = true;
security_rule.action = Action::Allow as i32;
security_rule.protocol = Protocol::Tcp as i32;
security_rule.stateful = true;
security_rule.rate_limit = 100; // 100 packets/sec
security_rule.burst_limit = 200;
let security_rule = Rule {
name: "security_tcp".to_string(),
priority: 200,
enabled: true,
action: Action::Allow as i32,
protocol: Protocol::Tcp as i32,
stateful: true,
rate_limit: 100,
burst_limit: 200,
..Default::default()
};
chain.rules.push(security_rule);
// 3. Low-priority default allow rule for Any
let mut default_rule = Rule::default();
default_rule.name = "default_allow".to_string();
default_rule.priority = 100;
default_rule.enabled = true;
default_rule.action = Action::Allow as i32;
default_rule.protocol = Protocol::Any as i32;
let default_rule = Rule {
name: "default_allow".to_string(),
priority: 100,
enabled: true,
action: Action::Allow as i32,
protocol: Protocol::Any as i32,
..Default::default()
};
chain.rules.push(default_rule);
acl_v1.chains.push(chain);
@@ -1441,15 +1451,16 @@ mod tests {
// Create a very restrictive rate-limited rule
if let Some(ref mut acl_v1) = acl_config.acl_v1 {
let mut rule = Rule::default();
rule.name = "strict_rate_limit".to_string();
rule.priority = 200;
rule.enabled = true;
rule.action = Action::Allow as i32;
rule.protocol = Protocol::Any as i32;
rule.rate_limit = 1; // Allow only 1 packet per second
rule.burst_limit = 1; // Burst of 1 packet
let rule = Rule {
name: "strict_rate_limit".to_string(),
priority: 200,
enabled: true,
action: Action::Allow as i32,
protocol: Protocol::Any as i32,
rate_limit: 1, // Allow only 1 packet per second
burst_limit: 1, // Burst of 1 packet
..Default::default()
};
acl_v1.chains[0].rules.push(rule);
}

View File

@@ -21,6 +21,12 @@ pub trait Compressor {
pub struct DefaultCompressor {}
impl Default for DefaultCompressor {
fn default() -> Self {
Self::new()
}
}
impl DefaultCompressor {
pub fn new() -> Self {
DefaultCompressor {}
@@ -195,11 +201,11 @@ pub mod tests {
packet,
packet.payload_len()
);
assert_eq!(packet.peer_manager_header().unwrap().is_compressed(), true);
assert!(packet.peer_manager_header().unwrap().is_compressed());
compressor.decompress(&mut packet).await.unwrap();
assert_eq!(packet.payload(), text);
assert_eq!(packet.peer_manager_header().unwrap().is_compressed(), false);
assert!(!packet.peer_manager_header().unwrap().is_compressed());
}
#[tokio::test]
@@ -215,10 +221,10 @@ pub mod tests {
.compress(&mut packet, CompressorAlgo::ZstdDefault)
.await
.unwrap();
assert_eq!(packet.peer_manager_header().unwrap().is_compressed(), false);
assert!(!packet.peer_manager_header().unwrap().is_compressed());
compressor.decompress(&mut packet).await.unwrap();
assert_eq!(packet.payload(), text);
assert_eq!(packet.peer_manager_header().unwrap().is_compressed(), false);
assert!(!packet.peer_manager_header().unwrap().is_compressed());
}
}

View File

@@ -1,8 +1,8 @@
use std::{
net::{Ipv4Addr, SocketAddr},
hash::Hasher,
net::{IpAddr, SocketAddr},
path::PathBuf,
sync::{Arc, Mutex},
u64,
};
use anyhow::Context;
@@ -40,16 +40,87 @@ pub fn gen_default_flags() -> Flags {
bind_device: true,
enable_kcp_proxy: false,
disable_kcp_input: false,
disable_relay_kcp: true,
disable_relay_kcp: false,
enable_relay_foreign_network_kcp: false,
accept_dns: false,
private_mode: false,
enable_quic_proxy: false,
disable_quic_input: false,
foreign_relay_bps_limit: u64::MAX,
multi_thread_count: 2,
encryption_algorithm: "aes-gcm".to_string(),
}
}
pub enum EncryptionAlgorithm {
AesGcm,
Aes256Gcm,
Xor,
#[cfg(feature = "wireguard")]
ChaCha20,
#[cfg(feature = "openssl-crypto")]
OpensslAesGcm,
#[cfg(feature = "openssl-crypto")]
OpensslChacha20,
#[cfg(feature = "openssl-crypto")]
OpensslAes256Gcm,
}
impl std::fmt::Display for EncryptionAlgorithm {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::AesGcm => write!(f, "aes-gcm"),
Self::Aes256Gcm => write!(f, "aes-256-gcm"),
Self::Xor => write!(f, "xor"),
#[cfg(feature = "wireguard")]
Self::ChaCha20 => write!(f, "chacha20"),
#[cfg(feature = "openssl-crypto")]
Self::OpensslAesGcm => write!(f, "openssl-aes-gcm"),
#[cfg(feature = "openssl-crypto")]
Self::OpensslChacha20 => write!(f, "openssl-chacha20"),
#[cfg(feature = "openssl-crypto")]
Self::OpensslAes256Gcm => write!(f, "openssl-aes-256-gcm"),
}
}
}
impl TryFrom<&str> for EncryptionAlgorithm {
type Error = anyhow::Error;
fn try_from(value: &str) -> Result<Self, Self::Error> {
match value {
"aes-gcm" => Ok(Self::AesGcm),
"aes-256-gcm" => Ok(Self::Aes256Gcm),
"xor" => Ok(Self::Xor),
#[cfg(feature = "wireguard")]
"chacha20" => Ok(Self::ChaCha20),
#[cfg(feature = "openssl-crypto")]
"openssl-aes-gcm" => Ok(Self::OpensslAesGcm),
#[cfg(feature = "openssl-crypto")]
"openssl-chacha20" => Ok(Self::OpensslChacha20),
#[cfg(feature = "openssl-crypto")]
"openssl-aes-256-gcm" => Ok(Self::OpensslAes256Gcm),
_ => Err(anyhow::anyhow!("invalid encryption algorithm")),
}
}
}
pub fn get_avaliable_encrypt_methods() -> Vec<&'static str> {
let mut r = vec!["aes-gcm", "aes-256-gcm", "xor"];
if cfg!(feature = "wireguard") {
r.push("chacha20");
}
if cfg!(feature = "openssl-crypto") {
r.extend(vec![
"openssl-aes-gcm",
"openssl-chacha20",
"openssl-aes-256-gcm",
]);
}
r
}
#[auto_impl::auto_impl(Box, &)]
pub trait ConfigLoader: Send + Sync {
fn get_id(&self) -> uuid::Uuid;
@@ -107,8 +178,8 @@ pub trait ConfigLoader: Send + Sync {
fn get_flags(&self) -> Flags;
fn set_flags(&self, flags: Flags);
fn get_exit_nodes(&self) -> Vec<Ipv4Addr>;
fn set_exit_nodes(&self, nodes: Vec<Ipv4Addr>);
fn get_exit_nodes(&self) -> Vec<IpAddr>;
fn set_exit_nodes(&self, nodes: Vec<IpAddr>);
fn get_routes(&self) -> Option<Vec<cidr::Ipv4Cidr>>;
fn set_routes(&self, routes: Option<Vec<cidr::Ipv4Cidr>>);
@@ -139,7 +210,7 @@ pub trait LoggingConfigLoader {
pub type NetworkSecretDigest = [u8; 32];
#[derive(Debug, Clone, Deserialize, Serialize, Default, Eq, Hash)]
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct NetworkIdentity {
pub network_name: String,
pub network_secret: Option<String>,
@@ -147,27 +218,53 @@ pub struct NetworkIdentity {
pub network_secret_digest: Option<NetworkSecretDigest>,
}
#[derive(Eq, PartialEq, Hash)]
struct NetworkIdentityWithOnlyDigest {
network_name: String,
network_secret_digest: Option<NetworkSecretDigest>,
}
impl From<NetworkIdentity> for NetworkIdentityWithOnlyDigest {
fn from(identity: NetworkIdentity) -> Self {
if identity.network_secret_digest.is_some() {
Self {
network_name: identity.network_name,
network_secret_digest: identity.network_secret_digest,
}
} else if identity.network_secret.is_some() {
let mut network_secret_digest = [0u8; 32];
generate_digest_from_str(
&identity.network_name,
identity.network_secret.as_ref().unwrap(),
&mut network_secret_digest,
);
Self {
network_name: identity.network_name,
network_secret_digest: Some(network_secret_digest),
}
} else {
Self {
network_name: identity.network_name,
network_secret_digest: None,
}
}
}
}
impl PartialEq for NetworkIdentity {
fn eq(&self, other: &Self) -> bool {
if self.network_name != other.network_name {
return false;
}
let self_with_digest = NetworkIdentityWithOnlyDigest::from(self.clone());
let other_with_digest = NetworkIdentityWithOnlyDigest::from(other.clone());
self_with_digest == other_with_digest
}
}
if self.network_secret.is_some()
&& other.network_secret.is_some()
&& self.network_secret != other.network_secret
{
return false;
}
impl Eq for NetworkIdentity {}
if self.network_secret_digest.is_some()
&& other.network_secret_digest.is_some()
&& self.network_secret_digest != other.network_secret_digest
{
return false;
}
return true;
impl std::hash::Hash for NetworkIdentity {
fn hash<H: Hasher>(&self, state: &mut H) {
let self_with_digest = NetworkIdentityWithOnlyDigest::from(self.clone());
self_with_digest.hash(state);
}
}
@@ -182,8 +279,10 @@ impl NetworkIdentity {
network_secret_digest: Some(network_secret_digest),
}
}
}
pub fn default() -> Self {
impl Default for NetworkIdentity {
fn default() -> Self {
Self::new("default".to_string(), "".to_string())
}
}
@@ -257,12 +356,12 @@ impl From<PortForwardConfigPb> for PortForwardConfig {
}
}
impl Into<PortForwardConfigPb> for PortForwardConfig {
fn into(self) -> PortForwardConfigPb {
impl From<PortForwardConfig> for PortForwardConfigPb {
fn from(val: PortForwardConfig) -> Self {
PortForwardConfigPb {
bind_addr: Some(self.bind_addr.into()),
dst_addr: Some(self.dst_addr.into()),
socket_type: match self.proto.to_lowercase().as_str() {
bind_addr: Some(val.bind_addr.into()),
dst_addr: Some(val.dst_addr.into()),
socket_type: match val.proto.to_lowercase().as_str() {
"tcp" => SocketType::Tcp as i32,
"udp" => SocketType::Udp as i32,
_ => SocketType::Tcp as i32,
@@ -283,7 +382,7 @@ struct Config {
network_identity: Option<NetworkIdentity>,
listeners: Option<Vec<url::Url>>,
mapped_listeners: Option<Vec<url::Url>>,
exit_nodes: Option<Vec<Ipv4Addr>>,
exit_nodes: Option<Vec<IpAddr>>,
peer: Option<Vec<PeerConfig>>,
proxy_network: Option<Vec<ProxyNetworkConfig>>,
@@ -422,8 +521,7 @@ impl ConfigLoader for TomlConfigLoader {
locked_config
.ipv4
.as_ref()
.map(|s| s.parse().ok())
.flatten()
.and_then(|s| s.parse().ok())
.map(|c: cidr::Ipv4Inet| {
if c.network_length() == 32 {
cidr::Ipv4Inet::new(c.address(), 24).unwrap()
@@ -434,28 +532,16 @@ impl ConfigLoader for TomlConfigLoader {
}
fn set_ipv4(&self, addr: Option<cidr::Ipv4Inet>) {
self.config.lock().unwrap().ipv4 = if let Some(addr) = addr {
Some(addr.to_string())
} else {
None
};
self.config.lock().unwrap().ipv4 = addr.map(|addr| addr.to_string());
}
fn get_ipv6(&self) -> Option<cidr::Ipv6Inet> {
let locked_config = self.config.lock().unwrap();
locked_config
.ipv6
.as_ref()
.map(|s| s.parse().ok())
.flatten()
locked_config.ipv6.as_ref().and_then(|s| s.parse().ok())
}
fn set_ipv6(&self, addr: Option<cidr::Ipv6Inet>) {
self.config.lock().unwrap().ipv6 = if let Some(addr) = addr {
Some(addr.to_string())
} else {
None
};
self.config.lock().unwrap().ipv6 = addr.map(|addr| addr.to_string());
}
fn get_dhcp(&self) -> bool {
@@ -529,7 +615,7 @@ impl ConfigLoader for TomlConfigLoader {
locked_config.instance_id = Some(id);
id
} else {
locked_config.instance_id.as_ref().unwrap().clone()
*locked_config.instance_id.as_ref().unwrap()
}
}
@@ -543,7 +629,7 @@ impl ConfigLoader for TomlConfigLoader {
.unwrap()
.network_identity
.clone()
.unwrap_or_else(NetworkIdentity::default)
.unwrap_or_default()
}
fn set_network_identity(&self, identity: NetworkIdentity) {
@@ -624,7 +710,7 @@ impl ConfigLoader for TomlConfigLoader {
self.config.lock().unwrap().flags_struct = Some(flags);
}
fn get_exit_nodes(&self) -> Vec<Ipv4Addr> {
fn get_exit_nodes(&self) -> Vec<IpAddr> {
self.config
.lock()
.unwrap()
@@ -633,7 +719,7 @@ impl ConfigLoader for TomlConfigLoader {
.unwrap_or_default()
}
fn set_exit_nodes(&self, nodes: Vec<Ipv4Addr>) {
fn set_exit_nodes(&self, nodes: Vec<IpAddr>) {
self.config.lock().unwrap().exit_nodes = Some(nodes);
}

View File

@@ -8,14 +8,14 @@ macro_rules! define_global_var {
#[macro_export]
macro_rules! use_global_var {
($name:ident) => {
crate::common::constants::$name.lock().unwrap().to_owned()
$crate::common::constants::$name.lock().unwrap().to_owned()
};
}
#[macro_export]
macro_rules! set_global_var {
($name:ident, $val:expr) => {
*crate::common::constants::$name.lock().unwrap() = $val
*$crate::common::constants::$name.lock().unwrap() = $val
};
}

View File

@@ -12,7 +12,9 @@ impl<F: FnOnce()> Defer<F> {
impl<F: FnOnce()> Drop for Defer<F> {
fn drop(&mut self) {
self.func.take().map(|f| f());
if let Some(f) = self.func.take() {
f()
}
}
}

View File

@@ -48,19 +48,15 @@ pub static RESOLVER: Lazy<Arc<Resolver<GenericConnector<TokioRuntimeProvider>>>>
pub async fn resolve_txt_record(domain_name: &str) -> Result<String, Error> {
let r = RESOLVER.clone();
let response = r.txt_lookup(domain_name).await.with_context(|| {
format!(
"txt_lookup failed, domain_name: {}",
domain_name.to_string()
)
})?;
let response = r
.txt_lookup(domain_name)
.await
.with_context(|| format!("txt_lookup failed, domain_name: {}", domain_name))?;
let txt_record = response.iter().next().with_context(|| {
format!(
"no txt record found, domain_name: {}",
domain_name.to_string()
)
})?;
let txt_record = response
.iter()
.next()
.with_context(|| format!("no txt record found, domain_name: {}", domain_name))?;
let txt_data = String::from_utf8_lossy(&txt_record.txt_data()[0]);
tracing::info!(?txt_data, ?domain_name, "get txt record");

View File

@@ -5,6 +5,7 @@ use std::{
};
use crate::common::config::ProxyNetworkConfig;
use crate::common::stats_manager::StatsManager;
use crate::common::token_bucket::TokenBucketManager;
use crate::peers::acl_filter::AclFilter;
use crate::proto::cli::PeerConnInfo;
@@ -83,6 +84,8 @@ pub struct GlobalCtx {
token_bucket_manager: TokenBucketManager,
stats_manager: Arc<StatsManager>,
acl_filter: Arc<AclFilter>,
}
@@ -101,7 +104,7 @@ impl std::fmt::Debug for GlobalCtx {
pub type ArcGlobalCtx = std::sync::Arc<GlobalCtx>;
impl GlobalCtx {
pub fn new(config_fs: impl ConfigLoader + 'static + Send + Sync) -> Self {
pub fn new(config_fs: impl ConfigLoader + 'static) -> Self {
let id = config_fs.get_id();
let network = config_fs.get_network_identity();
let net_ns = NetNS::new(config_fs.get_netns());
@@ -115,9 +118,11 @@ impl GlobalCtx {
let proxy_forward_by_system = config_fs.get_flags().proxy_forward_by_system;
let no_tun = config_fs.get_flags().no_tun;
let mut feature_flags = PeerFeatureFlag::default();
feature_flags.kcp_input = !config_fs.get_flags().disable_kcp_input;
feature_flags.no_relay_kcp = config_fs.get_flags().disable_relay_kcp;
let feature_flags = PeerFeatureFlag {
kcp_input: !config_fs.get_flags().disable_kcp_input,
no_relay_kcp: config_fs.get_flags().disable_relay_kcp,
..Default::default()
};
GlobalCtx {
inst_name: config_fs.get_inst_name(),
@@ -151,6 +156,8 @@ impl GlobalCtx {
token_bucket_manager: TokenBucketManager::new(),
stats_manager: Arc::new(StatsManager::new()),
acl_filter: Arc::new(AclFilter::new()),
}
}
@@ -180,7 +187,7 @@ impl GlobalCtx {
{
Ok(())
} else {
Err(anyhow::anyhow!("network {} not in whitelist", network_name).into())
Err(anyhow::anyhow!("network {} not in whitelist", network_name))
}
}
@@ -189,8 +196,8 @@ impl GlobalCtx {
return Some(ret);
}
let addr = self.config.get_ipv4();
self.cached_ipv4.store(addr.clone());
return addr;
self.cached_ipv4.store(addr);
addr
}
pub fn set_ipv4(&self, addr: Option<cidr::Ipv4Inet>) {
@@ -203,8 +210,8 @@ impl GlobalCtx {
return Some(ret);
}
let addr = self.config.get_ipv6();
self.cached_ipv6.store(addr.clone());
return addr;
self.cached_ipv6.store(addr);
addr
}
pub fn set_ipv6(&self, addr: Option<cidr::Ipv6Inet>) {
@@ -291,6 +298,29 @@ impl GlobalCtx {
key
}
pub fn get_256_key(&self) -> [u8; 32] {
let mut key = [0u8; 32];
let secret = self
.config
.get_network_identity()
.network_secret
.unwrap_or_default();
// fill key according to network secret
let mut hasher = DefaultHasher::new();
hasher.write(secret.as_bytes());
hasher.write(b"easytier-256bit-key"); // 添加固定盐值以区分128位和256位密钥
// 生成32字节密钥
for i in 0..4 {
let chunk_start = i * 8;
let chunk_end = chunk_start + 8;
hasher.write(&key[0..chunk_start]);
hasher.write(&[i as u8]); // 添加索引以确保每个8字节块都不同
key[chunk_start..chunk_end].copy_from_slice(&hasher.finish().to_be_bytes());
}
key
}
pub fn enable_exit_node(&self) -> bool {
self.enable_exit_node
}
@@ -323,6 +353,10 @@ impl GlobalCtx {
&self.token_bucket_manager
}
pub fn stats_manager(&self) -> &Arc<StatsManager> {
&self.stats_manager
}
pub fn get_acl_filter(&self) -> &Arc<AclFilter> {
&self.acl_filter
}
@@ -344,18 +378,18 @@ pub mod tests {
let mut subscriber = global_ctx.subscribe();
let peer_id = new_peer_id();
global_ctx.issue_event(GlobalCtxEvent::PeerAdded(peer_id.clone()));
global_ctx.issue_event(GlobalCtxEvent::PeerRemoved(peer_id.clone()));
global_ctx.issue_event(GlobalCtxEvent::PeerAdded(peer_id));
global_ctx.issue_event(GlobalCtxEvent::PeerRemoved(peer_id));
global_ctx.issue_event(GlobalCtxEvent::PeerConnAdded(PeerConnInfo::default()));
global_ctx.issue_event(GlobalCtxEvent::PeerConnRemoved(PeerConnInfo::default()));
assert_eq!(
subscriber.recv().await.unwrap(),
GlobalCtxEvent::PeerAdded(peer_id.clone())
GlobalCtxEvent::PeerAdded(peer_id)
);
assert_eq!(
subscriber.recv().await.unwrap(),
GlobalCtxEvent::PeerRemoved(peer_id.clone())
GlobalCtxEvent::PeerRemoved(peer_id)
);
assert_eq!(
subscriber.recv().await.unwrap(),
@@ -372,7 +406,7 @@ pub mod tests {
) -> ArcGlobalCtx {
let config_fs = TomlConfigLoader::default();
config_fs.set_inst_name(format!("test_{}", config_fs.get_id()));
config_fs.set_network_identity(network_identy.unwrap_or(NetworkIdentity::default()));
config_fs.set_network_identity(network_identy.unwrap_or_default());
let ctx = Arc::new(GlobalCtx::new(config_fs));
ctx.replace_stun_info_collector(Box::new(MockStunInfoCollector {

View File

@@ -1,6 +1,6 @@
#[cfg(any(target_os = "macos", target_os = "freebsd"))]
mod darwin;
#[cfg(any(target_os = "linux"))]
#[cfg(target_os = "linux")]
mod netlink;
#[cfg(target_os = "windows")]
mod win;
@@ -141,7 +141,7 @@ pub struct DummyIfConfiger {}
#[async_trait]
impl IfConfiguerTrait for DummyIfConfiger {}
#[cfg(any(target_os = "linux"))]
#[cfg(target_os = "linux")]
pub type IfConfiger = netlink::NetlinkIfConfiger;
#[cfg(any(target_os = "macos", target_os = "freebsd"))]

View File

@@ -85,14 +85,14 @@ fn send_netlink_req_and_wait_one_resp<T: NetlinkDeserializable + NetlinkSerializ
match ret.payload {
NetlinkPayload::Error(e) => {
if e.code == NonZero::new(0) {
return Ok(());
Ok(())
} else {
return Err(e.to_io().into());
Err(e.to_io().into())
}
}
p => {
tracing::error!("Unexpected netlink response: {:?}", p);
return Err(anyhow::anyhow!("Unexpected netlink response").into());
Err(anyhow::anyhow!("Unexpected netlink response").into())
}
}
}
@@ -263,8 +263,8 @@ impl NetlinkIfConfiger {
let (address, netmask) = match (address.family(), netmask.family()) {
(Some(Inet), Some(Inet)) => (
IpAddr::V4(address.as_sockaddr_in().unwrap().ip().into()),
IpAddr::V4(netmask.as_sockaddr_in().unwrap().ip().into()),
IpAddr::V4(address.as_sockaddr_in().unwrap().ip()),
IpAddr::V4(netmask.as_sockaddr_in().unwrap().ip()),
),
(Some(Inet6), Some(Inet6)) => (
IpAddr::V6(address.as_sockaddr_in6().unwrap().ip()),
@@ -333,7 +333,7 @@ impl NetlinkIfConfiger {
let mut resp = Vec::<u8>::new();
loop {
if resp.len() == 0 {
if resp.is_empty() {
let (new_resp, _) = s.recv_from_full()?;
resp = new_resp;
}

View File

@@ -727,7 +727,7 @@ impl InterfaceLuid {
if family == (AF_INET6 as ADDRESS_FAMILY) {
// ipv6 mtu must be at least 1280
mtu = 1280.max(mtu);
}
}
// https://stackoverflow.com/questions/54857292/setipinterfaceentry-returns-error-invalid-parameter
row.SitePrefixLength = 0;

View File

@@ -1,3 +1,3 @@
pub mod luid;
pub mod netsh;
pub mod types;
pub mod luid;

View File

@@ -115,4 +115,4 @@ pub fn add_dns_ipv6(if_index: u32, dnses: &[Ipv6Addr]) -> Result<(), String> {
}
let dnses_str: Vec<String> = dnses.iter().map(|addr| addr.to_string()).collect();
add_dns(AF_INET6 as _, if_index, &dnses_str)
}
}

View File

@@ -100,4 +100,4 @@ pub fn u16_ptr_to_string(ptr: *const u16) -> String {
let slice = unsafe { std::slice::from_raw_parts(ptr, len) };
String::from_utf16_lossy(slice)
}
}

View File

@@ -22,6 +22,7 @@ pub mod ifcfg;
pub mod netns;
pub mod network;
pub mod scoped_task;
pub mod stats_manager;
pub mod stun;
pub mod stun_codec_ext;
pub mod token_bucket;
@@ -139,8 +140,8 @@ pub fn get_machine_id() -> uuid::Uuid {
)))]
let gen_mid = None;
if gen_mid.is_some() {
return gen_mid.unwrap();
if let Some(mid) = gen_mid {
return mid;
}
let gen_mid = uuid::Uuid::new_v4();

View File

@@ -34,13 +34,12 @@ impl NetNSGuard {
return;
}
let ns_path: String;
let name = name.unwrap();
if name == ROOT_NETNS_NAME {
ns_path = "/proc/1/ns/net".to_string();
let ns_path: String = if name == ROOT_NETNS_NAME {
"/proc/1/ns/net".to_string()
} else {
ns_path = format!("/var/run/netns/{}", name);
}
format!("/var/run/netns/{}", name)
};
let ns = std::fs::File::open(ns_path).unwrap();
tracing::info!(

View File

@@ -211,7 +211,7 @@ impl IPCollector {
cached_ip_list.read().await.public_ipv6
);
let sleep_sec = if !cached_ip_list.read().await.public_ipv4.is_none() {
let sleep_sec = if cached_ip_list.read().await.public_ipv4.is_some() {
CACHED_IP_LIST_TIMEOUT_SEC
} else {
3
@@ -252,14 +252,11 @@ impl IPCollector {
for iface in ifaces {
for ip in iface.ips {
let ip: std::net::IpAddr = ip.ip();
match ip {
std::net::IpAddr::V4(v4) => {
if ip.is_loopback() || ip.is_multicast() {
continue;
}
ret.interface_ipv4s.push(v4.into());
if let std::net::IpAddr::V4(v4) = ip {
if ip.is_loopback() || ip.is_multicast() {
continue;
}
_ => {}
ret.interface_ipv4s.push(v4.into());
}
}
}
@@ -269,14 +266,11 @@ impl IPCollector {
for iface in ifaces {
for ip in iface.ips {
let ip: std::net::IpAddr = ip.ip();
match ip {
std::net::IpAddr::V6(v6) => {
if v6.is_multicast() || v6.is_loopback() || v6.is_unicast_link_local() {
continue;
}
ret.interface_ipv6s.push(v6.into());
if let std::net::IpAddr::V6(v6) = ip {
if v6.is_multicast() || v6.is_loopback() || v6.is_unicast_link_local() {
continue;
}
_ => {}
ret.interface_ipv6s.push(v6.into());
}
}
}

View File

@@ -0,0 +1,886 @@
use dashmap::DashMap;
use serde::{Deserialize, Serialize};
use std::cell::UnsafeCell;
use std::fmt;
use std::sync::Arc;
use std::time::{Duration, Instant};
use tokio::time::interval;
use crate::common::scoped_task::ScopedTask;
/// Predefined metric names for type safety
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub enum MetricName {
/// RPC calls sent to peers
PeerRpcClientTx,
/// RPC calls received from peers
PeerRpcClientRx,
/// RPC calls sent to peers
PeerRpcServerTx,
/// RPC calls received from peers
PeerRpcServerRx,
/// RPC call duration in milliseconds
PeerRpcDuration,
/// RPC errors
PeerRpcErrors,
/// Traffic bytes sent
TrafficBytesTx,
/// Traffic bytes received
TrafficBytesRx,
/// Traffic bytes forwarded
TrafficBytesForwarded,
/// Traffic bytes sent to self
TrafficBytesSelfTx,
/// Traffic bytes received from self
TrafficBytesSelfRx,
/// Traffic bytes forwarded for foreign network, rx to local
TrafficBytesForeignForwardRx,
/// Traffic bytes forwarded for foreign network, tx from local
TrafficBytesForeignForwardTx,
/// Traffic bytes forwarded for foreign network, forward
TrafficBytesForeignForwardForwarded,
/// Traffic packets sent
TrafficPacketsTx,
/// Traffic packets received
TrafficPacketsRx,
/// Traffic packets forwarded
TrafficPacketsForwarded,
/// Traffic packets sent to self
TrafficPacketsSelfTx,
/// Traffic packets received from self
TrafficPacketsSelfRx,
/// Traffic packets forwarded for foreign network, rx to local
TrafficPacketsForeignForwardRx,
/// Traffic packets forwarded for foreign network, tx from local
TrafficPacketsForeignForwardTx,
/// Traffic packets forwarded for foreign network, forward
TrafficPacketsForeignForwardForwarded,
/// Compression bytes before compression
CompressionBytesRxBefore,
/// Compression bytes after compression
CompressionBytesRxAfter,
/// Compression bytes before compression
CompressionBytesTxBefore,
/// Compression bytes after compression
CompressionBytesTxAfter,
}
impl fmt::Display for MetricName {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
MetricName::PeerRpcClientTx => write!(f, "peer_rpc_client_tx"),
MetricName::PeerRpcClientRx => write!(f, "peer_rpc_client_rx"),
MetricName::PeerRpcServerTx => write!(f, "peer_rpc_server_tx"),
MetricName::PeerRpcServerRx => write!(f, "peer_rpc_server_rx"),
MetricName::PeerRpcDuration => write!(f, "peer_rpc_duration_ms"),
MetricName::PeerRpcErrors => write!(f, "peer_rpc_errors"),
MetricName::TrafficBytesTx => write!(f, "traffic_bytes_tx"),
MetricName::TrafficBytesRx => write!(f, "traffic_bytes_rx"),
MetricName::TrafficBytesForwarded => write!(f, "traffic_bytes_forwarded"),
MetricName::TrafficBytesSelfTx => write!(f, "traffic_bytes_self_tx"),
MetricName::TrafficBytesSelfRx => write!(f, "traffic_bytes_self_rx"),
MetricName::TrafficBytesForeignForwardRx => {
write!(f, "traffic_bytes_foreign_forward_rx")
}
MetricName::TrafficBytesForeignForwardTx => {
write!(f, "traffic_bytes_foreign_forward_tx")
}
MetricName::TrafficBytesForeignForwardForwarded => {
write!(f, "traffic_bytes_foreign_forward_forwarded")
}
MetricName::TrafficPacketsTx => write!(f, "traffic_packets_tx"),
MetricName::TrafficPacketsRx => write!(f, "traffic_packets_rx"),
MetricName::TrafficPacketsForwarded => write!(f, "traffic_packets_forwarded"),
MetricName::TrafficPacketsSelfTx => write!(f, "traffic_packets_self_tx"),
MetricName::TrafficPacketsSelfRx => write!(f, "traffic_packets_self_rx"),
MetricName::TrafficPacketsForeignForwardRx => {
write!(f, "traffic_packets_foreign_forward_rx")
}
MetricName::TrafficPacketsForeignForwardTx => {
write!(f, "traffic_packets_foreign_forward_tx")
}
MetricName::TrafficPacketsForeignForwardForwarded => {
write!(f, "traffic_packets_foreign_forward_forwarded")
}
MetricName::CompressionBytesRxBefore => write!(f, "compression_bytes_rx_before"),
MetricName::CompressionBytesRxAfter => write!(f, "compression_bytes_rx_after"),
MetricName::CompressionBytesTxBefore => write!(f, "compression_bytes_tx_before"),
MetricName::CompressionBytesTxAfter => write!(f, "compression_bytes_tx_after"),
}
}
}
/// Predefined label types for type safety
#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub enum LabelType {
/// Network Name
NetworkName(String),
/// Source peer ID
SrcPeerId(u32),
/// Destination peer ID
DstPeerId(u32),
/// Service name
ServiceName(String),
/// Method name
MethodName(String),
/// Protocol type
Protocol(String),
/// Direction (tx/rx)
Direction(String),
/// Compression algorithm
CompressionAlgo(String),
/// Error type
ErrorType(String),
/// Status
Status(String),
}
impl fmt::Display for LabelType {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
LabelType::NetworkName(name) => write!(f, "network_name={}", name),
LabelType::SrcPeerId(id) => write!(f, "src_peer_id={}", id),
LabelType::DstPeerId(id) => write!(f, "dst_peer_id={}", id),
LabelType::ServiceName(name) => write!(f, "service_name={}", name),
LabelType::MethodName(name) => write!(f, "method_name={}", name),
LabelType::Protocol(proto) => write!(f, "protocol={}", proto),
LabelType::Direction(dir) => write!(f, "direction={}", dir),
LabelType::CompressionAlgo(algo) => write!(f, "compression_algo={}", algo),
LabelType::ErrorType(err) => write!(f, "error_type={}", err),
LabelType::Status(status) => write!(f, "status={}", status),
}
}
}
impl LabelType {
pub fn key(&self) -> &'static str {
match self {
LabelType::NetworkName(_) => "network_name",
LabelType::SrcPeerId(_) => "src_peer_id",
LabelType::DstPeerId(_) => "dst_peer_id",
LabelType::ServiceName(_) => "service_name",
LabelType::MethodName(_) => "method_name",
LabelType::Protocol(_) => "protocol",
LabelType::Direction(_) => "direction",
LabelType::CompressionAlgo(_) => "compression_algo",
LabelType::ErrorType(_) => "error_type",
LabelType::Status(_) => "status",
}
}
pub fn value(&self) -> String {
match self {
LabelType::NetworkName(name) => name.clone(),
LabelType::SrcPeerId(id) => id.to_string(),
LabelType::DstPeerId(id) => id.to_string(),
LabelType::ServiceName(name) => name.clone(),
LabelType::MethodName(name) => name.clone(),
LabelType::Protocol(proto) => proto.clone(),
LabelType::Direction(dir) => dir.clone(),
LabelType::CompressionAlgo(algo) => algo.clone(),
LabelType::ErrorType(err) => err.clone(),
LabelType::Status(status) => status.clone(),
}
}
}
/// Label represents a key-value pair for metric identification
#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub struct Label {
pub key: String,
pub value: String,
}
impl Label {
pub fn new(key: impl Into<String>, value: impl Into<String>) -> Self {
Self {
key: key.into(),
value: value.into(),
}
}
pub fn from_label_type(label_type: &LabelType) -> Self {
Self {
key: label_type.key().to_string(),
value: label_type.value(),
}
}
}
/// LabelSet represents a collection of labels for a metric
#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub struct LabelSet {
labels: Vec<Label>,
}
impl LabelSet {
pub fn new() -> Self {
Self { labels: Vec::new() }
}
pub fn with_label(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
self.labels.push(Label::new(key, value));
self.labels.sort_by(|a, b| a.key.cmp(&b.key)); // Keep labels sorted for consistent hashing
self
}
/// Add a typed label to the set
pub fn with_label_type(mut self, label_type: LabelType) -> Self {
self.labels.push(Label::from_label_type(&label_type));
self.labels.sort_by(|a, b| a.key.cmp(&b.key)); // Keep labels sorted for consistent hashing
self
}
/// Create a LabelSet from multiple LabelTypes
pub fn from_label_types(label_types: &[LabelType]) -> Self {
let mut labels = Vec::new();
for label_type in label_types {
labels.push(Label::from_label_type(label_type));
}
labels.sort_by(|a, b| a.key.cmp(&b.key)); // Keep labels sorted for consistent hashing
Self { labels }
}
pub fn labels(&self) -> &[Label] {
&self.labels
}
/// Generate a string key for this label set
pub fn to_key(&self) -> String {
if self.labels.is_empty() {
return String::new();
}
let mut parts = Vec::with_capacity(self.labels.len());
for label in &self.labels {
parts.push(format!("{}={}", label.key, label.value));
}
parts.join(",")
}
}
impl Default for LabelSet {
fn default() -> Self {
Self::new()
}
}
/// UnsafeCounter provides a high-performance counter using UnsafeCell
#[derive(Debug)]
pub struct UnsafeCounter {
value: UnsafeCell<u64>,
}
impl Default for UnsafeCounter {
fn default() -> Self {
Self::new()
}
}
impl UnsafeCounter {
pub fn new() -> Self {
Self {
value: UnsafeCell::new(0),
}
}
pub fn new_with_value(initial: u64) -> Self {
Self {
value: UnsafeCell::new(initial),
}
}
/// Increment the counter by the given amount
/// # Safety
/// This method is unsafe because it uses UnsafeCell. The caller must ensure
/// that no other thread is accessing this counter simultaneously.
pub unsafe fn add(&self, delta: u64) {
let ptr = self.value.get();
*ptr = (*ptr).saturating_add(delta);
}
/// Increment the counter by 1
/// # Safety
/// This method is unsafe because it uses UnsafeCell. The caller must ensure
/// that no other thread is accessing this counter simultaneously.
pub unsafe fn inc(&self) {
self.add(1);
}
/// Get the current value of the counter
/// # Safety
/// This method is unsafe because it uses UnsafeCell. The caller must ensure
/// that no other thread is modifying this counter simultaneously.
pub unsafe fn get(&self) -> u64 {
let ptr = self.value.get();
*ptr
}
/// Reset the counter to zero
/// # Safety
/// This method is unsafe because it uses UnsafeCell. The caller must ensure
/// that no other thread is accessing this counter simultaneously.
pub unsafe fn reset(&self) {
let ptr = self.value.get();
*ptr = 0;
}
/// Set the counter to a specific value
/// # Safety
/// This method is unsafe because it uses UnsafeCell. The caller must ensure
/// that no other thread is accessing this counter simultaneously.
pub unsafe fn set(&self, value: u64) {
let ptr = self.value.get();
*ptr = value;
}
}
// UnsafeCounter is Send + Sync because the safety is guaranteed by the caller
unsafe impl Send for UnsafeCounter {}
unsafe impl Sync for UnsafeCounter {}
/// MetricData contains both the counter and last update timestamp
/// Uses UnsafeCell for lock-free access
#[derive(Debug)]
struct MetricData {
counter: UnsafeCounter,
last_updated: UnsafeCell<Instant>,
}
impl MetricData {
fn new() -> Self {
Self {
counter: UnsafeCounter::new(),
last_updated: UnsafeCell::new(Instant::now()),
}
}
fn new_with_value(initial: u64) -> Self {
Self {
counter: UnsafeCounter::new_with_value(initial),
last_updated: UnsafeCell::new(Instant::now()),
}
}
/// Update the last_updated timestamp
/// # Safety
/// This method is unsafe because it uses UnsafeCell. The caller must ensure
/// that no other thread is accessing this timestamp simultaneously.
unsafe fn touch(&self) {
let ptr = self.last_updated.get();
*ptr = Instant::now();
}
/// Get the last updated timestamp
/// # Safety
/// This method is unsafe because it uses UnsafeCell. The caller must ensure
/// that no other thread is modifying this timestamp simultaneously.
unsafe fn get_last_updated(&self) -> Instant {
let ptr = self.last_updated.get();
*ptr
}
}
// MetricData is Send + Sync because the safety is guaranteed by the caller
unsafe impl Send for MetricData {}
unsafe impl Sync for MetricData {}
/// MetricKey uniquely identifies a metric with its name and labels
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
struct MetricKey {
name: MetricName,
labels: LabelSet,
}
impl MetricKey {
fn new(name: MetricName, labels: LabelSet) -> Self {
Self { name, labels }
}
}
impl fmt::Display for MetricKey {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let label_str = self.labels.to_key();
if label_str.is_empty() {
f.write_str(self.name.to_string().as_str())
} else {
f.write_str(format!("{}[{}]", self.name, label_str).as_str())
}
}
}
/// CounterHandle provides a safe interface to a MetricData
/// It ensures thread-local access patterns for performance
#[derive(Clone)]
pub struct CounterHandle {
metric_data: Arc<MetricData>,
_key: MetricKey, // Keep key for debugging purposes
}
impl CounterHandle {
fn new(metric_data: Arc<MetricData>, key: MetricKey) -> Self {
Self {
metric_data,
_key: key,
}
}
/// Increment the counter by the given amount
pub fn add(&self, delta: u64) {
unsafe {
self.metric_data.counter.add(delta);
self.metric_data.touch();
}
}
/// Increment the counter by 1
pub fn inc(&self) {
unsafe {
self.metric_data.counter.inc();
self.metric_data.touch();
}
}
/// Get the current value of the counter
pub fn get(&self) -> u64 {
unsafe { self.metric_data.counter.get() }
}
/// Reset the counter to zero
pub fn reset(&self) {
unsafe {
self.metric_data.counter.reset();
self.metric_data.touch();
}
}
/// Set the counter to a specific value
pub fn set(&self, value: u64) {
unsafe {
self.metric_data.counter.set(value);
self.metric_data.touch();
}
}
}
/// MetricSnapshot represents a point-in-time view of a metric
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MetricSnapshot {
pub name: MetricName,
pub labels: LabelSet,
pub value: u64,
}
impl MetricSnapshot {
pub fn name_str(&self) -> String {
self.name.to_string()
}
}
/// StatsManager manages global statistics with high performance counters
pub struct StatsManager {
counters: Arc<DashMap<MetricKey, Arc<MetricData>>>,
cleanup_task: ScopedTask<()>,
}
impl StatsManager {
/// Create a new StatsManager
pub fn new() -> Self {
let counters = Arc::new(DashMap::new());
// Start cleanup task only if we're in a tokio runtime
let counters_clone = Arc::downgrade(&counters.clone());
let cleanup_task = tokio::spawn(async move {
let mut interval = interval(Duration::from_secs(60)); // Check every minute
loop {
interval.tick().await;
let cutoff_time = Instant::now() - Duration::from_secs(180); // 3 minutes
let Some(counters) = counters_clone.upgrade() else {
break;
};
// Remove entries that haven't been updated for 3 minutes
counters.retain(|_, metric_data: &mut Arc<MetricData>| unsafe {
metric_data.get_last_updated() > cutoff_time
});
}
});
Self {
counters,
cleanup_task: cleanup_task.into(),
}
}
/// Get or create a counter with the given name and labels
pub fn get_counter(&self, name: MetricName, labels: LabelSet) -> CounterHandle {
let key = MetricKey::new(name, labels);
let metric_data = self
.counters
.entry(key.clone())
.or_insert_with(|| Arc::new(MetricData::new()))
.clone();
CounterHandle::new(metric_data, key)
}
/// Get a counter with no labels
pub fn get_simple_counter(&self, name: MetricName) -> CounterHandle {
self.get_counter(name, LabelSet::new())
}
/// Get all metric snapshots
pub fn get_all_metrics(&self) -> Vec<MetricSnapshot> {
let mut metrics = Vec::new();
for entry in self.counters.iter() {
let key = entry.key();
let metric_data = entry.value();
let value = unsafe { metric_data.counter.get() };
metrics.push(MetricSnapshot {
name: key.name,
labels: key.labels.clone(),
value,
});
}
// Sort by metric name and then by labels for consistent output
metrics.sort_by(|a, b| {
a.name
.to_string()
.cmp(&b.name.to_string())
.then_with(|| a.labels.to_key().cmp(&b.labels.to_key()))
});
metrics
}
/// Get metrics filtered by name prefix
pub fn get_metrics_by_prefix(&self, prefix: &str) -> Vec<MetricSnapshot> {
self.get_all_metrics()
.into_iter()
.filter(|m| m.name.to_string().starts_with(prefix))
.collect()
}
/// Get a specific metric by name and labels
pub fn get_metric(&self, name: MetricName, labels: &LabelSet) -> Option<MetricSnapshot> {
let key = MetricKey::new(name, labels.clone());
if let Some(metric_data) = self.counters.get(&key) {
let value = unsafe { metric_data.counter.get() };
Some(MetricSnapshot {
name,
labels: labels.clone(),
value,
})
} else {
None
}
}
/// Clear all metrics
pub fn clear(&self) {
self.counters.clear();
}
/// Get the number of tracked metrics
pub fn metric_count(&self) -> usize {
self.counters.len()
}
/// Export metrics in Prometheus format
pub fn export_prometheus(&self) -> String {
let metrics = self.get_all_metrics();
let mut output = String::new();
let mut current_metric = String::new();
for metric in metrics {
let metric_name_str = metric.name.to_string();
if metric_name_str != current_metric {
if !current_metric.is_empty() {
output.push('\n');
}
output.push_str(&format!("# TYPE {} counter\n", metric_name_str));
current_metric = metric_name_str.clone();
}
if metric.labels.labels().is_empty() {
output.push_str(&format!("{} {}\n", metric_name_str, metric.value));
} else {
let label_str = metric
.labels
.labels()
.iter()
.map(|l| format!("{}=\"{}\"", l.key, l.value))
.collect::<Vec<_>>()
.join(",");
output.push_str(&format!(
"{}{{{}}} {}\n",
metric_name_str, label_str, metric.value
));
}
}
output
}
}
impl Default for StatsManager {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::common::stats_manager::{LabelSet, LabelType, MetricName, StatsManager};
use crate::proto::cli::{
GetPrometheusStatsRequest, GetPrometheusStatsResponse, GetStatsRequest, GetStatsResponse,
};
use std::collections::BTreeMap;
#[tokio::test]
async fn test_label_set() {
let labels = LabelSet::new()
.with_label("peer_id", "peer1")
.with_label("method", "ping");
assert_eq!(labels.to_key(), "method=ping,peer_id=peer1");
}
#[tokio::test]
async fn test_unsafe_counter() {
let counter = UnsafeCounter::new();
unsafe {
assert_eq!(counter.get(), 0);
counter.inc();
assert_eq!(counter.get(), 1);
counter.add(5);
assert_eq!(counter.get(), 6);
counter.set(10);
assert_eq!(counter.get(), 10);
counter.reset();
assert_eq!(counter.get(), 0);
}
}
#[tokio::test]
async fn test_stats_manager() {
let stats = StatsManager::new();
// Test simple counter
let counter1 = stats.get_simple_counter(MetricName::PeerRpcClientTx);
counter1.inc();
counter1.add(5);
// Test counter with labels
let labels = LabelSet::new()
.with_label("peer_id", "peer1")
.with_label("method", "ping");
let counter2 = stats.get_counter(MetricName::PeerRpcClientTx, labels.clone());
counter2.add(3);
// Check metrics
let metrics = stats.get_all_metrics();
assert_eq!(metrics.len(), 2);
// Find the simple counter
let simple_metric = metrics
.iter()
.find(|m| m.labels.labels().is_empty())
.unwrap();
assert_eq!(simple_metric.name, MetricName::PeerRpcClientTx);
assert_eq!(simple_metric.value, 6);
// Find the labeled counter
let labeled_metric = metrics
.iter()
.find(|m| !m.labels.labels().is_empty())
.unwrap();
assert_eq!(labeled_metric.name, MetricName::PeerRpcClientTx);
assert_eq!(labeled_metric.value, 3);
assert_eq!(labeled_metric.labels, labels);
}
#[tokio::test]
async fn test_prometheus_export() {
let stats = StatsManager::new();
let counter1 = stats.get_simple_counter(MetricName::TrafficBytesTx);
counter1.set(100);
let labels = LabelSet::new().with_label("status", "success");
let counter2 = stats.get_counter(MetricName::PeerRpcClientTx, labels);
counter2.set(50);
let prometheus_output = stats.export_prometheus();
assert!(prometheus_output.contains("# TYPE peer_rpc_client_tx counter"));
assert!(prometheus_output.contains("peer_rpc_client_tx{status=\"success\"} 50"));
assert!(prometheus_output.contains("# TYPE traffic_bytes_tx counter"));
assert!(prometheus_output.contains("traffic_bytes_tx 100"));
}
#[tokio::test]
async fn test_get_metric() {
let stats = StatsManager::new();
let labels = LabelSet::new().with_label("peer", "test");
let counter = stats.get_counter(MetricName::PeerRpcClientTx, labels.clone());
counter.set(42);
let metric = stats
.get_metric(MetricName::PeerRpcClientTx, &labels)
.unwrap();
assert_eq!(metric.value, 42);
let non_existent = stats.get_metric(MetricName::PeerRpcErrors, &LabelSet::new());
assert!(non_existent.is_none());
}
#[tokio::test]
async fn test_metrics_by_prefix() {
let stats = StatsManager::new();
stats
.get_simple_counter(MetricName::PeerRpcClientTx)
.set(10);
stats.get_simple_counter(MetricName::PeerRpcErrors).set(2);
stats
.get_simple_counter(MetricName::TrafficBytesTx)
.set(100);
let rpc_metrics = stats.get_metrics_by_prefix("peer_rpc");
assert_eq!(rpc_metrics.len(), 2);
let traffic_metrics = stats.get_metrics_by_prefix("traffic_");
assert_eq!(traffic_metrics.len(), 1);
}
#[tokio::test]
async fn test_cleanup_mechanism() {
let stats = StatsManager::new();
// 创建一些计数器
let counter1 = stats.get_simple_counter(MetricName::PeerRpcClientTx);
counter1.set(10);
let labels = LabelSet::new().with_label("test", "value");
let counter2 = stats.get_counter(MetricName::TrafficBytesTx, labels);
counter2.set(20);
// 验证计数器存在
assert_eq!(stats.metric_count(), 2);
// 注意实际的清理测试需要等待3分钟这在单元测试中不现实
// 这里我们只验证清理机制的基本结构是否正确
// 清理逻辑在后台线程中运行会自动删除超过3分钟未更新的条目
// 验证计数器仍然可以正常工作
counter1.inc();
assert_eq!(counter1.get(), 11);
counter2.add(5);
assert_eq!(counter2.get(), 25);
}
#[tokio::test]
async fn test_stats_rpc_data_structures() {
// Test GetStatsRequest
let request = GetStatsRequest {};
assert_eq!(request, GetStatsRequest {});
// Test GetStatsResponse
let response = GetStatsResponse { metrics: vec![] };
assert!(response.metrics.is_empty());
// Test GetPrometheusStatsRequest
let prometheus_request = GetPrometheusStatsRequest {};
assert_eq!(prometheus_request, GetPrometheusStatsRequest {});
// Test GetPrometheusStatsResponse
let prometheus_response = GetPrometheusStatsResponse {
prometheus_text: "# Test metrics\n".to_string(),
};
assert_eq!(prometheus_response.prometheus_text, "# Test metrics\n");
}
#[tokio::test]
async fn test_metric_snapshot_creation() {
let stats_manager = StatsManager::new();
// Create some test metrics
let counter1 = stats_manager.get_counter(
MetricName::PeerRpcClientTx,
LabelSet::new()
.with_label_type(LabelType::SrcPeerId(123))
.with_label_type(LabelType::ServiceName("test_service".to_string())),
);
counter1.add(100);
let counter2 = stats_manager.get_counter(
MetricName::TrafficBytesTx,
LabelSet::new().with_label_type(LabelType::Protocol("tcp".to_string())),
);
counter2.add(1024);
// Get all metrics
let metrics = stats_manager.get_all_metrics();
assert_eq!(metrics.len(), 2);
// Verify the metrics can be converted to the format expected by RPC
for metric in metrics {
let mut labels = BTreeMap::new();
for label in metric.labels.labels() {
labels.insert(label.key.clone(), label.value.clone());
}
// This simulates what the RPC service would do
let _metric_snapshot = crate::proto::cli::MetricSnapshot {
name: metric.name.to_string(),
value: metric.value,
labels,
};
}
}
#[tokio::test]
async fn test_prometheus_export_format() {
let stats_manager = StatsManager::new();
// Create test metrics
let counter = stats_manager.get_counter(
MetricName::PeerRpcClientTx,
LabelSet::new()
.with_label_type(LabelType::SrcPeerId(123))
.with_label_type(LabelType::ServiceName("test".to_string())),
);
counter.add(42);
// Export to Prometheus format
let prometheus_text = stats_manager.export_prometheus();
println!("{}", prometheus_text);
// Verify the format
assert!(prometheus_text.contains("peer_rpc_client_tx"));
assert!(prometheus_text.contains("42"));
assert!(prometheus_text.contains("src_peer_id=\"123\""));
assert!(prometheus_text.contains("service_name=\"test\""));
}
}

View File

@@ -282,9 +282,7 @@ impl StunClient {
.with_context(|| "encode stun message")?;
tids.push(tid);
tracing::trace!(?message, ?msg, tid, "send stun request");
self.socket
.send_to(msg.as_slice().into(), &stun_host)
.await?;
self.socket.send_to(msg.as_slice(), &stun_host).await?;
}
let now = Instant::now();
@@ -372,7 +370,7 @@ impl StunClientBuilder {
pub async fn stop(&mut self) {
self.task_set.abort_all();
while let Some(_) = self.task_set.join_next().await {}
while self.task_set.join_next().await.is_some() {}
}
}
@@ -417,7 +415,7 @@ impl UdpNatTypeDetectResult {
return true;
}
}
return false;
false
}
fn is_pat(&self) -> bool {
@@ -457,16 +455,16 @@ impl UdpNatTypeDetectResult {
if self.is_cone() {
if self.has_ip_changed_resp() {
if self.is_open_internet() {
return NatType::OpenInternet;
NatType::OpenInternet
} else if self.is_pat() {
return NatType::NoPat;
NatType::NoPat
} else {
return NatType::FullCone;
NatType::FullCone
}
} else if self.has_port_changed_resp() {
return NatType::Restricted;
NatType::Restricted
} else {
return NatType::PortRestricted;
NatType::PortRestricted
}
} else if !self.stun_resps.is_empty() {
if self.public_ips().len() != 1
@@ -480,7 +478,7 @@ impl UdpNatTypeDetectResult {
.mapped_socket_addr
.is_none()
{
return NatType::Symmetric;
NatType::Symmetric
} else {
let extra_bind_test = self.extra_bind_test.as_ref().unwrap();
let extra_port = extra_bind_test.mapped_socket_addr.unwrap().port();
@@ -488,15 +486,15 @@ impl UdpNatTypeDetectResult {
let max_port_diff = extra_port.saturating_sub(self.max_port());
let min_port_diff = self.min_port().saturating_sub(extra_port);
if max_port_diff != 0 && max_port_diff < 100 {
return NatType::SymmetricEasyInc;
NatType::SymmetricEasyInc
} else if min_port_diff != 0 && min_port_diff < 100 {
return NatType::SymmetricEasyDec;
NatType::SymmetricEasyDec
} else {
return NatType::Symmetric;
NatType::Symmetric
}
}
} else {
return NatType::Unknown;
NatType::Unknown
}
}
@@ -679,7 +677,7 @@ impl StunInfoCollectorTrait for StunInfoCollector {
.unwrap()
.clone()
.map(|x| x.collect_available_stun_server())
.unwrap_or(vec![]);
.unwrap_or_default();
if stun_servers.is_empty() {
let mut host_resolver =
@@ -740,7 +738,7 @@ impl StunInfoCollector {
pub fn get_default_servers() -> Vec<String> {
// NOTICE: we may need to choose stun stun server based on geo location
// stun server cross nation may return a external ip address with high latency and loss rate
vec![
[
"txt:stun.easytier.cn",
"stun.miwifi.com",
"stun.chat.bilibili.com",
@@ -752,16 +750,16 @@ impl StunInfoCollector {
}
pub fn get_default_servers_v6() -> Vec<String> {
vec!["txt:stun-v6.easytier.cn"]
["txt:stun-v6.easytier.cn"]
.iter()
.map(|x| x.to_string())
.collect()
}
async fn get_public_ipv6(servers: &Vec<String>) -> Option<Ipv6Addr> {
async fn get_public_ipv6(servers: &[String]) -> Option<Ipv6Addr> {
let mut ips = HostResolverIter::new(servers.to_vec(), 10, true);
while let Some(ip) = ips.next().await {
let Ok(udp_socket) = UdpSocket::bind(format!("[::]:0")).await else {
let Ok(udp_socket) = UdpSocket::bind("[::]:0".to_string()).await else {
break;
};
let udp = Arc::new(udp_socket);
@@ -770,11 +768,8 @@ impl StunInfoCollector {
.bind_request(false, false)
.await;
tracing::debug!(?ret, "finish ipv6 udp nat type detect");
match ret.map(|x| x.mapped_socket_addr.map(|x| x.ip())) {
Ok(Some(IpAddr::V6(v6))) => {
return Some(v6);
}
_ => {}
if let Ok(Some(IpAddr::V6(v6))) = ret.map(|x| x.mapped_socket_addr.map(|x| x.ip())) {
return Some(v6);
}
}
None
@@ -854,9 +849,9 @@ impl StunInfoCollector {
self.tasks.lock().unwrap().spawn(async move {
loop {
let servers = stun_servers.read().unwrap().clone();
Self::get_public_ipv6(&servers)
.await
.map(|x| stored_ipv6.store(Some(x)));
if let Some(x) = Self::get_public_ipv6(&servers).await {
stored_ipv6.store(Some(x))
}
let sleep_sec = if stored_ipv6.load().is_none() {
60

View File

@@ -34,7 +34,7 @@ impl From<LimiterConfig> for BucketConfig {
.unwrap_or(Duration::from_millis(10));
BucketConfig {
capacity: burst_rate * fill_rate,
fill_rate: fill_rate,
fill_rate,
refill_interval,
}
}
@@ -162,6 +162,12 @@ pub struct TokenBucketManager {
retain_task: ScopedTask<()>,
}
impl Default for TokenBucketManager {
fn default() -> Self {
Self::new()
}
}
impl TokenBucketManager {
/// Creates a new TokenBucketManager
pub fn new() -> Self {
@@ -318,7 +324,7 @@ mod tests {
// Should have accumulated about 100 tokens (10,000 tokens/s * 0.001s)
let tokens = bucket.available_tokens.load(Ordering::Relaxed);
assert!(
tokens >= 100 && tokens <= 200,
(100..=200).contains(&tokens),
"Unexpected token count: {}",
tokens
);
@@ -355,8 +361,7 @@ mod tests {
.list_foreign_networks()
.await
.foreign_networks
.len()
== 0
.is_empty()
},
Duration::from_secs(5),
)
@@ -370,8 +375,7 @@ mod tests {
.get_global_ctx()
.token_bucket_manager()
.buckets
.len()
== 0
.is_empty()
},
Duration::from_secs(10),
)

View File

@@ -16,6 +16,7 @@ use crate::{
dns::socket_addrs, error::Error, global_ctx::ArcGlobalCtx, stun::StunInfoCollectorTrait,
PeerId,
},
connector::udp_hole_punch::handle_rpc_result,
peers::{
peer_conn::PeerConnId,
peer_manager::PeerManager,
@@ -91,6 +92,7 @@ struct DirectConnectorManagerData {
global_ctx: ArcGlobalCtx,
peer_manager: Arc<PeerManager>,
dst_listener_blacklist: timedmap::TimedMap<DstListenerUrlBlackListItem, ()>,
peer_black_list: timedmap::TimedMap<PeerId, ()>,
}
impl DirectConnectorManagerData {
@@ -99,6 +101,7 @@ impl DirectConnectorManagerData {
global_ctx,
peer_manager,
dst_listener_blacklist: timedmap::TimedMap::new(),
peer_black_list: timedmap::TimedMap::new(),
}
}
@@ -177,16 +180,13 @@ impl DirectConnectorManagerData {
// ask remote to send v6 hole punch packet
// and no matter what the result is, continue to connect
let _ = self
.remote_send_v6_hole_punch_packet(dst_peer_id, &local_socket, &remote_url)
.remote_send_v6_hole_punch_packet(dst_peer_id, &local_socket, remote_url)
.await;
let udp_connector = UdpTunnelConnector::new(remote_url.clone());
let remote_addr = super::check_scheme_and_get_socket_addr::<SocketAddr>(
&remote_url,
"udp",
IpVersion::V6,
)
.await?;
let remote_addr =
super::check_scheme_and_get_socket_addr::<SocketAddr>(remote_url, "udp", IpVersion::V6)
.await?;
let ret = udp_connector
.try_connect_with_socket(local_socket, remote_addr)
.await?;
@@ -230,8 +230,8 @@ impl DirectConnectorManagerData {
dst_peer_id: PeerId,
addr: String,
) -> Result<(), Error> {
let mut rand_gen = rand::rngs::OsRng::default();
let backoff_ms = vec![1000, 2000, 4000];
let mut rand_gen = rand::rngs::OsRng;
let backoff_ms = [1000, 2000, 4000];
let mut backoff_idx = 0;
tracing::debug!(?dst_peer_id, ?addr, "try_connect_to_ip start");
@@ -240,10 +240,7 @@ impl DirectConnectorManagerData {
if self
.dst_listener_blacklist
.contains(&DstListenerUrlBlackListItem(
dst_peer_id.clone(),
addr.clone(),
))
.contains(&DstListenerUrlBlackListItem(dst_peer_id, addr.clone()))
{
return Err(Error::UrlInBlacklist);
}
@@ -278,7 +275,7 @@ impl DirectConnectorManagerData {
continue;
} else {
self.dst_listener_blacklist.insert(
DstListenerUrlBlackListItem(dst_peer_id.clone(), addr),
DstListenerUrlBlackListItem(dst_peer_id, addr),
(),
std::time::Duration::from_secs(DIRECT_CONNECTOR_BLACKLIST_TIMEOUT_SEC),
);
@@ -312,7 +309,7 @@ impl DirectConnectorManagerData {
if addr.set_host(Some(ip.to_string().as_str())).is_ok() {
tasks.spawn(Self::try_connect_to_ip(
self.clone(),
dst_peer_id.clone(),
dst_peer_id,
addr.to_string(),
));
} else {
@@ -327,7 +324,7 @@ impl DirectConnectorManagerData {
} else if !s_addr.ip().is_loopback() || TESTING.load(Ordering::Relaxed) {
tasks.spawn(Self::try_connect_to_ip(
self.clone(),
dst_peer_id.clone(),
dst_peer_id,
listener.to_string(),
));
}
@@ -352,13 +349,10 @@ impl DirectConnectorManagerData {
.iter()
.for_each(|ip| {
let mut addr = (*listener).clone();
if addr
.set_host(Some(format!("[{}]", ip.to_string()).as_str()))
.is_ok()
{
if addr.set_host(Some(format!("[{}]", ip).as_str())).is_ok() {
tasks.spawn(Self::try_connect_to_ip(
self.clone(),
dst_peer_id.clone(),
dst_peer_id,
addr.to_string(),
));
} else {
@@ -373,7 +367,7 @@ impl DirectConnectorManagerData {
} else if !s_addr.ip().is_loopback() || TESTING.load(Ordering::Relaxed) {
tasks.spawn(Self::try_connect_to_ip(
self.clone(),
dst_peer_id.clone(),
dst_peer_id,
listener.to_string(),
));
}
@@ -433,13 +427,8 @@ impl DirectConnectorManagerData {
}
tracing::debug!("try direct connect to peer with listener: {}", listener);
self.spawn_direct_connect_task(
dst_peer_id.clone(),
&ip_list,
&listener,
&mut tasks,
)
.await;
self.spawn_direct_connect_task(dst_peer_id, &ip_list, listener, &mut tasks)
.await;
listener_list.push(listener.clone().to_string());
available_listeners.pop();
@@ -473,7 +462,17 @@ impl DirectConnectorManagerData {
) -> Result<(), Error> {
let mut backoff =
udp_hole_punch::BackOff::new(vec![1000, 2000, 2000, 5000, 5000, 10000, 30000, 60000]);
let mut attempt = 0;
loop {
if self.peer_black_list.contains(&dst_peer_id) {
return Err(anyhow::anyhow!("peer {} is blacklisted", dst_peer_id).into());
}
if attempt > 0 {
tokio::time::sleep(Duration::from_millis(backoff.next_backoff())).await;
}
attempt += 1;
let peer_manager = self.peer_manager.clone();
tracing::debug!("try direct connect to peer: {}", dst_peer_id);
@@ -486,17 +485,11 @@ impl DirectConnectorManagerData {
self.global_ctx.get_network_name(),
);
let ip_list = match rpc_stub
let ip_list = rpc_stub
.get_ip_list(BaseController::default(), GetIpListRequest {})
.await
.with_context(|| format!("get ip list from peer {}", dst_peer_id))
{
Ok(ip_list) => ip_list,
Err(e) => {
tracing::error!(?e, "failed to get ip list from peer");
continue;
}
};
.await;
let ip_list = handle_rpc_result(ip_list, dst_peer_id, &self.peer_black_list)
.with_context(|| format!("get ip list from peer {}", dst_peer_id))?;
tracing::info!(ip_list = ?ip_list, dst_peer_id = ?dst_peer_id, "got ip list");
@@ -512,8 +505,6 @@ impl DirectConnectorManagerData {
);
return Ok(());
}
tokio::time::sleep(Duration::from_millis(backoff.next_backoff())).await;
}
}
}
@@ -547,13 +538,16 @@ impl PeerTaskLauncher for DirectConnectorLauncher {
}
async fn collect_peers_need_task(&self, data: &Self::Data) -> Vec<Self::CollectPeerItem> {
data.peer_black_list.cleanup();
let my_peer_id = data.peer_manager.my_peer_id();
data.peer_manager
.list_peers()
.await
.into_iter()
.filter(|peer_id| {
*peer_id != my_peer_id && !data.peer_manager.has_directly_connected_conn(*peer_id)
*peer_id != my_peer_id
&& !data.peer_manager.has_directly_connected_conn(*peer_id)
&& !data.peer_black_list.contains(peer_id)
})
.collect()
}

View File

@@ -124,11 +124,11 @@ impl DNSTunnelConnector {
let responses = responses.clone();
async move {
let response = resolver.srv_lookup(srv_domain).await.with_context(|| {
format!("srv_lookup failed, srv_domain: {}", srv_domain.to_string())
format!("srv_lookup failed, srv_domain: {}", srv_domain)
})?;
tracing::info!(?response, ?srv_domain, "srv_lookup response");
for record in response.iter() {
let parsed_record = Self::handle_one_srv_record(record, &protocol);
let parsed_record = Self::handle_one_srv_record(record, protocol);
tracing::info!(?parsed_record, ?srv_domain, "parsed_record");
if parsed_record.is_err() {
eprintln!(
@@ -153,8 +153,7 @@ impl DNSTunnelConnector {
let url = weighted_choice(srv_records.as_slice()).with_context(|| {
format!(
"failed to choose a srv record, domain_name: {}, srv_records: {:?}",
domain_name.to_string(),
srv_records
domain_name, srv_records
)
})?;

View File

@@ -93,7 +93,7 @@ impl HttpTunnelConnector {
tracing::info!("try to create connector by url: {}", query[0]);
self.redirect_type = HttpRedirectType::RedirectToQuery;
return create_connector_by_url(
&query[0].to_string(),
query[0].as_ref(),
&self.global_ctx,
self.ip_version,
)
@@ -193,7 +193,7 @@ impl HttpTunnelConnector {
.ok_or_else(|| Error::InvalidUrl("no redirect address found".to_string()))?;
let new_url = url::Url::parse(redirect_url.as_str())
.with_context(|| format!("parsing redirect url failed. url: {}", redirect_url))?;
return self.handle_302_redirect(new_url, &redirect_url).await;
return self.handle_302_redirect(new_url, redirect_url).await;
} else if res.status_code().is_success() {
return self.handle_200_success(&body).await;
} else {

View File

@@ -131,7 +131,7 @@ impl ManualConnectorManager {
.data
.connectors
.iter()
.map(|x| x.key().clone().into())
.map(|x| x.key().clone())
.collect();
let dead_urls: BTreeSet<String> = Self::collect_dead_conns(self.data.clone())
@@ -155,12 +155,8 @@ impl ManualConnectorManager {
);
}
let reconnecting_urls: BTreeSet<String> = self
.data
.reconnecting
.iter()
.map(|x| x.clone().into())
.collect();
let reconnecting_urls: BTreeSet<String> =
self.data.reconnecting.iter().map(|x| x.clone()).collect();
for conn_url in reconnecting_urls {
ret.insert(
@@ -282,7 +278,7 @@ impl ManualConnectorManager {
let remove_later = DashSet::new();
for it in data.removed_conn_urls.iter() {
let url = it.key();
if let Some(_) = data.connectors.remove(url) {
if data.connectors.remove(url).is_some() {
tracing::warn!("connector: {}, removed", url);
continue;
} else if data.reconnecting.contains(url) {
@@ -301,11 +297,7 @@ impl ManualConnectorManager {
async fn collect_dead_conns(data: Arc<ConnectorManagerData>) -> BTreeSet<String> {
Self::handle_remove_connector(data.clone());
let all_urls: BTreeSet<String> = data
.connectors
.iter()
.map(|x| x.key().clone().into())
.collect();
let all_urls: BTreeSet<String> = data.connectors.iter().map(|x| x.key().clone()).collect();
let mut ret = BTreeSet::new();
for url in all_urls.iter() {
if !data.alive_conn_urls.contains(url) {
@@ -400,21 +392,28 @@ impl ManualConnectorManager {
.await;
tracing::info!("reconnect: {} done, ret: {:?}", dead_url, ret);
if ret.is_ok() && ret.as_ref().unwrap().is_ok() {
reconn_ret = ret.unwrap();
break;
} else {
if ret.is_err() {
reconn_ret = Err(ret.unwrap_err().into());
} else if ret.as_ref().unwrap().is_err() {
reconn_ret = Err(ret.unwrap().unwrap_err());
match ret {
Ok(Ok(_)) => {
// 外层和内层都成功:解包并跳出
reconn_ret = ret.unwrap();
break;
}
Ok(Err(e)) => {
// 外层成功,内层失败
reconn_ret = Err(e);
}
Err(e) => {
// 外层失败
reconn_ret = Err(e.into());
}
data.global_ctx.issue_event(GlobalCtxEvent::ConnectError(
dead_url.clone(),
format!("{:?}", ip_version),
format!("{:?}", reconn_ret),
));
}
// 发送事件(只有在未 break 时才执行)
data.global_ctx.issue_event(GlobalCtxEvent::ConnectError(
dead_url.clone(),
format!("{:?}", ip_version),
format!("{:?}", reconn_ret),
));
}
reconn_ret

View File

@@ -260,7 +260,7 @@ impl PunchBothEasySymHoleClient {
)
.await;
let remote_ret = handle_rpc_result(remote_ret, dst_peer_id, self.blacklist.clone())?;
let remote_ret = handle_rpc_result(remote_ret, dst_peer_id, &self.blacklist)?;
if remote_ret.is_busy {
*is_busy = true;
@@ -389,7 +389,7 @@ pub mod tests {
let udp1 = Arc::new(UdpSocket::bind("0.0.0.0:40164").await.unwrap());
// 144 - DST_PORT_OFFSET = 124
let udp2 = Arc::new(UdpSocket::bind("0.0.0.0:40124").await.unwrap());
let udps = vec![udp1, udp2];
let udps = [udp1, udp2];
let counter = Arc::new(AtomicU32::new(0));

View File

@@ -67,9 +67,9 @@ impl From<NatType> for UdpNatType {
}
}
impl Into<NatType> for UdpNatType {
fn into(self) -> NatType {
match self {
impl From<UdpNatType> for NatType {
fn from(val: UdpNatType) -> Self {
match val {
UdpNatType::Unknown => NatType::Unknown,
UdpNatType::Open(nat_type) => nat_type,
UdpNatType::Cone(nat_type) => nat_type,
@@ -249,7 +249,7 @@ impl UdpSocketArray {
tracing::info!(?addr, ?tid, "got hole punching packet with intreast tid");
tid_to_socket
.entry(tid)
.or_insert_with(Vec::new)
.or_default()
.push(PunchedUdpSocket {
socket: socket.clone(),
tid,
@@ -556,7 +556,7 @@ impl PunchHoleServerCommon {
#[tracing::instrument(err, ret(level=Level::DEBUG), skip(ports))]
pub(crate) async fn send_symmetric_hole_punch_packet(
ports: &Vec<u16>,
ports: &[u16],
udp: Arc<UdpSocket>,
transaction_id: u32,
public_ips: &Vec<Ipv4Addr>,
@@ -628,5 +628,5 @@ pub(crate) async fn try_connect_with_socket(
connector
.try_connect_with_socket(socket, remote_mapped_addr)
.await
.map_err(|e| Error::from(e))
.map_err(Error::from)
}

View File

@@ -154,7 +154,7 @@ impl PunchConeHoleClient {
)
.await;
let resp = handle_rpc_result(resp, dst_peer_id, self.blacklist.clone())?;
let resp = handle_rpc_result(resp, dst_peer_id, &self.blacklist)?;
let remote_mapped_addr = resp.listener_mapped_addr.ok_or(anyhow::anyhow!(
"select_punch_listener response missing listener_mapped_addr"
@@ -172,7 +172,7 @@ impl PunchConeHoleClient {
udp_array
.send_with_all(
&new_hole_punch_packet(tid, HOLE_PUNCH_PACKET_BODY_LEN).into_bytes(),
remote_mapped_addr.clone().into(),
remote_mapped_addr.into(),
)
.await
.with_context(|| "failed to send hole punch packet from local")
@@ -188,7 +188,7 @@ impl PunchConeHoleClient {
..Default::default()
},
SendPunchPacketConeRequest {
listener_mapped_addr: Some(remote_mapped_addr.into()),
listener_mapped_addr: Some(remote_mapped_addr),
dest_addr: Some(local_mapped_addr.into()),
transaction_id: tid,
packet_count_per_batch: 2,

View File

@@ -39,7 +39,7 @@ pub(crate) mod cone;
pub(crate) mod sym_to_cone;
// sym punch should be serialized
static SYM_PUNCH_LOCK: Lazy<DashMap<PeerId, Arc<Mutex<()>>>> = Lazy::new(|| DashMap::new());
static SYM_PUNCH_LOCK: Lazy<DashMap<PeerId, Arc<Mutex<()>>>> = Lazy::new(DashMap::new);
pub static RUN_TESTING: Lazy<AtomicBool> = Lazy::new(|| AtomicBool::new(false));
// Blacklist timeout in seconds
@@ -183,7 +183,7 @@ impl BackOff {
pub fn handle_rpc_result<T>(
ret: Result<T, rpc_types::error::Error>,
dst_peer_id: PeerId,
blacklist: Arc<timedmap::TimedMap<PeerId, ()>>,
blacklist: &timedmap::TimedMap<PeerId, ()>,
) -> Result<T, rpc_types::error::Error> {
match ret {
Ok(ret) => Ok(ret),
@@ -223,7 +223,7 @@ impl UdpHoePunchConnectorData {
#[tracing::instrument(skip(self))]
async fn handle_punch_result(
self: &Self,
&self,
ret: Result<Option<Box<dyn Tunnel>>, Error>,
backoff: Option<&mut BackOff>,
round: Option<&mut u32>,
@@ -236,10 +236,8 @@ impl UdpHoePunchConnectorData {
if let Some(round) = round {
*round = round.saturating_sub(1);
}
} else {
if let Some(round) = round {
*round += 1;
}
} else if let Some(round) = round {
*round += 1;
}
};
@@ -464,7 +462,7 @@ impl PeerTaskLauncher for UdpHolePunchPeerTaskLauncher {
}
let conns = data.peer_mgr.list_peer_conns(peer_id).await;
if conns.is_some() && conns.unwrap().len() > 0 {
if conns.is_some() && !conns.unwrap().is_empty() {
continue;
}

View File

@@ -80,9 +80,9 @@ impl PunchSymToConeHoleServer {
let public_ips = request
.public_ips
.into_iter()
.map(|ip| std::net::Ipv4Addr::from(ip))
.map(std::net::Ipv4Addr::from)
.collect::<Vec<_>>();
if public_ips.len() == 0 {
if public_ips.is_empty() {
tracing::warn!("send_punch_packet_easy_sym got zero len public ip");
return Err(
anyhow::anyhow!("send_punch_packet_easy_sym got zero len public ip").into(),
@@ -158,9 +158,9 @@ impl PunchSymToConeHoleServer {
let public_ips = request
.public_ips
.into_iter()
.map(|ip| std::net::Ipv4Addr::from(ip))
.map(std::net::Ipv4Addr::from)
.collect::<Vec<_>>();
if public_ips.len() == 0 {
if public_ips.is_empty() {
tracing::warn!("try_punch_symmetric got zero len public ip");
return Err(anyhow::anyhow!("try_punch_symmetric got zero len public ip").into());
}
@@ -281,7 +281,7 @@ impl PunchSymToConeHoleClient {
return;
};
let req = SendPunchPacketEasySymRequest {
listener_mapped_addr: remote_mapped_addr.clone().into(),
listener_mapped_addr: remote_mapped_addr.into(),
public_ips: public_ips.clone().into_iter().map(|x| x.into()).collect(),
transaction_id: tid,
base_port_num: base_port_for_easy_sym.unwrap() as u32,
@@ -313,7 +313,7 @@ impl PunchSymToConeHoleClient {
port_index: u32,
) -> Option<u32> {
let req = SendPunchPacketHardSymRequest {
listener_mapped_addr: remote_mapped_addr.clone().into(),
listener_mapped_addr: remote_mapped_addr.into(),
public_ips: public_ips.clone().into_iter().map(|x| x.into()).collect(),
transaction_id: tid,
round,
@@ -333,16 +333,16 @@ impl PunchSymToConeHoleClient {
{
Err(e) => {
tracing::error!(?e, "failed to send punch packet for hard sym");
return None;
None
}
Ok(resp) => return Some(resp.next_port_index),
Ok(resp) => Some(resp.next_port_index),
}
}
async fn get_rpc_stub(
&self,
dst_peer_id: PeerId,
) -> Box<(dyn UdpHolePunchRpc<Controller = BaseController> + std::marker::Send + 'static)> {
) -> Box<dyn UdpHolePunchRpc<Controller = BaseController> + std::marker::Send + 'static> {
self.peer_mgr
.get_peer_rpc_mgr()
.rpc_client()
@@ -366,7 +366,7 @@ impl PunchSymToConeHoleClient {
let mut finish_time: Option<Instant> = None;
while finish_time.is_none() || finish_time.as_ref().unwrap().elapsed().as_millis() < 1000 {
udp_array
.send_with_all(&packet, remote_mapped_addr.into())
.send_with_all(packet, remote_mapped_addr.into())
.await?;
tokio::time::sleep(Duration::from_millis(200)).await;
@@ -437,7 +437,7 @@ impl PunchSymToConeHoleClient {
)
.await;
let resp = handle_rpc_result(resp, dst_peer_id, self.blacklist.clone())?;
let resp = handle_rpc_result(resp, dst_peer_id, &self.blacklist)?;
let remote_mapped_addr = resp.listener_mapped_addr.ok_or(anyhow::anyhow!(
"select_punch_listener response missing listener_mapped_addr"
@@ -484,7 +484,7 @@ impl PunchSymToConeHoleClient {
rpc_stub,
base_port_for_easy_sym,
my_nat_info,
remote_mapped_addr.clone(),
remote_mapped_addr,
public_ips.clone(),
tid,
))
@@ -494,7 +494,7 @@ impl PunchSymToConeHoleClient {
&udp_array,
&packet,
tid,
remote_mapped_addr.clone(),
remote_mapped_addr,
&scoped_punch_task,
)
.await?;
@@ -510,7 +510,7 @@ impl PunchSymToConeHoleClient {
let scoped_punch_task: ScopedTask<Option<u32>> =
tokio::spawn(Self::remote_send_hole_punch_packet_random(
rpc_stub,
remote_mapped_addr.clone(),
remote_mapped_addr,
public_ips.clone(),
tid,
round,
@@ -522,7 +522,7 @@ impl PunchSymToConeHoleClient {
&udp_array,
&packet,
tid,
remote_mapped_addr.clone(),
remote_mapped_addr,
&scoped_punch_task,
)
.await?;

View File

@@ -4,7 +4,6 @@ use std::{
net::{IpAddr, SocketAddr},
path::PathBuf,
str::FromStr,
sync::Mutex,
time::Duration,
vec,
};
@@ -30,15 +29,16 @@ use easytier::{
cli::{
list_peer_route_pair, AclManageRpc, AclManageRpcClientFactory, AddPortForwardRequest,
ConnectorManageRpc, ConnectorManageRpcClientFactory, DumpRouteRequest,
GetAclStatsRequest, GetVpnPortalInfoRequest, GetWhitelistRequest, ListConnectorRequest,
GetAclStatsRequest, GetPrometheusStatsRequest, GetStatsRequest,
GetVpnPortalInfoRequest, GetWhitelistRequest, ListConnectorRequest,
ListForeignNetworkRequest, ListGlobalForeignNetworkRequest, ListMappedListenerRequest,
ListPeerRequest, ListPeerResponse, ListPortForwardRequest, ListRouteRequest,
ListRouteResponse, ManageMappedListenerRequest, MappedListenerManageAction,
MappedListenerManageRpc, MappedListenerManageRpcClientFactory, NodeInfo, PeerManageRpc,
PeerManageRpcClientFactory, PortForwardManageRpc, PortForwardManageRpcClientFactory,
RemovePortForwardRequest, SetWhitelistRequest, ShowNodeInfoRequest, TcpProxyEntryState,
TcpProxyEntryTransportType, TcpProxyRpc, TcpProxyRpcClientFactory, VpnPortalRpc,
VpnPortalRpcClientFactory,
RemovePortForwardRequest, SetWhitelistRequest, ShowNodeInfoRequest, StatsRpc,
StatsRpcClientFactory, TcpProxyEntryState, TcpProxyEntryTransportType, TcpProxyRpc,
TcpProxyRpcClientFactory, VpnPortalRpc, VpnPortalRpcClientFactory,
},
common::{NatType, SocketType},
peer_rpc::{GetGlobalPeerMapRequest, PeerCenterRpc, PeerCenterRpcClientFactory},
@@ -102,6 +102,8 @@ enum SubCommand {
PortForward(PortForwardArgs),
#[command(about = "manage TCP/UDP whitelist")]
Whitelist(WhitelistArgs),
#[command(about = "show statistics information")]
Stats(StatsArgs),
#[command(about = t!("core_clap.generate_completions").to_string())]
GenAutocomplete { shell: Shell },
}
@@ -255,6 +257,20 @@ enum WhitelistSubCommand {
Show,
}
#[derive(Args, Debug)]
struct StatsArgs {
#[command(subcommand)]
sub_command: Option<StatsSubCommand>,
}
#[derive(Subcommand, Debug)]
enum StatsSubCommand {
/// Show general statistics
Show,
/// Show statistics in Prometheus format
Prometheus,
}
#[derive(Args, Debug)]
struct ServiceArgs {
#[arg(short, long, default_value = env!("CARGO_PKG_NAME"), help = "service name")]
@@ -309,7 +325,7 @@ struct InstallArgs {
type Error = anyhow::Error;
struct CommandHandler<'a> {
client: Mutex<RpcClient>,
client: tokio::sync::Mutex<RpcClient>,
verbose: bool,
output_format: &'a OutputFormat,
}
@@ -323,7 +339,7 @@ impl CommandHandler<'_> {
Ok(self
.client
.lock()
.unwrap()
.await
.scoped_client::<PeerManageRpcClientFactory<BaseController>>("".to_string())
.await
.with_context(|| "failed to get peer manager client")?)
@@ -335,7 +351,7 @@ impl CommandHandler<'_> {
Ok(self
.client
.lock()
.unwrap()
.await
.scoped_client::<ConnectorManageRpcClientFactory<BaseController>>("".to_string())
.await
.with_context(|| "failed to get connector manager client")?)
@@ -347,7 +363,7 @@ impl CommandHandler<'_> {
Ok(self
.client
.lock()
.unwrap()
.await
.scoped_client::<MappedListenerManageRpcClientFactory<BaseController>>("".to_string())
.await
.with_context(|| "failed to get mapped listener manager client")?)
@@ -359,7 +375,7 @@ impl CommandHandler<'_> {
Ok(self
.client
.lock()
.unwrap()
.await
.scoped_client::<PeerCenterRpcClientFactory<BaseController>>("".to_string())
.await
.with_context(|| "failed to get peer center client")?)
@@ -371,7 +387,7 @@ impl CommandHandler<'_> {
Ok(self
.client
.lock()
.unwrap()
.await
.scoped_client::<VpnPortalRpcClientFactory<BaseController>>("".to_string())
.await
.with_context(|| "failed to get vpn portal client")?)
@@ -383,7 +399,7 @@ impl CommandHandler<'_> {
Ok(self
.client
.lock()
.unwrap()
.await
.scoped_client::<AclManageRpcClientFactory<BaseController>>("".to_string())
.await
.with_context(|| "failed to get acl manager client")?)
@@ -396,7 +412,7 @@ impl CommandHandler<'_> {
Ok(self
.client
.lock()
.unwrap()
.await
.scoped_client::<TcpProxyRpcClientFactory<BaseController>>(transport_type.to_string())
.await
.with_context(|| "failed to get vpn portal client")?)
@@ -408,12 +424,24 @@ impl CommandHandler<'_> {
Ok(self
.client
.lock()
.unwrap()
.await
.scoped_client::<PortForwardManageRpcClientFactory<BaseController>>("".to_string())
.await
.with_context(|| "failed to get port forward manager client")?)
}
async fn get_stats_client(
&self,
) -> Result<Box<dyn StatsRpc<Controller = BaseController>>, Error> {
Ok(self
.client
.lock()
.await
.scoped_client::<StatsRpcClientFactory<BaseController>>("".to_string())
.await
.with_context(|| "failed to get stats client")?)
}
async fn list_peers(&self) -> Result<ListPeerResponse, Error> {
let client = self.get_peer_manager_client().await?;
let request = ListPeerRequest::default();
@@ -545,6 +573,18 @@ impl CommandHandler<'_> {
items.push(p.into());
}
// Sort items by ipv4 (using IpAddr for proper numeric comparison) first, then by hostname
items.sort_by(|a, b| {
use std::net::{IpAddr, Ipv4Addr};
use std::str::FromStr;
let a_ip = IpAddr::from_str(&a.ipv4).unwrap_or(IpAddr::V4(Ipv4Addr::UNSPECIFIED));
let b_ip = IpAddr::from_str(&b.ipv4).unwrap_or(IpAddr::V4(Ipv4Addr::UNSPECIFIED));
match a_ip.cmp(&b_ip) {
std::cmp::Ordering::Equal => a.hostname.cmp(&b.hostname),
other => other,
}
});
print_output(&items, self.output_format)?;
Ok(())
@@ -825,7 +865,7 @@ impl CommandHandler<'_> {
Ok(())
}
async fn handle_mapped_listener_add(&self, url: &String) -> Result<(), Error> {
async fn handle_mapped_listener_add(&self, url: &str) -> Result<(), Error> {
let url = Self::mapped_listener_validate_url(url)?;
let client = self.get_mapped_listener_manager_client().await?;
let request = ManageMappedListenerRequest {
@@ -838,7 +878,7 @@ impl CommandHandler<'_> {
Ok(())
}
async fn handle_mapped_listener_remove(&self, url: &String) -> Result<(), Error> {
async fn handle_mapped_listener_remove(&self, url: &str) -> Result<(), Error> {
let url = Self::mapped_listener_validate_url(url)?;
let client = self.get_mapped_listener_manager_client().await?;
let request = ManageMappedListenerRequest {
@@ -851,7 +891,7 @@ impl CommandHandler<'_> {
Ok(())
}
fn mapped_listener_validate_url(url: &String) -> Result<url::Url, Error> {
fn mapped_listener_validate_url(url: &str) -> Result<url::Url, Error> {
let url = url::Url::parse(url)?;
if url.scheme() != "tcp" && url.scheme() != "udp" {
return Err(anyhow::anyhow!(
@@ -885,8 +925,8 @@ impl CommandHandler<'_> {
cfg: Some(
PortForwardConfig {
proto: protocol.to_string(),
bind_addr: bind_addr.into(),
dst_addr: dst_addr.into(),
bind_addr,
dst_addr,
}
.into(),
),
@@ -921,11 +961,10 @@ impl CommandHandler<'_> {
cfg: Some(
PortForwardConfig {
proto: protocol.to_string(),
bind_addr: bind_addr.into(),
bind_addr,
dst_addr: dst_addr
.map(|s| s.parse::<SocketAddr>().unwrap())
.map(Into::into)
.unwrap_or("0.0.0.0:0".parse::<SocketAddr>().unwrap().into()),
.unwrap_or("0.0.0.0:0".parse::<SocketAddr>().unwrap()),
}
.into(),
),
@@ -1418,7 +1457,7 @@ async fn main() -> Result<(), Error> {
.unwrap(),
));
let handler = CommandHandler {
client: Mutex::new(client),
client: tokio::sync::Mutex::new(client),
verbose: cli.verbose,
output_format: &cli.output_format,
};
@@ -1676,16 +1715,10 @@ async fn main() -> Result<(), Error> {
format!("{:?}", stun_info.udp_nat_type()).as_str(),
]);
ip_list.interface_ipv4s.iter().for_each(|ip| {
builder.push_record(vec![
"Interface IPv4",
format!("{}", ip.to_string()).as_str(),
]);
builder.push_record(vec!["Interface IPv4", ip.to_string().as_str()]);
});
ip_list.interface_ipv6s.iter().for_each(|ip| {
builder.push_record(vec![
"Interface IPv6",
format!("{}", ip.to_string()).as_str(),
]);
builder.push_record(vec!["Interface IPv6", ip.to_string().as_str()]);
});
for (idx, l) in node_info.listeners.iter().enumerate() {
if l.starts_with("ring") {
@@ -1867,6 +1900,69 @@ async fn main() -> Result<(), Error> {
handler.handle_whitelist_show().await?;
}
},
SubCommand::Stats(stats_args) => match &stats_args.sub_command {
Some(StatsSubCommand::Show) | None => {
let client = handler.get_stats_client().await?;
let request = GetStatsRequest {};
let response = client.get_stats(BaseController::default(), request).await?;
if cli.output_format == OutputFormat::Json {
println!("{}", serde_json::to_string_pretty(&response.metrics)?);
} else {
#[derive(tabled::Tabled, serde::Serialize)]
struct StatsTableRow {
#[tabled(rename = "Metric Name")]
name: String,
#[tabled(rename = "Value")]
value: String,
#[tabled(rename = "Labels")]
labels: String,
}
let table_rows: Vec<StatsTableRow> = response
.metrics
.iter()
.map(|metric| {
let labels_str = if metric.labels.is_empty() {
"-".to_string()
} else {
metric
.labels
.iter()
.map(|(k, v)| format!("{}={}", k, v))
.collect::<Vec<_>>()
.join(", ")
};
let formatted_value = if metric.name.contains("bytes") {
format_size(metric.value, humansize::BINARY)
} else if metric.name.contains("duration") {
format!("{} ms", metric.value)
} else {
metric.value.to_string()
};
StatsTableRow {
name: metric.name.clone(),
value: formatted_value,
labels: labels_str,
}
})
.collect();
print_output(&table_rows, &cli.output_format)?
}
}
Some(StatsSubCommand::Prometheus) => {
let client = handler.get_stats_client().await?;
let request = GetPrometheusStatsRequest {};
let response = client
.get_prometheus_stats(BaseController::default(), request)
.await?;
println!("{}", response.prometheus_text);
}
},
SubCommand::GenAutocomplete { shell } => {
let mut cmd = Cli::command();
easytier::print_completions(shell, &mut cmd, "easytier-cli");

View File

@@ -4,7 +4,7 @@
extern crate rust_i18n;
use std::{
net::{Ipv4Addr, SocketAddr},
net::{IpAddr, SocketAddr},
path::PathBuf,
process::ExitCode,
sync::Arc,
@@ -18,8 +18,9 @@ use clap_complete::Shell;
use easytier::{
common::{
config::{
ConfigLoader, ConsoleLoggerConfig, FileLoggerConfig, LoggingConfigLoader,
NetworkIdentity, PeerConfig, PortForwardConfig, TomlConfigLoader, VpnPortalConfig,
get_avaliable_encrypt_methods, ConfigLoader, ConsoleLoggerConfig, FileLoggerConfig,
LoggingConfigLoader, NetworkIdentity, PeerConfig, PortForwardConfig, TomlConfigLoader,
VpnPortalConfig,
},
constants::EASYTIER_VERSION,
global_ctx::GlobalCtx,
@@ -52,10 +53,15 @@ use jemalloc_ctl::{epoch, stats, Access as _, AsName as _};
#[global_allocator]
static ALLOC: jemallocator::Jemalloc = jemallocator::Jemalloc;
#[cfg(feature = "jemalloc-prof")]
#[allow(non_upper_case_globals)]
#[export_name = "malloc_conf"]
pub static malloc_conf: &[u8] = b"prof:true,prof_active:true,lg_prof_sample:19\0";
fn set_prof_active(_active: bool) {
#[cfg(feature = "jemalloc-prof")]
{
const PROF_ACTIVE: &'static [u8] = b"prof.active\0";
const PROF_ACTIVE: &[u8] = b"prof.active\0";
let name = PROF_ACTIVE.name();
name.write(_active).expect("Should succeed to set prof");
}
@@ -64,7 +70,7 @@ fn set_prof_active(_active: bool) {
fn dump_profile(_cur_allocated: usize) {
#[cfg(feature = "jemalloc-prof")]
{
const PROF_DUMP: &'static [u8] = b"prof.dump\0";
const PROF_DUMP: &[u8] = b"prof.dump\0";
static mut PROF_DUMP_FILE_NAME: [u8; 128] = [0; 128];
let file_name_str = format!(
"profile-{}-{}.out",
@@ -278,6 +284,15 @@ struct NetworkOptions {
)]
disable_encryption: Option<bool>,
#[arg(
long,
env = "ET_ENCRYPTION_ALGORITHM",
help = t!("core_clap.encryption_algorithm").to_string(),
default_value = "aes-gcm",
value_parser = get_avaliable_encrypt_methods()
)]
encryption_algorithm: Option<String>,
#[arg(
long,
env = "ET_MULTI_THREAD",
@@ -333,7 +348,7 @@ struct NetworkOptions {
help = t!("core_clap.exit_nodes").to_string(),
num_args = 0..
)]
exit_nodes: Vec<Ipv4Addr>,
exit_nodes: Vec<IpAddr>,
#[arg(
long,
@@ -510,7 +525,7 @@ struct NetworkOptions {
#[arg(
long,
value_delimiter = ',',
help = "TCP port whitelist. Supports single ports (80) and ranges (8000-9000)",
help = t!("core_clap.tcp_whitelist").to_string(),
num_args = 0..
)]
tcp_whitelist: Vec<String>,
@@ -518,10 +533,28 @@ struct NetworkOptions {
#[arg(
long,
value_delimiter = ',',
help = "UDP port whitelist. Supports single ports (53) and ranges (5000-6000)",
help = t!("core_clap.udp_whitelist").to_string(),
num_args = 0..
)]
udp_whitelist: Vec<String>,
#[arg(
long,
env = "ET_DISABLE_RELAY_KCP",
help = t!("core_clap.disable_relay_kcp").to_string(),
num_args = 0..=1,
default_missing_value = "true"
)]
disable_relay_kcp: Option<bool>,
#[arg(
long,
env = "ET_ENABLE_RELAY_FOREIGN_NETWORK_KCP",
help = t!("core_clap.enable_relay_foreign_network_kcp").to_string(),
num_args = 0..=1,
default_missing_value = "true"
)]
enable_relay_foreign_network_kcp: Option<bool>,
}
#[derive(Parser, Debug)]
@@ -668,7 +701,7 @@ impl NetworkOptions {
.map(|s| s.parse().unwrap())
.collect(),
);
} else if cfg.get_listeners() == None {
} else if cfg.get_listeners().is_none() {
cfg.set_listeners(
Cli::parse_listeners(false, vec!["11010".to_string()])?
.into_iter()
@@ -707,7 +740,7 @@ impl NetworkOptions {
}
for n in self.proxy_networks.iter() {
add_proxy_network_to_config(n, &cfg)?;
add_proxy_network_to_config(n, cfg)?;
}
let rpc_portal = if let Some(r) = &self.rpc_portal {
@@ -721,9 +754,9 @@ impl NetworkOptions {
cfg.set_rpc_portal(rpc_portal);
if let Some(rpc_portal_whitelist) = &self.rpc_portal_whitelist {
let mut whitelist = cfg.get_rpc_portal_whitelist().unwrap_or_else(|| Vec::new());
let mut whitelist = cfg.get_rpc_portal_whitelist().unwrap_or_default();
for cidr in rpc_portal_whitelist {
whitelist.push((*cidr).clone());
whitelist.push(*cidr);
}
cfg.set_rpc_portal_whitelist(Some(whitelist));
}
@@ -792,18 +825,18 @@ impl NetworkOptions {
port_forward.port().expect("local bind port is missing")
)
.parse()
.expect(format!("failed to parse local bind addr {}", example_str).as_str());
.unwrap_or_else(|_| panic!("failed to parse local bind addr {}", example_str));
let dst_addr = format!(
"{}",
port_forward
.path_segments()
.expect(format!("remote destination addr is missing {}", example_str).as_str())
.next()
.expect(format!("remote destination addr is missing {}", example_str).as_str())
)
.parse()
.expect(format!("failed to parse remote destination addr {}", example_str).as_str());
let dst_addr = port_forward
.path_segments()
.unwrap_or_else(|| panic!("remote destination addr is missing {}", example_str))
.next()
.unwrap_or_else(|| panic!("remote destination addr is missing {}", example_str))
.to_string()
.parse()
.unwrap_or_else(|_| {
panic!("failed to parse remote destination addr {}", example_str)
});
let port_forward_item = PortForwardConfig {
bind_addr,
@@ -823,6 +856,9 @@ impl NetworkOptions {
if let Some(v) = self.disable_encryption {
f.enable_encryption = !v;
}
if let Some(algorithm) = &self.encryption_algorithm {
f.encryption_algorithm = algorithm.clone();
}
if let Some(v) = self.disable_ipv6 {
f.enable_ipv6 = !v;
}
@@ -870,6 +906,10 @@ impl NetworkOptions {
.foreign_relay_bps_limit
.unwrap_or(f.foreign_relay_bps_limit);
f.multi_thread_count = self.multi_thread_count.unwrap_or(f.multi_thread_count);
f.disable_relay_kcp = self.disable_relay_kcp.unwrap_or(f.disable_relay_kcp);
f.enable_relay_foreign_network_kcp = self
.enable_relay_foreign_network_kcp
.unwrap_or(f.enable_relay_foreign_network_kcp);
cfg.set_flags(f);
if !self.exit_nodes.is_empty() {
@@ -1101,7 +1141,7 @@ async fn run_main(cli: Cli) -> anyhow::Result<()> {
let mut cfg = TomlConfigLoader::default();
cli.network_options
.merge_into(&mut cfg)
.with_context(|| format!("failed to create config from cli"))?;
.with_context(|| "failed to create config from cli".to_string())?;
println!("Starting easytier from cli with config:");
println!("############### TOML ###############\n");
println!("{}", cfg.dump());
@@ -1116,7 +1156,7 @@ async fn run_main(cli: Cli) -> anyhow::Result<()> {
.into_values()
.filter_map(|info| info.error_msg)
.collect::<Vec<_>>();
if errs.len() > 0 {
if !errs.is_empty() {
return Err(anyhow::anyhow!("some instances stopped with errors"));
}
}

View File

@@ -294,7 +294,7 @@ pub fn new_udp_header<T: ToTargetAddr>(target_addr: T) -> Result<Vec<u8>> {
}
/// Parse data from UDP client on raw buffer, return (frag, target_addr, payload).
pub async fn parse_udp_request<'a>(mut req: &'a [u8]) -> Result<(u8, TargetAddr, &'a [u8])> {
pub async fn parse_udp_request(mut req: &[u8]) -> Result<(u8, TargetAddr, &[u8])> {
let rsv = read_exact!(req, [0u8; 2]).context("Malformed request")?;
if !rsv.eq(&[0u8; 2]) {

View File

@@ -455,16 +455,16 @@ impl<T: AsyncRead + AsyncWrite + Unpin, A: Authentication, C: AsyncTcpConnector>
info!("User logged successfully.");
return Ok(credentials);
Ok(credentials)
} else {
self.inner
.write_all(&[1, consts::SOCKS5_AUTH_METHOD_NOT_ACCEPTABLE])
.await
.context("Can't reply with auth method not acceptable.")?;
return Err(SocksError::AuthenticationRejected(format!(
"Authentication, rejected."
)));
Err(SocksError::AuthenticationRejected(
"Authentication, rejected.".to_string(),
))
}
}

View File

@@ -72,10 +72,7 @@ impl TargetAddr {
}
pub fn is_ip(&self) -> bool {
match self {
TargetAddr::Ip(_) => true,
_ => false,
}
matches!(self, TargetAddr::Ip(_))
}
pub fn is_domain(&self) -> bool {
@@ -104,7 +101,7 @@ impl TargetAddr {
}
TargetAddr::Domain(ref domain, port) => {
debug!("TargetAddr::Domain");
if domain.len() > u8::max_value() as usize {
if domain.len() > u8::MAX as usize {
return Err(SocksError::ExceededMaxDomainLen(domain.len()).into());
}
buf.extend_from_slice(&[consts::SOCKS5_ADDR_TYPE_DOMAIN_NAME, domain.len() as u8]);
@@ -125,8 +122,7 @@ impl std::net::ToSocketAddrs for TargetAddr {
fn to_socket_addrs(&self) -> io::Result<IntoIter<SocketAddr>> {
match *self {
TargetAddr::Ip(addr) => Ok(vec![addr].into_iter()),
TargetAddr::Domain(_, _) => Err(io::Error::new(
io::ErrorKind::Other,
TargetAddr::Domain(_, _) => Err(io::Error::other(
"Domain name has to be explicitly resolved, please use TargetAddr::resolve_dns().",
)),
}
@@ -149,7 +145,7 @@ pub trait ToTargetAddr {
fn to_target_addr(&self) -> io::Result<TargetAddr>;
}
impl<'a> ToTargetAddr for (&'a str, u16) {
impl ToTargetAddr for (&str, u16) {
fn to_target_addr(&self) -> io::Result<TargetAddr> {
// try to parse as an IP first
if let Ok(addr) = self.0.parse::<Ipv4Addr>() {

View File

@@ -23,6 +23,7 @@ use tracing::Instrument;
use crate::{
common::{error::Error, global_ctx::ArcGlobalCtx, PeerId},
gateway::ip_reassembler::ComposeIpv4PacketArgs,
peers::{peer_manager::PeerManager, PeerPacketFilter},
tunnel::packet_def::{PacketType, ZCPacket},
};
@@ -118,7 +119,7 @@ fn socket_recv_loop(
}
};
if len <= 0 {
if len == 0 {
tracing::error!("recv empty packet, len: {}", len);
return;
}
@@ -158,20 +159,18 @@ fn socket_recv_loop(
let payload_len = len - ipv4_packet.get_header_length() as usize * 4;
let id = ipv4_packet.get_identification();
let _ = compose_ipv4_packet(
&mut buf[..],
&v.mapped_dst_ip,
&dest_ip,
IpNextHeaderProtocols::Icmp,
payload_len,
1200,
id,
ComposeIpv4PacketArgs {
buf: &mut buf[..],
src_v4: &v.mapped_dst_ip,
dst_v4: &dest_ip,
next_protocol: IpNextHeaderProtocols::Icmp,
payload_len,
payload_mtu: 1200,
ip_id: id,
},
|buf| {
let mut p = ZCPacket::new_with_payload(buf);
p.fill_peer_manager_hdr(
v.my_peer_id.into(),
v.src_peer_id.into(),
PacketType::Data as u8,
);
p.fill_peer_manager_hdr(v.my_peer_id, v.src_peer_id, PacketType::Data as u8);
p.mut_peer_manager_header().unwrap().set_no_proxy(true);
if let Err(e) = sender.send(p) {
@@ -186,7 +185,7 @@ fn socket_recv_loop(
#[async_trait::async_trait]
impl PeerPacketFilter for IcmpProxy {
async fn try_process_packet_from_peer(&self, packet: ZCPacket) -> Option<ZCPacket> {
if let Some(_) = self.try_handle_peer_packet(&packet).await {
if self.try_handle_peer_packet(&packet).await.is_some() {
return None;
} else {
return Some(packet);
@@ -320,10 +319,7 @@ impl IcmpProxy {
.unwrap()
.as_ref()
.with_context(|| "icmp socket not created")?
.send_to(
icmp_packet.packet(),
&SocketAddrV4::new(dst_ip.into(), 0).into(),
)?;
.send_to(icmp_packet.packet(), &SocketAddrV4::new(dst_ip, 0).into())?;
Ok(())
}
@@ -349,13 +345,15 @@ impl IcmpProxy {
let len = buf.len() - 20;
let _ = compose_ipv4_packet(
&mut buf[..],
src_ip,
dst_ip,
IpNextHeaderProtocols::Icmp,
len,
1200,
rand::random(),
ComposeIpv4PacketArgs {
buf: &mut buf[..],
src_v4: src_ip,
dst_v4: dst_ip,
next_protocol: IpNextHeaderProtocols::Icmp,
payload_len: len,
payload_mtu: 1200,
ip_id: rand::random(),
},
|buf| {
let mut packet = ZCPacket::new_with_payload(buf);
packet.fill_peer_manager_hdr(src_peer_id, dst_peer_id, PacketType::Data as u8);
@@ -387,7 +385,7 @@ impl IcmpProxy {
return None;
};
let ipv4 = Ipv4Packet::new(&packet.payload())?;
let ipv4 = Ipv4Packet::new(packet.payload())?;
if ipv4.get_version() != 4 || ipv4.get_next_level_protocol() != IpNextHeaderProtocols::Icmp
{
@@ -396,17 +394,17 @@ impl IcmpProxy {
let mut real_dst_ip = ipv4.get_destination();
if !self
if !(self
.cidr_set
.contains_v4(ipv4.get_destination(), &mut real_dst_ip)
&& !is_exit_node
&& !(self.global_ctx.no_tun()
|| is_exit_node
|| (self.global_ctx.no_tun()
&& Some(ipv4.get_destination())
== self
.global_ctx
.get_ipv4()
.as_ref()
.map(cidr::Ipv4Inet::address))
.map(cidr::Ipv4Inet::address)))
{
return None;
}
@@ -416,12 +414,10 @@ impl IcmpProxy {
resembled_buf =
self.ip_resemmbler
.add_fragment(ipv4.get_source(), ipv4.get_destination(), &ipv4);
if resembled_buf.is_none() {
return None;
};
resembled_buf.as_ref()?;
icmp::echo_request::EchoRequestPacket::new(resembled_buf.as_ref().unwrap())?
} else {
icmp::echo_request::EchoRequestPacket::new(&ipv4.payload())?
icmp::echo_request::EchoRequestPacket::new(ipv4.payload())?
};
if icmp_packet.get_icmp_type() != IcmpTypes::EchoRequest {
@@ -484,10 +480,9 @@ impl Drop for IcmpProxy {
"dropping icmp proxy, {:?}",
self.socket.lock().unwrap().as_ref()
);
self.socket.lock().unwrap().as_ref().and_then(|s| {
if let Some(s) = self.socket.lock().unwrap().as_ref() {
tracing::info!("shutting down icmp socket");
let _ = s.shutdown(std::net::Shutdown::Both);
Some(())
});
}
}
}

View File

@@ -190,33 +190,36 @@ impl IpReassembler {
}
}
pub struct ComposeIpv4PacketArgs<'a> {
pub buf: &'a mut [u8],
pub src_v4: &'a Ipv4Addr,
pub dst_v4: &'a Ipv4Addr,
pub next_protocol: IpNextHeaderProtocol,
pub payload_len: usize,
pub payload_mtu: usize,
pub ip_id: u16,
}
// ip payload should be in buf[20..]
pub fn compose_ipv4_packet<F>(
buf: &mut [u8],
src_v4: &Ipv4Addr,
dst_v4: &Ipv4Addr,
next_protocol: IpNextHeaderProtocol,
payload_len: usize,
payload_mtu: usize,
ip_id: u16,
cb: F,
) -> Result<(), Error>
pub fn compose_ipv4_packet<F>(args: ComposeIpv4PacketArgs, cb: F) -> Result<(), Error>
where
F: Fn(&[u8]) -> Result<(), Error>,
{
let total_pieces = (payload_len + payload_mtu - 1) / payload_mtu;
let total_pieces = args.payload_len.div_ceil(args.payload_mtu);
let mut buf_offset = 0;
let mut fragment_offset = 0;
let mut cur_piece = 0;
while fragment_offset < payload_len {
let next_fragment_offset = std::cmp::min(fragment_offset + payload_mtu, payload_len);
while fragment_offset < args.payload_len {
let next_fragment_offset =
std::cmp::min(fragment_offset + args.payload_mtu, args.payload_len);
let fragment_len = next_fragment_offset - fragment_offset;
let mut ipv4_packet =
MutableIpv4Packet::new(&mut buf[buf_offset..buf_offset + fragment_len + 20]).unwrap();
MutableIpv4Packet::new(&mut args.buf[buf_offset..buf_offset + fragment_len + 20])
.unwrap();
ipv4_packet.set_version(4);
ipv4_packet.set_header_length(5);
ipv4_packet.set_total_length((fragment_len + 20) as u16);
ipv4_packet.set_identification(ip_id);
ipv4_packet.set_identification(args.ip_id);
if total_pieces > 1 {
if cur_piece != total_pieces - 1 {
ipv4_packet.set_flags(Ipv4Flags::MoreFragments);
@@ -232,9 +235,9 @@ where
ipv4_packet.set_ecn(0);
ipv4_packet.set_dscp(0);
ipv4_packet.set_ttl(32);
ipv4_packet.set_source(src_v4.clone());
ipv4_packet.set_destination(dst_v4.clone());
ipv4_packet.set_next_level_protocol(next_protocol);
ipv4_packet.set_source(*args.src_v4);
ipv4_packet.set_destination(*args.dst_v4);
ipv4_packet.set_next_level_protocol(args.next_protocol);
ipv4_packet.set_checksum(ipv4::checksum(&ipv4_packet.to_immutable()));
tracing::trace!(?ipv4_packet, "udp nat packet response send");
@@ -254,7 +257,7 @@ mod tests {
#[test]
fn resembler() {
let raw_packets = vec![
let raw_packets = [
// last packet
vec![
0x45, 0x00, 0x00, 0x1c, 0x1c, 0x46, 0x20, 0x01, 0x40, 0x06, 0xb1, 0xe6, 0xc0, 0xa8,
@@ -282,7 +285,7 @@ mod tests {
let resembler = IpReassembler::new(Duration::from_secs(1));
for (idx, raw_packet) in raw_packets.iter().enumerate() {
if let Some(packet) = Ipv4Packet::new(&raw_packet) {
if let Some(packet) = Ipv4Packet::new(raw_packet) {
let ret = resembler.add_fragment(source, destination, &packet);
if idx != 2 {
assert!(ret.is_none());

View File

@@ -70,7 +70,9 @@ impl PeerPacketFilter for KcpEndpointFilter {
async fn try_process_packet_from_peer(&self, packet: ZCPacket) -> Option<ZCPacket> {
let t = packet.peer_manager_header().unwrap().packet_type;
if t == PacketType::KcpSrc as u8 && !self.is_src {
// src packet, but we are dst
} else if t == PacketType::KcpDst as u8 && self.is_src {
// dst packet, but we are src
} else {
return Some(packet);
}
@@ -103,7 +105,7 @@ async fn handle_kcp_output(
PacketType::KcpDst as u8
};
let mut packet = ZCPacket::new_with_payload(&packet.inner().freeze());
packet.fill_peer_manager_hdr(peer_mgr.my_peer_id(), dst_peer_id, packet_type as u8);
packet.fill_peer_manager_hdr(peer_mgr.my_peer_id(), dst_peer_id, packet_type);
if let Err(e) = peer_mgr.send_msg(packet, dst_peer_id).await {
tracing::error!("failed to send kcp packet to peer: {:?}", e);
@@ -171,7 +173,7 @@ impl NatDstConnector for NatDstKcpConnector {
let kcp_endpoint = self.kcp_endpoint.clone();
let my_peer_id = peer_mgr.my_peer_id();
let conn_data_clone = conn_data.clone();
let conn_data_clone = conn_data;
connect_tasks.spawn(async move {
kcp_endpoint
@@ -182,9 +184,7 @@ impl NatDstConnector for NatDstKcpConnector {
Bytes::from(conn_data_clone.encode_to_vec()),
)
.await
.with_context(|| {
format!("failed to connect to nat dst: {}", nat_dst.to_string())
})
.with_context(|| format!("failed to connect to nat dst: {}", nat_dst))
});
}
@@ -203,7 +203,7 @@ impl NatDstConnector for NatDstKcpConnector {
_ipv4: &Ipv4Packet,
_real_dst_ip: &mut Ipv4Addr,
) -> bool {
return hdr.from_peer_id == hdr.to_peer_id && hdr.is_kcp_src_modified();
hdr.from_peer_id == hdr.to_peer_id && hdr.is_kcp_src_modified()
}
fn transport_type(&self) -> TcpProxyEntryTransportType {
@@ -230,15 +230,10 @@ impl TcpProxyForKcpSrcTrait for TcpProxyForKcpSrc {
}
async fn check_dst_allow_kcp_input(&self, dst_ip: &Ipv4Addr) -> bool {
let peer_map: Arc<crate::peers::peer_map::PeerMap> =
self.0.get_peer_manager().get_peer_map();
let Some(dst_peer_id) = peer_map.get_peer_id_by_ipv4(dst_ip).await else {
return false;
};
let Some(peer_info) = peer_map.get_route_peer_info(dst_peer_id).await else {
return false;
};
peer_info.feature_flag.map(|x| x.kcp_input).unwrap_or(false)
self.0
.get_peer_manager()
.check_allow_kcp_to_dst(&IpAddr::V4(*dst_ip))
.await
}
}
@@ -464,14 +459,11 @@ impl KcpProxyDst {
.into();
let src_socket: SocketAddr = parsed_conn_data.src.unwrap_or_default().into();
match dst_socket.ip() {
IpAddr::V4(dst_v4_ip) => {
let mut real_ip = dst_v4_ip;
if cidr_set.contains_v4(dst_v4_ip, &mut real_ip) {
dst_socket.set_ip(real_ip.into());
}
if let IpAddr::V4(dst_v4_ip) = dst_socket.ip() {
let mut real_ip = dst_v4_ip;
if cidr_set.contains_v4(dst_v4_ip, &mut real_ip) {
dst_socket.set_ip(real_ip.into());
}
_ => {}
};
let conn_id = kcp_stream.conn_id();
@@ -586,7 +578,7 @@ impl TcpProxyRpc for KcpProxyDstRpcService {
let mut reply = ListTcpProxyEntryResponse::default();
if let Some(tcp_proxy) = self.0.upgrade() {
for item in tcp_proxy.iter() {
reply.entries.push(item.value().clone());
reply.entries.push(*item.value());
}
}
Ok(reply)

View File

@@ -56,11 +56,11 @@ impl CidrSet {
cidr_set.lock().unwrap().clear();
for cidr in cidrs.iter() {
let real_cidr = cidr.cidr;
let mapped = cidr.mapped_cidr.unwrap_or(real_cidr.clone());
cidr_set.lock().unwrap().push(mapped.clone());
let mapped = cidr.mapped_cidr.unwrap_or(real_cidr);
cidr_set.lock().unwrap().push(mapped);
if mapped != real_cidr {
mapped_to_real.insert(mapped.clone(), real_cidr.clone());
mapped_to_real.insert(mapped, real_cidr);
}
}
}
@@ -70,11 +70,11 @@ impl CidrSet {
}
pub fn contains_v4(&self, ipv4: std::net::Ipv4Addr, real_ip: &mut std::net::Ipv4Addr) -> bool {
let ip = ipv4.into();
let ip = ipv4;
let s = self.cidr_set.lock().unwrap();
for cidr in s.iter() {
if cidr.contains(&ip) {
if let Some(real_cidr) = self.mapped_to_real.get(&cidr).map(|v| v.value().clone()) {
if let Some(real_cidr) = self.mapped_to_real.get(cidr).map(|v| *v.value()) {
let origin_network_bits = real_cidr.first().address().to_bits();
let network_mask = cidr.mask().to_bits();

View File

@@ -172,7 +172,7 @@ impl NatDstConnector for NatDstQUICConnector {
_ipv4: &Ipv4Packet,
_real_dst_ip: &mut Ipv4Addr,
) -> bool {
return hdr.from_peer_id == hdr.to_peer_id && !hdr.is_kcp_src_modified();
hdr.from_peer_id == hdr.to_peer_id && !hdr.is_kcp_src_modified()
}
fn transport_type(&self) -> TcpProxyEntryTransportType {
@@ -457,7 +457,7 @@ impl TcpProxyRpc for QUICProxyDstRpcService {
let mut reply = ListTcpProxyEntryResponse::default();
if let Some(tcp_proxy) = self.0.upgrade() {
for item in tcp_proxy.iter() {
reply.entries.push(item.value().clone());
reply.entries.push(*item.value());
}
}
Ok(reply)

View File

@@ -72,9 +72,9 @@ impl SocksUdpSocket {
}
enum SocksTcpStream {
TcpStream(tokio::net::TcpStream),
SmolTcpStream(super::tokio_smoltcp::TcpStream),
KcpStream(KcpStream),
Tcp(tokio::net::TcpStream),
SmolTcp(super::tokio_smoltcp::TcpStream),
Kcp(KcpStream),
}
impl AsyncRead for SocksTcpStream {
@@ -84,15 +84,11 @@ impl AsyncRead for SocksTcpStream {
buf: &mut tokio::io::ReadBuf<'_>,
) -> std::task::Poll<std::io::Result<()>> {
match self.get_mut() {
SocksTcpStream::TcpStream(ref mut stream) => {
std::pin::Pin::new(stream).poll_read(cx, buf)
}
SocksTcpStream::SmolTcpStream(ref mut stream) => {
std::pin::Pin::new(stream).poll_read(cx, buf)
}
SocksTcpStream::KcpStream(ref mut stream) => {
SocksTcpStream::Tcp(ref mut stream) => std::pin::Pin::new(stream).poll_read(cx, buf),
SocksTcpStream::SmolTcp(ref mut stream) => {
std::pin::Pin::new(stream).poll_read(cx, buf)
}
SocksTcpStream::Kcp(ref mut stream) => std::pin::Pin::new(stream).poll_read(cx, buf),
}
}
}
@@ -104,15 +100,11 @@ impl AsyncWrite for SocksTcpStream {
buf: &[u8],
) -> std::task::Poll<Result<usize, std::io::Error>> {
match self.get_mut() {
SocksTcpStream::TcpStream(ref mut stream) => {
std::pin::Pin::new(stream).poll_write(cx, buf)
}
SocksTcpStream::SmolTcpStream(ref mut stream) => {
std::pin::Pin::new(stream).poll_write(cx, buf)
}
SocksTcpStream::KcpStream(ref mut stream) => {
SocksTcpStream::Tcp(ref mut stream) => std::pin::Pin::new(stream).poll_write(cx, buf),
SocksTcpStream::SmolTcp(ref mut stream) => {
std::pin::Pin::new(stream).poll_write(cx, buf)
}
SocksTcpStream::Kcp(ref mut stream) => std::pin::Pin::new(stream).poll_write(cx, buf),
}
}
@@ -121,11 +113,9 @@ impl AsyncWrite for SocksTcpStream {
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Result<(), std::io::Error>> {
match self.get_mut() {
SocksTcpStream::TcpStream(ref mut stream) => std::pin::Pin::new(stream).poll_flush(cx),
SocksTcpStream::SmolTcpStream(ref mut stream) => {
std::pin::Pin::new(stream).poll_flush(cx)
}
SocksTcpStream::KcpStream(ref mut stream) => std::pin::Pin::new(stream).poll_flush(cx),
SocksTcpStream::Tcp(ref mut stream) => std::pin::Pin::new(stream).poll_flush(cx),
SocksTcpStream::SmolTcp(ref mut stream) => std::pin::Pin::new(stream).poll_flush(cx),
SocksTcpStream::Kcp(ref mut stream) => std::pin::Pin::new(stream).poll_flush(cx),
}
}
@@ -134,15 +124,9 @@ impl AsyncWrite for SocksTcpStream {
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Result<(), std::io::Error>> {
match self.get_mut() {
SocksTcpStream::TcpStream(ref mut stream) => {
std::pin::Pin::new(stream).poll_shutdown(cx)
}
SocksTcpStream::SmolTcpStream(ref mut stream) => {
std::pin::Pin::new(stream).poll_shutdown(cx)
}
SocksTcpStream::KcpStream(ref mut stream) => {
std::pin::Pin::new(stream).poll_shutdown(cx)
}
SocksTcpStream::Tcp(ref mut stream) => std::pin::Pin::new(stream).poll_shutdown(cx),
SocksTcpStream::SmolTcp(ref mut stream) => std::pin::Pin::new(stream).poll_shutdown(cx),
SocksTcpStream::Kcp(ref mut stream) => std::pin::Pin::new(stream).poll_shutdown(cx),
}
}
}
@@ -196,7 +180,7 @@ impl AsyncTcpConnector for SmolTcpConnector {
let modified_addr =
SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), addr.port());
Ok(SocksTcpStream::TcpStream(
Ok(SocksTcpStream::Tcp(
tcp_connect_with_timeout(modified_addr, timeout_s).await?,
))
} else {
@@ -207,9 +191,9 @@ impl AsyncTcpConnector for SmolTcpConnector {
.await
.with_context(|| "connect to remote timeout")?;
Ok(SocksTcpStream::SmolTcpStream(remote_socket.map_err(
|e| super::fast_socks5::SocksError::Other(e.into()),
)?))
Ok(SocksTcpStream::SmolTcp(remote_socket.map_err(|e| {
super::fast_socks5::SocksError::Other(e.into())
})?))
}
}
}
@@ -249,7 +233,7 @@ impl AsyncTcpConnector for Socks5KcpConnector {
.connect(self.src_addr, addr)
.await
.map_err(|e| super::fast_socks5::SocksError::Other(e.into()))?;
Ok(SocksTcpStream::KcpStream(ret))
Ok(SocksTcpStream::Kcp(ret))
}
}
@@ -560,16 +544,16 @@ impl Socks5Server {
tcp_forward_task.lock().unwrap().abort_all();
udp_client_map.clear();
if cur_ipv4.is_none() {
let _ = net.lock().await.take();
} else {
if let Some(cur_ipv4) = cur_ipv4 {
net.lock().await.replace(Socks5ServerNet::new(
cur_ipv4.unwrap(),
cur_ipv4,
None,
peer_manager.clone(),
packet_recv.clone(),
entries.clone(),
));
} else {
let _ = net.lock().await.take();
}
}
@@ -621,7 +605,7 @@ impl Socks5Server {
let cfgs = self.global_ctx.config.get_port_forwards();
self.reload_port_forwards(&cfgs).await?;
need_start = need_start || cfgs.len() > 0;
need_start = need_start || !cfgs.is_empty();
if need_start {
self.peer_manager
@@ -751,20 +735,26 @@ impl Socks5Server {
continue;
};
let Some(peer_mgr_arc) = peer_mgr.upgrade() else {
tracing::error!("peer manager is dropped");
continue;
};
let dst_allow_kcp = peer_mgr_arc.check_allow_kcp_to_dst(&dst_addr.ip()).await;
tracing::debug!("dst_allow_kcp: {:?}", dst_allow_kcp);
let connector: Box<dyn AsyncTcpConnector<S = SocksTcpStream> + Send> =
if kcp_endpoint.is_none() {
Box::new(SmolTcpConnector {
match (&kcp_endpoint, dst_allow_kcp) {
(Some(kcp_endpoint), true) => Box::new(Socks5KcpConnector {
kcp_endpoint: kcp_endpoint.clone(),
peer_mgr: peer_mgr.clone(),
src_addr: addr,
}),
(_, _) => Box::new(SmolTcpConnector {
net: net.smoltcp_net.clone(),
entries: entries.clone(),
current_entry: std::sync::Mutex::new(None),
})
} else {
let kcp_endpoint = kcp_endpoint.as_ref().unwrap().clone();
Box::new(Socks5KcpConnector {
kcp_endpoint,
peer_mgr: peer_mgr.clone(),
src_addr: addr,
})
}),
};
forward_tasks
@@ -954,10 +944,10 @@ impl Socks5Server {
udp_client_map.retain(|_, client_info| {
now.duration_since(client_info.last_active.load()).as_secs() < 600
});
udp_forward_task.retain(|k, _| udp_client_map.contains_key(&k));
udp_forward_task.retain(|k, _| udp_client_map.contains_key(k));
entries.retain(|_, data| match data {
Socks5EntryData::Udp((_, udp_client_key)) => {
udp_client_map.contains_key(&udp_client_key)
udp_client_map.contains_key(udp_client_key)
}
_ => true,
});

View File

@@ -109,9 +109,9 @@ impl NatDstConnector for NatDstTcpConnector {
) -> bool {
let is_exit_node = hdr.is_exit_node();
if !cidr_set.contains_v4(ipv4.get_destination(), real_dst_ip)
&& !is_exit_node
&& !(global_ctx.no_tun()
if !(cidr_set.contains_v4(ipv4.get_destination(), real_dst_ip)
|| is_exit_node
|| global_ctx.no_tun()
&& Some(ipv4.get_destination())
== global_ctx.get_ipv4().as_ref().map(Ipv4Inet::address))
{
@@ -154,10 +154,10 @@ impl NatDstEntry {
}
}
fn into_pb(&self, transport_type: TcpProxyEntryTransportType) -> TcpProxyEntry {
fn parse_as_pb(&self, transport_type: TcpProxyEntryTransportType) -> TcpProxyEntry {
TcpProxyEntry {
src: Some(self.src.clone().into()),
dst: Some(self.real_dst.clone().into()),
src: Some(self.src.into()),
dst: Some(self.real_dst.into()),
start_time: self.start_time_local.timestamp() as u64,
state: self.state.load().into(),
transport_type: transport_type.into(),
@@ -332,16 +332,14 @@ pub struct TcpProxy<C: NatDstConnector> {
#[async_trait::async_trait]
impl<C: NatDstConnector> PeerPacketFilter for TcpProxy<C> {
async fn try_process_packet_from_peer(&self, mut packet: ZCPacket) -> Option<ZCPacket> {
if let Some(_) = self.try_handle_peer_packet(&mut packet).await {
if self.try_handle_peer_packet(&mut packet).await.is_some() {
if self.is_smoltcp_enabled() {
let smoltcp_stack_sender = self.smoltcp_stack_sender.as_ref().unwrap();
if let Err(e) = smoltcp_stack_sender.try_send(packet) {
tracing::error!("send to smoltcp stack failed: {:?}", e);
}
} else {
if let Err(e) = self.peer_manager.get_nic_channel().send(packet).await {
tracing::error!("send to nic failed: {:?}", e);
}
} else if let Err(e) = self.peer_manager.get_nic_channel().send(packet).await {
tracing::error!("send to nic failed: {:?}", e);
}
return None;
} else {
@@ -610,7 +608,7 @@ impl<C: NatDstConnector> TcpProxy<C> {
self.enable_smoltcp
.store(false, std::sync::atomic::Ordering::Relaxed);
return Ok(ProxyTcpListener::KernelTcpListener(tcp_listener));
Ok(ProxyTcpListener::KernelTcpListener(tcp_listener))
}
}
@@ -917,10 +915,10 @@ impl<C: NatDstConnector> TcpProxy<C> {
let mut entries: Vec<TcpProxyEntry> = Vec::new();
let transport_type = self.connector.transport_type();
for entry in self.syn_map.iter() {
entries.push(entry.value().as_ref().into_pb(transport_type));
entries.push(entry.value().as_ref().parse_as_pb(transport_type));
}
for entry in self.conn_map.iter() {
entries.push(entry.value().as_ref().into_pb(transport_type));
entries.push(entry.value().as_ref().parse_as_pb(transport_type));
}
entries
}

View File

@@ -17,11 +17,17 @@ pub struct ChannelDevice {
caps: DeviceCapabilities,
}
pub type ChannelDeviceNewRet = (
ChannelDevice,
Sender<io::Result<Vec<u8>>>,
Receiver<Vec<u8>>,
);
impl ChannelDevice {
/// Make a new `ChannelDevice` with the given `recv` and `send` channels.
///
/// The `caps` is used to determine the device capabilities. `DeviceCapabilities::max_transmission_unit` must be set.
pub fn new(caps: DeviceCapabilities) -> (Self, Sender<io::Result<Vec<u8>>>, Receiver<Vec<u8>>) {
pub fn new(caps: DeviceCapabilities) -> ChannelDeviceNewRet {
let (tx1, rx1) = channel(1000);
let (tx2, rx2) = channel(1000);
(
@@ -45,7 +51,7 @@ impl Stream for ChannelDevice {
}
fn map_err(e: PollSendError<Vec<u8>>) -> io::Error {
io::Error::new(io::ErrorKind::Other, e)
io::Error::other(e)
}
impl Sink<Vec<u8>> for ChannelDevice {

View File

@@ -46,8 +46,8 @@ impl RxToken for BufferRxToken {
F: FnOnce(&[u8]) -> R,
{
let p = &mut self.0;
let result = f(p);
result
f(p)
}
}
@@ -79,10 +79,9 @@ impl Device for BufferDevice {
Self: 'a;
fn receive(&mut self, _timestamp: Instant) -> Option<(Self::RxToken<'_>, Self::TxToken<'_>)> {
match self.recv_queue.pop_front() {
Some(p) => Some((BufferRxToken(p), BufferTxToken(self))),
None => None,
}
self.recv_queue
.pop_front()
.map(|p| (BufferRxToken(p), BufferTxToken(self)))
}
fn transmit(&mut self, _timestamp: Instant) -> Option<Self::TxToken<'_>> {

View File

@@ -4,7 +4,7 @@
use std::{
io,
net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr},
net::{IpAddr, SocketAddr},
sync::{
atomic::{AtomicU16, Ordering},
Arc,
@@ -34,7 +34,7 @@ mod socket_allocator;
/// Can be used to create a forever timestamp in neighbor.
// The 60_000 is the same as NeighborCache::ENTRY_LIFETIME.
pub const FOREVER: Instant =
Instant::from_micros_const(i64::max_value() - Duration::from_millis(60_000).micros() as i64);
Instant::from_micros_const(i64::MAX - Duration::from_millis(60_000).micros() as i64);
pub struct Neighbor {
pub protocol_addr: IpAddress,
@@ -173,8 +173,8 @@ impl Net {
fn set_address(&self, mut addr: SocketAddr) -> SocketAddr {
if addr.ip().is_unspecified() {
addr.set_ip(match self.ip_addr.address() {
IpAddress::Ipv4(ip) => Ipv4Addr::from(ip).into(),
IpAddress::Ipv6(ip) => Ipv6Addr::from(ip).into(),
IpAddress::Ipv4(ip) => ip.into(),
IpAddress::Ipv6(ip) => ip.into(),
#[allow(unreachable_patterns)]
_ => panic!("address must not be unspecified"),
});

View File

@@ -51,9 +51,7 @@ async fn run(
loop {
let packets = device.take_send_queue();
async_iface
.send_all(&mut iter(packets).map(|p| Ok(p)))
.await?;
async_iface.send_all(&mut iter(packets).map(Ok)).await?;
if recv_buf.is_empty() && device.need_wait() {
let start = Instant::now();
@@ -94,14 +92,10 @@ async fn run(
// wake up all closed sockets (smoltcp seems have a bug that it doesn't wake up closed sockets)
for (_, socket) in socket_allocator.sockets().lock().iter_mut() {
match socket {
Socket::Tcp(tcp) => {
if tcp.state() == smoltcp::socket::tcp::State::Closed {
tcp.abort();
}
if let Socket::Tcp(tcp) = socket {
if tcp.state() == smoltcp::socket::tcp::State::Closed {
tcp.abort();
}
#[allow(unreachable_patterns)]
_ => {}
}
}
}
@@ -164,10 +158,8 @@ impl Reactor {
impl Drop for Reactor {
fn drop(&mut self) {
for (_, socket) in self.socket_allocator.sockets().lock().iter_mut() {
match socket {
Socket::Tcp(tcp) => tcp.close(),
#[allow(unreachable_patterns)]
_ => {}
if let Socket::Tcp(tcp) = socket {
tcp.close()
}
}
}

Some files were not shown because too many files have changed in this diff Show More