Skip to content

Commit

Permalink
feat: add graceful shutdown functionality to the service
Browse files Browse the repository at this point in the history
  • Loading branch information
cfanbo committed Aug 29, 2024
1 parent e22d92b commit 455d7fc
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 28 deletions.
1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ tokio = { version = "1.39.3", features = [
"macros",
"rt-multi-thread",
"tracing",
"signal"
] }
tonic = "0.12.1"
prost = "0.13.1"
Expand Down
82 changes: 54 additions & 28 deletions src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,19 @@ use redis_protocol::resp2::{
types::{OwnedFrame, Resp2Frame},
};
use std::net::SocketAddr;
use std::net::TcpListener;
// use std::net::TcpListener;
use std::path::PathBuf;
use std::sync::mpsc;
use std::sync::Arc;
use std::{
io::{Read, Write},
// io::{Read, Write},
sync::RwLock,
};
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::TcpListener;
use tokio::net::TcpStream;
use tokio::signal;
use tokio::sync::Notify;
use tokio::task::JoinSet;

pub struct Server {
Expand All @@ -30,12 +35,12 @@ impl Server {
Server { config, store }
}

async fn handle_client_connection(&self, mut stream: std::net::TcpStream) {
async fn handle_client_connection(&self, mut stream: TcpStream) {
let mut buffer = [0u8; 4096];

loop {
// 读取客户端发送的数据
let n = match stream.read(&mut buffer) {
let n = match stream.read(&mut buffer).await {
Ok(size) => size,
Err(e) => {
error!("Failed to read from stream: {}", e);
Expand All @@ -53,11 +58,11 @@ impl Server {
Ok(resp) => {
let mut buf = vec![0; resp.encode_len()];
encode(&mut buf, &resp).unwrap();
stream.write_all(&buf).unwrap();
stream.write_all(&buf).await.unwrap();
}
Err(e) => {
let error_message = format!("-ERR {}\r\n", e);
stream.write_all(error_message.as_bytes()).unwrap();
stream.write_all(error_message.as_bytes()).await.unwrap();
}
},
Ok(None) => {
Expand Down Expand Up @@ -851,33 +856,41 @@ impl Server {
}
}

pub fn server_start(self: Arc<Self>) -> anyhow::Result<()> {
pub async fn server_start(self: Arc<Self>, notify: Arc<Notify>) -> anyhow::Result<()> {
let addr = self.config.get_addr()?;
let listener = TcpListener::bind(addr)?;

let listener = TcpListener::bind(addr).await?;
println!("Listening on {}", addr);

for stream in listener.incoming() {
match stream {
Ok(stream) => {
debug!("New connection: {}", stream.peer_addr().unwrap());
let server_clone = Arc::clone(&self);

tokio::spawn(async move {
server_clone.handle_client_connection(stream).await;
});
}
Err(e) => {
error!("Error: {}", e);
loop {
tokio::select! {
accept_result = listener.accept() => {
match accept_result {
Ok((stream, _)) => {
debug!("New connection: {}", stream.peer_addr().unwrap());
let server_clone = Arc::clone(&self);
tokio::spawn(async move {
server_clone.handle_client_connection(stream).await;
});
},
Err(e) => {
error!("Failed to accept connection: {}", e);
}
}
},
_ = notify.notified() => {
println!("Shutdown signal received. Stopping server...");
break;
}
}
}

Ok(())
}
}

pub async fn start_server(option: &Option<PathBuf>) -> anyhow::Result<()> {
// 创建一个 Notify 对象,用于通知所有任务停止
let notify = Arc::new(Notify::new());

let conf = if let Some(file) = option {
config::Config::try_from(file.as_path())?
} else {
Expand All @@ -903,22 +916,31 @@ pub async fn start_server(option: &Option<PathBuf>) -> anyhow::Result<()> {
// server.server_start();
// });

join_set.spawn(async {
let notify_clone = Arc::clone(&notify);
join_set.spawn(async move {
// 使用 block_in_place 处理阻塞操作
tokio::task::block_in_place(|| {
server.server_start().unwrap();
});
server.server_start(notify_clone).await.unwrap();
});

// 添加 gRPC 服务器任务(如果配置存在)
if let Some(grpc_config) = the_config.get_grpc() {
let store_clone = Arc::clone(&store);
let addr = grpc_config.get_addr()?;
let notify_clone = Arc::clone(&notify);
join_set.spawn(async move {
run_grpc_server(addr, store_clone).await.unwrap();
run_grpc_server(addr, store_clone, notify_clone)
.await
.unwrap();
});
}

// 监听 Ctrl+C 信号
signal::ctrl_c().await.expect("Failed to listen for Ctrl+C");
info!("Received Ctrl+C, shutting down...");

// 通知所有任务停止
notify.notify_waiters();

// 等待所有任务完成
while let Some(Ok(_)) = join_set.join_next().await {}

Expand All @@ -929,12 +951,16 @@ pub async fn start_server(option: &Option<PathBuf>) -> anyhow::Result<()> {
async fn run_grpc_server(
addr: SocketAddr,
store: Arc<RwLock<dyn db_store::Op>>,
notify: Arc<Notify>,
) -> anyhow::Result<()> {
println!("gRPC Server Listening on {:?}", addr);

tonic::transport::Server::builder()
.add_service(StoreServer::new(StoreImpl::new(store)))
.serve(addr)
.serve_with_shutdown(addr, async {
notify.notified().await;
})
// .serve(addr)
.await?;

Ok(())
Expand Down

0 comments on commit 455d7fc

Please sign in to comment.