Skip to content
This repository has been archived by the owner on Aug 30, 2024. It is now read-only.

Commit

Permalink
flush buffers and add a DISCONNECT command.
Browse files Browse the repository at this point in the history
  • Loading branch information
alessiodam committed Jun 19, 2024
1 parent 7dc1143 commit 49e61b3
Showing 1 changed file with 33 additions and 14 deletions.
47 changes: 33 additions & 14 deletions src/conn_handler.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::sync::{RwLock, Mutex};
use std::sync::Arc;
use tracing::{info, warn, error, debug};
use tracing::{info, warn, error};
use std::collections::HashMap;
use chrono::Utc;

Expand Down Expand Up @@ -30,12 +30,17 @@ pub async fn handle_connection(
break;
}
let message = String::from_utf8_lossy(&buf[..n]).to_string();
debug!(message="Received message", msg=%message);

if message.trim() == "DISCONNECT" {
socket_guard.write_all(b"DISCONNECTED\n").await.unwrap();
socket_guard.flush().await.unwrap();
break;
}
if server_password_correct {
if message.starts_with("AUTH:") {
if authenticated {
let _ = socket_guard.write_all(b"ALREADY_AUTHENTICATED\n").await;
socket_guard.write_all(b"ALREADY_AUTHENTICATED\n").await.unwrap();
socket_guard.flush().await.unwrap();
} else {
let auth_parts: Vec<&str> = message.splitn(3, ':').collect();
if auth_parts.len() == 3 {
Expand All @@ -48,19 +53,22 @@ pub async fn handle_connection(
match verify_session(&config, &username, session_token).await {
Ok(is_valid_session) => {
if !is_valid_session {
let _ = socket_guard.write_all(b"AUTH_FAILED\n").await;
socket_guard.write_all(b"AUTH_FAILED\n").await.unwrap();
socket_guard.flush().await.unwrap();
} else {
authenticated = true;
{
let mut users = active_users.write().await;
users.insert(username.clone(), Arc::clone(&socket));
}
let _ = socket_guard.write_all(b"AUTH_SUCCESS\n").await;
socket_guard.write_all(b"AUTH_SUCCESS\n").await.unwrap();
socket_guard.flush().await.unwrap();
}
},
Err(e) => {
let error_message = format!("AUTH_ERROR:{}\n", e);
let _ = socket_guard.write_all(error_message.as_bytes()).await;
socket_guard.write_all(error_message.as_bytes()).await.unwrap();
socket_guard.flush().await.unwrap();
},
}
} else {
Expand All @@ -70,16 +78,19 @@ pub async fn handle_connection(
let mut users = active_users.write().await;
users.insert(username.clone(), Arc::clone(&socket));
}
let _ = socket_guard.write_all(b"AUTH_SUCCESS\n").await;
socket_guard.write_all(b"AUTH_SUCCESS\n").await.unwrap();
socket_guard.flush().await.unwrap();
}
} else {
warn!(target: "auth", "Invalid AUTH message");
let _ = socket_guard.write_all(b"AUTH_INVALID\n").await;
socket_guard.write_all(b"AUTH_INVALID\n").await.unwrap();
socket_guard.flush().await.unwrap();
}
}
} else if authenticated {
if message.len() > 256 {
let _ = socket_guard.write_all(b"MESSAGE_TOO_LONG\n").await;
socket_guard.write_all(b"MESSAGE_TOO_LONG\n").await.unwrap();
socket_guard.flush().await.unwrap();
continue;
}
if let Some((recipient, message)) = message.split_once(':') {
Expand All @@ -91,22 +102,27 @@ pub async fn handle_connection(
send_direct_message(&active_users, recipient, &full_message).await;
}
} else {
let _ = socket_guard.write_all(b"INVALID_MESSAGE_FORMAT\n").await;
socket_guard.write_all(b"INVALID_MESSAGE_FORMAT\n").await.unwrap();
socket_guard.flush().await.unwrap();
}
} else {
let _ = socket_guard.write_all(b"NOT_AUTHENTICATED\n").await;
socket_guard.write_all(b"NOT_AUTHENTICATED\n").await.unwrap();
socket_guard.flush().await.unwrap();
}
} else if !server_password_correct && config.server.protect_server {
if message.starts_with("SERVER_PASS:") {
let server_password = message.trim_start_matches("SERVER_PASS:").trim();
if server_password == config.server.server_password {
server_password_correct = true;
let _ = socket_guard.write_all(b"SERVER_PASS_CORRECT\n").await;
socket_guard.write_all(b"SERVER_PASS_CORRECT\n").await.unwrap();
socket_guard.flush().await.unwrap();
} else {
let _ = socket_guard.write_all(b"SERVER_PASS_INCORRECT\n").await;
socket_guard.write_all(b"SERVER_PASS_INCORRECT\n").await.unwrap();
socket_guard.flush().await.unwrap();
}
} else {
let _ = socket_guard.write_all(b"SERVER_PASS_REQUIRED\n").await;
socket_guard.write_all(b"SERVER_PASS_REQUIRED\n").await.unwrap();
socket_guard.flush().await.unwrap();
}
}
}
Expand Down Expand Up @@ -150,6 +166,7 @@ async fn broadcast_message(
if let Err(e) = client.write_all(message.as_bytes()).await {
error!(target: "server", "Failed to send message: {}", e);
} else {
let _ = client.flush().await;
info!(target: "server", "Broadcasted message: {}", message);
}
});
Expand All @@ -161,11 +178,13 @@ async fn send_direct_message(active_users: &ActiveUsers, target: &str, message:
if let Some(client) = active_users.get(target) {
let client = client.clone();
let message = message.to_string();
info!(target: "server", "Sending direct message: {}", message);
tokio::spawn(async move {
let mut client = client.lock().await;
if let Err(e) = client.write_all(message.as_bytes()).await {
error!(target: "server", "Failed to send direct message: {}", e);
} else {
let _ = client.flush().await;
info!(target: "server", "Sent direct message: {}", message);
}
});
Expand Down

0 comments on commit 49e61b3

Please sign in to comment.