diff --git a/easytier/src/gateway/tcp_proxy.rs b/easytier/src/gateway/tcp_proxy.rs index 57828aec..4630dcf9 100644 --- a/easytier/src/gateway/tcp_proxy.rs +++ b/easytier/src/gateway/tcp_proxy.rs @@ -12,7 +12,7 @@ use std::net::{IpAddr, Ipv4Addr, SocketAddr, SocketAddrV4}; use std::sync::atomic::{AtomicBool, AtomicU16}; use std::sync::{Arc, Weak}; use std::time::{Duration, Instant}; -use tokio::io::{copy_bidirectional, AsyncRead, AsyncWrite}; +use tokio::io::{copy_bidirectional, AsyncRead, AsyncWrite, AsyncWriteExt}; use tokio::net::{TcpListener, TcpSocket, TcpStream}; use tokio::sync::{mpsc, Mutex}; use tokio::task::JoinSet; @@ -158,6 +158,20 @@ impl ProxyTcpStream { } } + pub async fn shutdown(&mut self) -> Result<()> { + match self { + Self::KernelTcpStream(stream) => { + stream.shutdown().await?; + Ok(()) + } + #[cfg(feature = "smoltcp")] + Self::SmolTcpStream(stream) => { + stream.shutdown().await?; + Ok(()) + } + } + } + pub async fn copy_bidirectional( &mut self, dst: &mut D, @@ -692,7 +706,15 @@ impl TcpProxy { let ret = src_tcp_stream.copy_bidirectional(&mut dst_tcp_stream).await; tracing::info!(nat_entry = ?nat_entry_clone, ret = ?ret, "nat tcp connection closed"); nat_entry_clone.state.store(NatDstEntryState::Closed); + + let ret = src_tcp_stream.shutdown().await; + tracing::info!(nat_entry = ?nat_entry_clone, ret = ?ret, "src tcp stream shutdown"); + + let ret = dst_tcp_stream.shutdown().await; + tracing::info!(nat_entry = ?nat_entry_clone, ret = ?ret, "dst tcp stream shutdown"); + drop(src_tcp_stream); + drop(dst_tcp_stream); // sleep later so the fin packet can be processed tokio::time::sleep(Duration::from_secs(10)).await;