Skip to content

Commit

Permalink
Gracefully shutdown connector with task manager (#1655) (#1660)
Browse files Browse the repository at this point in the history
  • Loading branch information
pronebird authored Nov 26, 2024
1 parent 776da76 commit d3bafc5
Show file tree
Hide file tree
Showing 4 changed files with 187 additions and 36 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ impl TunnelStateHandler for ConnectingState {
#[cfg(any(target_os = "linux", target_os = "macos", target_os = "windows"))]
shared_state.route_handler.remove_routes().await;

NextTunnelState::NewState(ConnectingState::enter( self.retry_attempt.saturating_add(1), self.selected_gateways, shared_state))
NextTunnelState::NewState(ConnectingState::enter(self.retry_attempt.saturating_add(1), self.selected_gateways, shared_state))
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,9 @@ use nym_task::TaskManager;
use super::connected_tunnel::ConnectedTunnel;
use crate::{
mixnet::SharedMixnetClient,
tunnel_state_machine::tunnel::{gateway_selector::SelectedGateways, Error, Result},
tunnel_state_machine::tunnel::{
self, gateway_selector::SelectedGateways, AnyConnector, ConnectorError, Error, Result,
},
};

/// Struct holding addresses assigned by mixnet upon connect.
Expand Down Expand Up @@ -47,11 +49,41 @@ impl Connector {
self,
selected_gateways: SelectedGateways,
nym_ips: Option<IpPair>,
) -> Result<ConnectedTunnel> {
let mixnet_client_address = self.mixnet_client.nym_address().await;
) -> Result<ConnectedTunnel, ConnectorError> {
let result = Self::connect_inner(
selected_gateways,
nym_ips,
self.mixnet_client.clone(),
&self.gateway_directory_client,
)
.await;

match result {
Ok(assigned_addresses) => Ok(ConnectedTunnel::new(
self.task_manager,
self.mixnet_client,
assigned_addresses,
)),
Err(e) => Err(ConnectorError::new(
e,
AnyConnector::Mixnet(Self::new(
self.task_manager,
self.mixnet_client,
self.gateway_directory_client,
)),
)),
}
}

async fn connect_inner(
selected_gateways: SelectedGateways,
nym_ips: Option<IpPair>,
mixnet_client: SharedMixnetClient,
gateway_directory_client: &GatewayClient,
) -> Result<AssignedAddresses> {
let mixnet_client_address = mixnet_client.nym_address().await;
let gateway_used = mixnet_client_address.gateway().to_base58_string();
let entry_mixnet_gateway_ip: IpAddr = self
.gateway_directory_client
let entry_mixnet_gateway_ip: IpAddr = gateway_directory_client
.lookup_gateway_ip(&gateway_used)
.await
.map_err(|source| Error::LookupGatewayIp {
Expand All @@ -61,30 +93,30 @@ impl Connector {

let exit_mix_addresses = selected_gateways.exit.ipr_address.unwrap();

let mut ipr_client = IprClientConnect::new_from_inner(self.mixnet_client.inner()).await;
let mut ipr_client = IprClientConnect::new_from_inner(mixnet_client.inner()).await;
let interface_addresses = ipr_client
.connect(exit_mix_addresses.0, nym_ips)
.await
.map_err(Error::ConnectToIpPacketRouter)?;

let assigned_addresses = AssignedAddresses {
entry_mixnet_gateway_ip,
mixnet_client_address,
exit_mix_addresses,
interface_addresses,
};
if let Some(exit_country_code) = selected_gateways.exit.two_letter_iso_country_code() {
self.mixnet_client
mixnet_client
.send_stats_event(
ConnectionStatsEvent::MixCountry(exit_country_code.to_string()).into(),
)
.await;
}

Ok(ConnectedTunnel::new(
self.task_manager,
self.mixnet_client,
assigned_addresses,
))
Ok(AssignedAddresses {
entry_mixnet_gateway_ip,
mixnet_client_address,
exit_mix_addresses,
interface_addresses,
})
}

/// Gracefully shutdown task manager and consume the struct.
pub async fn dispose(self) {
tunnel::shutdown_task_manager(self.task_manager).await;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ pub mod mixnet;
mod status_listener;
pub mod wireguard;

use std::{path::PathBuf, time::Duration};
use std::{error::Error as StdError, fmt, path::PathBuf, time::Duration};

pub use gateway_selector::SelectedGateways;
use nym_gateway_directory::{EntryPoint, ExitPoint, GatewayClient};
Expand Down Expand Up @@ -60,9 +60,17 @@ impl ConnectedMixnet {
self.mixnet_client,
self.gateway_directory_client,
);
connector

match connector
.connect(self.selected_gateways, interface_addresses)
.await
{
Ok(connected_tunnel) => Ok(connected_tunnel),
Err(connector_error) => {
connector_error.connector.dispose().await;
Err(connector_error.error)
}
}
}

/// Creates a tunnel over WireGuard.
Expand All @@ -75,13 +83,21 @@ impl ConnectedMixnet {
self.mixnet_client,
self.gateway_directory_client,
);
connector

match connector
.connect(
enable_credentials_mode,
self.selected_gateways,
self.data_path,
)
.await
{
Ok(connected_tunnel) => Ok(connected_tunnel),
Err(connector_error) => {
connector_error.connector.dispose().await;
Err(connector_error.error)
}
}
}

/// Gracefully shutdown the mixnet client and consume the struct.
Expand Down Expand Up @@ -248,3 +264,53 @@ pub enum Error {
}

pub type Result<T, E = Error> = std::result::Result<T, E>;

/// Tunnel connector container.
pub enum AnyConnector {
Mixnet(mixnet::connector::Connector),
Wireguard(wireguard::connector::Connector),
}

impl AnyConnector {
pub async fn dispose(self) {
match self {
Self::Mixnet(connector) => connector.dispose().await,
Self::Wireguard(connector) => connector.dispose().await,
}
}
}

/// Error returned when connector is unable to connect the tunnel.
pub struct ConnectorError {
/// The error returned during the attempt to connect the tunnel.
pub error: Error,

/// The source connector.
pub connector: AnyConnector,
}

impl ConnectorError {
fn new(error: Error, connector: AnyConnector) -> Self {
Self { error, connector }
}
}

impl StdError for ConnectorError {
fn source(&self) -> Option<&(dyn StdError + 'static)> {
Some(&self.error)
}
}

impl fmt::Debug for ConnectorError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("ConnectorError")
.field("error", &self.error)
.finish_non_exhaustive()
}
}

impl fmt::Display for ConnectorError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
self.error.fmt(f)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@

use std::path::PathBuf;

use tokio::task::JoinHandle;

use nym_authenticator_client::AuthClient;
use nym_credentials_interface::TicketType;
use nym_gateway_directory::{AuthAddresses, Gateway, GatewayClient};
Expand All @@ -14,7 +16,9 @@ use super::connected_tunnel::ConnectedTunnel;
use crate::{
bandwidth_controller::BandwidthController,
mixnet::SharedMixnetClient,
tunnel_state_machine::tunnel::{gateway_selector::SelectedGateways, Error, Result},
tunnel_state_machine::tunnel::{
self, gateway_selector::SelectedGateways, AnyConnector, ConnectorError, Error, Result,
},
};

pub struct ConnectionData {
Expand All @@ -40,13 +44,49 @@ impl Connector {
gateway_directory_client,
}
}

pub async fn connect(
self,
enable_credentials_mode: bool,
selected_gateways: SelectedGateways,
data_path: Option<PathBuf>,
) -> Result<ConnectedTunnel> {
) -> Result<ConnectedTunnel, ConnectorError> {
let result = Self::connect_inner(
&self.task_manager,
self.mixnet_client.clone(),
&self.gateway_directory_client,
enable_credentials_mode,
selected_gateways,
data_path,
)
.await;

match result {
Ok(connect_result) => Ok(ConnectedTunnel::new(
self.task_manager,
connect_result.entry_gateway_client,
connect_result.exit_gateway_client,
connect_result.connection_data,
connect_result.bandwidth_controller_handle,
)),
Err(e) => Err(ConnectorError::new(
e,
AnyConnector::Wireguard(Self::new(
self.task_manager,
self.mixnet_client,
self.gateway_directory_client,
)),
)),
}
}

async fn connect_inner(
task_manager: &TaskManager,
mixnet_client: SharedMixnetClient,
gateway_directory_client: &GatewayClient,
enable_credentials_mode: bool,
selected_gateways: SelectedGateways,
data_path: Option<PathBuf>,
) -> Result<ConnectResult> {
let auth_addresses =
Self::setup_auth_addresses(&selected_gateways.entry, &selected_gateways.exit)?;
let (Some(entry_auth_recipient), Some(exit_auth_recipient)) =
Expand All @@ -56,7 +96,7 @@ impl Connector {
};
let entry_version = selected_gateways.entry.version.clone().into();
let exit_version = selected_gateways.exit.version.clone().into();
let auth_client = AuthClient::new_from_inner(self.mixnet_client.inner()).await;
let auth_client = AuthClient::new_from_inner(mixnet_client.inner()).await;

let mut wg_entry_gateway_client = if enable_credentials_mode {
WgGatewayClient::new_free_entry(
Expand Down Expand Up @@ -89,7 +129,7 @@ impl Connector {
)
};

let shutdown = self.task_manager.subscribe_named("bandwidth controller");
let shutdown = task_manager.subscribe_named("bandwidth controller");
let (connection_data, bandwidth_controller_handle) = if let Some(data_path) =
data_path.as_ref()
{
Expand All @@ -108,15 +148,15 @@ impl Connector {
.get_initial_bandwidth(
enable_credentials_mode,
TicketType::V1WireguardEntry,
&self.gateway_directory_client,
gateway_directory_client,
&mut wg_entry_gateway_client,
)
.await?;
let exit = bw
.get_initial_bandwidth(
enable_credentials_mode,
TicketType::V1WireguardExit,
&self.gateway_directory_client,
gateway_directory_client,
&mut wg_exit_gateway_client,
)
.await?;
Expand All @@ -136,15 +176,15 @@ impl Connector {
.get_initial_bandwidth(
enable_credentials_mode,
TicketType::V1WireguardEntry,
&self.gateway_directory_client,
gateway_directory_client,
&mut wg_entry_gateway_client,
)
.await?;
let exit = bw
.get_initial_bandwidth(
enable_credentials_mode,
TicketType::V1WireguardExit,
&self.gateway_directory_client,
gateway_directory_client,
&mut wg_exit_gateway_client,
)
.await?;
Expand All @@ -153,18 +193,19 @@ impl Connector {

(ConnectionData { entry, exit }, bandwidth_controller_handle)
};

if let Some(exit_country_code) = selected_gateways.exit.two_letter_iso_country_code() {
auth_client.send_stats_event(
ConnectionStatsEvent::WgCountry(exit_country_code.to_string()).into(),
);
}
Ok(ConnectedTunnel::new(
self.task_manager,
wg_entry_gateway_client,
wg_exit_gateway_client,

Ok(ConnectResult {
entry_gateway_client: wg_entry_gateway_client,
exit_gateway_client: wg_exit_gateway_client,
connection_data,
bandwidth_controller_handle,
))
})
}

fn setup_auth_addresses(entry: &Gateway, exit: &Gateway) -> Result<AuthAddresses> {
Expand All @@ -179,4 +220,16 @@ impl Connector {
exit_authenticator_address,
))
}

/// Gracefully shutdown task manager and consume the struct.
pub async fn dispose(self) {
tunnel::shutdown_task_manager(self.task_manager).await;
}
}

struct ConnectResult {
entry_gateway_client: WgGatewayClient,
exit_gateway_client: WgGatewayClient,
connection_data: ConnectionData,
bandwidth_controller_handle: JoinHandle<()>,
}

0 comments on commit d3bafc5

Please sign in to comment.