Skip to content

Commit

Permalink
refine quit logic
Browse files Browse the repository at this point in the history
  • Loading branch information
ssrlive committed Feb 15, 2024
1 parent 142abb7 commit f6f5edb
Show file tree
Hide file tree
Showing 9 changed files with 117 additions and 131 deletions.
5 changes: 3 additions & 2 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ async-shared-timeout = "0.2"
base64 = "0.21"
bytes = "1.5"
chrono = "0.4"
clap = { version = "4.4", features = ["derive"] }
clap = { version = "4.5", features = ["derive"] }
ctrlc2 = { version = "3.5", features = ["tokio", "termination"] }
dotenvy = "0.15"
env_logger = "0.11"
Expand All @@ -35,9 +35,10 @@ serde = { version = "1.0", features = ["derive"] }
serde_json = "1.0"
socks5-impl = "0.5"
thiserror = "1.0"
tokio = { version = "1.35", features = ["full"] }
tokio = { version = "1.36", features = ["full"] }
tokio-rustls = "0.25"
tokio-tungstenite = { version = "0.21", features = ["rustls-tls-webpki-roots"] }
tokio-util = "0.7"
trust-dns-proto = "0.23"
tungstenite = { version = "0.21", features = ["rustls-tls-webpki-roots"] }
url = "2.5"
Expand Down
33 changes: 17 additions & 16 deletions src/android.rs
Original file line number Diff line number Diff line change
Expand Up @@ -142,10 +142,7 @@ pub mod native {
}
}

lazy_static::lazy_static! {
pub static ref EXITING_FLAG: Arc<AtomicBool> = Arc::new(AtomicBool::new(false));
pub static ref LISTEN_ADDR: Arc<Mutex<SocketAddr>> = Arc::new(Mutex::new(SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 8080)));
}
static EXITING_FLAG: std::sync::Mutex<Option<crate::CancellationToken>> = std::sync::Mutex::new(None);

/// # Safety
///
Expand All @@ -159,6 +156,16 @@ pub mod native {
stat_path: JString,
verbosity: jint,
) -> jint {
let shutdown_token = crate::CancellationToken::new();
{
let mut lock = EXITING_FLAG.lock().unwrap();
if lock.is_some() {
log::error!("tun2proxy already started");
return -1;
}
*lock = Some(shutdown_token.clone());
}

let mut env = env;

let log_level = ArgVerbosity::try_from(verbosity).unwrap().to_string();
Expand Down Expand Up @@ -194,8 +201,7 @@ pub mod native {

let rt = tokio::runtime::Builder::new_multi_thread().enable_all().build()?;
rt.block_on(async {
EXITING_FLAG.store(false, Ordering::SeqCst);
crate::client::run_client(&config, Some(EXITING_FLAG.clone()), Some(callback)).await?;
crate::client::run_client(&config, shutdown_token, Some(callback)).await?;
Ok::<(), Error>(())
})
};
Expand All @@ -218,16 +224,11 @@ pub mod native {
pub unsafe extern "C" fn Java_com_github_shadowsocks_bg_OverTlsWrapper_stopClient(_: JNIEnv, _: JClass) -> jint {
stop_protect_socket();

EXITING_FLAG.store(true, Ordering::SeqCst);

let l_addr = *LISTEN_ADDR.lock().unwrap();
let addr = if l_addr.is_ipv6() {
SocketAddr::from((Ipv6Addr::LOCALHOST, l_addr.port()))
} else {
SocketAddr::from((Ipv4Addr::LOCALHOST, l_addr.port()))
};
let _ = std::net::TcpStream::connect(addr);
log::trace!("stopClient on listen address {l_addr}");
if let Ok(mut token) = EXITING_FLAG.lock() {
if let Some(token) = token.take() {
token.cancel();
}
}

SocketProtector::release();
Jni::release();
Expand Down
41 changes: 18 additions & 23 deletions src/api.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,8 @@ use crate::{
ArgVerbosity,
};
use std::{
net::{Ipv4Addr, Ipv6Addr, SocketAddr},
net::SocketAddr,
os::raw::{c_char, c_int, c_void},
sync::{
atomic::{AtomicBool, Ordering},
Arc, Mutex,
},
};

#[derive(Clone)]
Expand All @@ -27,10 +23,7 @@ impl CCallback {
unsafe impl Send for CCallback {}
unsafe impl Sync for CCallback {}

lazy_static::lazy_static! {
static ref EXITING_FLAG: Arc<AtomicBool> = Arc::new(AtomicBool::new(false));
static ref LISTEN_ADDR: Arc<Mutex<SocketAddr>> = Arc::new(Mutex::new(SocketAddr::from((Ipv4Addr::LOCALHOST, 0))));
}
static EXITING_FLAG: std::sync::Mutex<Option<crate::CancellationToken>> = std::sync::Mutex::new(None);

/// # Safety
///
Expand All @@ -53,6 +46,16 @@ unsafe fn _over_tls_client_run(
callback: Option<unsafe extern "C" fn(c_int, *mut c_void)>,
ctx: *mut c_void,
) -> c_int {
let shutdown_token = crate::CancellationToken::new();
{
let mut lock = EXITING_FLAG.lock().unwrap();
if lock.is_some() {
log::error!("tun2proxy already started");
return -1;
}
*lock = Some(shutdown_token.clone());
}

let ccb = CCallback(callback, ctx);

let block = || -> Result<()> {
Expand All @@ -61,8 +64,6 @@ unsafe fn _over_tls_client_run(
let cb = |addr: SocketAddr| {
log::trace!("Listening on {}", addr);
let port = addr.port();
let addr = SocketAddr::from((Ipv4Addr::LOCALHOST, port));
*LISTEN_ADDR.lock().unwrap() = addr;
unsafe {
ccb.call(port as c_int);
}
Expand All @@ -72,8 +73,7 @@ unsafe fn _over_tls_client_run(
config.check_correctness(false)?;
let rt = tokio::runtime::Builder::new_multi_thread().enable_all().build()?;
rt.block_on(async {
EXITING_FLAG.store(false, Ordering::SeqCst);
crate::client::run_client(&config, Some(EXITING_FLAG.clone()), Some(cb)).await?;
crate::client::run_client(&config, shutdown_token, Some(cb)).await?;
Ok::<(), Error>(())
})
};
Expand All @@ -89,15 +89,10 @@ unsafe fn _over_tls_client_run(
/// Shutdown the client.
#[no_mangle]
pub unsafe extern "C" fn over_tls_client_stop() -> c_int {
EXITING_FLAG.store(true, Ordering::SeqCst);

let l_addr = *LISTEN_ADDR.lock().unwrap();
let addr = if l_addr.is_ipv6() {
SocketAddr::from((Ipv6Addr::LOCALHOST, l_addr.port()))
} else {
SocketAddr::from((Ipv4Addr::LOCALHOST, l_addr.port()))
};
let _ = std::net::TcpStream::connect(addr);
log::trace!("Client stop on listen address {}", l_addr);
if let Ok(mut token) = EXITING_FLAG.lock() {
if let Some(token) = token.take() {
token.cancel();
}
}
0
}
25 changes: 6 additions & 19 deletions src/bin/overtls.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,4 @@
use overtls::{client, config, server, CmdOpt, Error, Result};
use std::{
net::{Ipv4Addr, Ipv6Addr, SocketAddr},
sync::{atomic::AtomicBool, Arc},
};

fn main() -> Result<()> {
let opt = CmdOpt::parse_cmd();
Expand Down Expand Up @@ -43,40 +39,31 @@ fn main() -> Result<()> {
}

async fn async_main(config: config::Config) -> Result<()> {
let exiting_flag = Arc::new(AtomicBool::new(false));
let exiting_flag_clone = exiting_flag.clone();
let shutdown_token = overtls::CancellationToken::new();
let shutdown_token_clone = shutdown_token.clone();

let main_body = async {
if config.is_server {
if config.exist_server() {
server::run_server(&config, Some(exiting_flag_clone)).await?;
server::run_server(&config, shutdown_token_clone).await?;
} else {
return Err(Error::from("Config is not a server config"));
}
} else if config.exist_client() {
let callback = |addr| {
log::trace!("Listening on {}", addr);
};
client::run_client(&config, Some(exiting_flag_clone), Some(callback)).await?;
client::run_client(&config, shutdown_token_clone, Some(callback)).await?;
} else {
return Err("Config is not a client config".into());
}

Ok(())
};

let local_addr = config.listen_addr()?;

ctrlc2::set_async_handler(async move {
exiting_flag.store(true, std::sync::atomic::Ordering::Relaxed);

let addr = if local_addr.is_ipv6() {
SocketAddr::from((Ipv6Addr::LOCALHOST, local_addr.port()))
} else {
SocketAddr::from((Ipv4Addr::LOCALHOST, local_addr.port()))
};
let _ = std::net::TcpStream::connect(addr);
log::info!("");
log::info!("Ctrl-C received, exiting...");
shutdown_token.cancel();
})
.await;

Expand Down
43 changes: 20 additions & 23 deletions src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,7 @@ use socks5_impl::{
AuthAdaptor, ClientConnection, Connect, IncomingConnection, Server,
},
};
use std::{
net::SocketAddr,
sync::{
atomic::{AtomicBool, Ordering},
Arc,
},
};
use std::{net::SocketAddr, sync::Arc};
use tokio::{
io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt},
net::TcpStream,
Expand All @@ -38,7 +32,7 @@ use tungstenite::{
protocol::{Message, Role},
};

pub async fn run_client<F>(config: &Config, exiting_flag: Option<Arc<AtomicBool>>, callback: Option<F>) -> Result<()>
pub async fn run_client<F>(config: &Config, quit: crate::CancellationToken, callback: Option<F>) -> Result<()>
where
F: FnOnce(SocketAddr) + Send + Sync + 'static,
{
Expand All @@ -52,14 +46,14 @@ where
if let Some(user) = listen_user {
let listen_password = client.listen_password.as_deref().unwrap_or("");
let key = UserKeyAuth::new(user, listen_password);
_run_client(config, Arc::new(key), exiting_flag, callback).await?;
_run_client(config, Arc::new(key), quit, callback).await?;
} else {
_run_client(config, Arc::new(NoAuth), exiting_flag, callback).await?;
_run_client(config, Arc::new(NoAuth), quit, callback).await?;
}
Ok(())
}

async fn _run_client<F, O>(config: &Config, auth: AuthAdaptor<O>, exiting_flag: Option<Arc<AtomicBool>>, callback: Option<F>) -> Result<()>
async fn _run_client<F, O>(config: &Config, auth: AuthAdaptor<O>, quit: crate::CancellationToken, callback: Option<F>) -> Result<()>
where
F: FnOnce(SocketAddr) + Send + Sync + 'static,
O: Send + Sync + 'static,
Expand All @@ -74,23 +68,26 @@ where
}

let (udp_tx, _, incomings) = udprelay::create_udp_tunnel();
udprelay::udp_handler_watchdog(config, &incomings, &udp_tx, exiting_flag.clone()).await?;
udprelay::udp_handler_watchdog(config, &incomings, &udp_tx, quit.clone()).await?;

while let Ok((conn, _)) = server.accept().await {
if let Some(exiting_flag) = &exiting_flag {
if exiting_flag.load(Ordering::Relaxed) {
loop {
tokio::select! {
_ = quit.cancelled() => {
log::info!("exiting...");
break;
}
}
let config = config.clone();
let udp_tx = udp_tx.clone();
let incomings = incomings.clone();
tokio::spawn(async move {
if let Err(e) = handle_incoming(conn, config, Some(udp_tx), incomings).await {
log::debug!("{}", e);
result = server.accept() => {
let (conn, _) = result?;
let config = config.clone();
let udp_tx = udp_tx.clone();
let incomings = incomings.clone();
tokio::spawn(async move {
if let Err(e) = handle_incoming(conn, config, Some(udp_tx), incomings).await {
log::debug!("{}", e);
}
});
}
});
}
}

Ok(())
Expand Down
2 changes: 1 addition & 1 deletion src/cmdopt.rs
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ impl std::fmt::Display for ArgVerbosity {
}

/// Proxy tunnel over tls
#[derive(clap::Parser, Debug, Clone, PartialEq, Eq)]
#[derive(clap::Parser, Debug, Clone, PartialEq, Eq, Default)]
#[command(author, version, about = "Proxy tunnel over tls.", long_about = None)]
pub struct CmdOpt {
/// Role of server or client
Expand Down
1 change: 1 addition & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ use bytes::BytesMut;
pub use cmdopt::{ArgVerbosity, CmdOpt, Role};
pub use error::{Error, Result};
use socks5_impl::protocol::{Address, StreamOperation};
pub use tokio_util::sync::CancellationToken;

#[cfg(target_os = "windows")]
pub(crate) const STREAM_BUFFER_SIZE: usize = 1024 * 32;
Expand Down
52 changes: 26 additions & 26 deletions src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,7 @@ use socks5_impl::protocol::{Address, StreamOperation};
use std::{
collections::HashMap,
net::{Ipv4Addr, Ipv6Addr, SocketAddr, ToSocketAddrs},
sync::{
atomic::{AtomicBool, Ordering},
Arc,
},
sync::Arc,
};
use tokio::{
io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt},
Expand All @@ -33,7 +30,7 @@ use tungstenite::{
const WS_HANDSHAKE_LEN: usize = 1024;
const WS_MSG_HEADER_LEN: usize = 14;

pub async fn run_server(config: &Config, exiting_flag: Option<Arc<AtomicBool>>) -> Result<()> {
pub async fn run_server(config: &Config, exiting_flag: crate::CancellationToken) -> Result<()> {
log::info!("starting {} server...", env!("CARGO_PKG_NAME"));
log::trace!("with following settings:");
log::trace!("{}", serde_json::to_string_pretty(config)?);
Expand Down Expand Up @@ -86,33 +83,36 @@ pub async fn run_server(config: &Config, exiting_flag: Option<Arc<AtomicBool>>)
let listener = TcpListener::bind(&addr).await?;

loop {
let (stream, peer_addr) = listener.accept().await?;
if let Some(exiting_flag) = &exiting_flag {
if exiting_flag.load(Ordering::Relaxed) {
tokio::select! {
_ = exiting_flag.cancelled() => {
log::info!("exiting...");
break;
}
}
let acceptor = acceptor.clone();
let config = config.clone();
let traffic_audit = traffic_audit.clone();

let incoming_task = async move {
if let Some(acceptor) = acceptor {
let stream = acceptor.accept(stream).await?;
handle_incoming(stream, peer_addr, config, traffic_audit).await?;
} else {
handle_incoming(stream, peer_addr, config, traffic_audit).await?;
}
Ok::<_, Error>(())
};
ret = listener.accept() => {
let (stream, peer_addr) = ret?;
let acceptor = acceptor.clone();
let config = config.clone();
let traffic_audit = traffic_audit.clone();

let incoming_task = async move {
if let Some(acceptor) = acceptor {
let stream = acceptor.accept(stream).await?;
handle_incoming(stream, peer_addr, config, traffic_audit).await?;
} else {
handle_incoming(stream, peer_addr, config, traffic_audit).await?;
}
Ok::<_, Error>(())
};

tokio::spawn(async move {
if let Err(e) = incoming_task.await {
log::debug!("{peer_addr}: {e}");
tokio::spawn(async move {
if let Err(e) = incoming_task.await {
log::debug!("{peer_addr}: {e}");
}
});
}
});
}
}

Ok(())
}

Expand Down
Loading

0 comments on commit f6f5edb

Please sign in to comment.