Skip to content

Commit

Permalink
Allow to limit the number of concurrent requests made by the sdk
Browse files Browse the repository at this point in the history
  • Loading branch information
gnunicorn committed Jun 28, 2024
1 parent 6464d21 commit db3d4d7
Show file tree
Hide file tree
Showing 2 changed files with 56 additions and 3 deletions.
24 changes: 22 additions & 2 deletions crates/matrix-sdk/src/config/request.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,18 +44,21 @@ pub struct RequestConfig {
pub(crate) timeout: Duration,
pub(crate) retry_limit: Option<u64>,
pub(crate) retry_timeout: Option<Duration>,
pub(crate) max_concurrent_requests: usize,
pub(crate) force_auth: bool,
}

#[cfg(not(tarpaulin_include))]
impl Debug for RequestConfig {
fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
let Self { timeout, retry_limit, retry_timeout, force_auth } = self;
let Self { timeout, retry_limit, retry_timeout, force_auth, max_concurrent_requests } =
self;

let mut res = fmt.debug_struct("RequestConfig");
res.field("timeout", timeout)
.maybe_field("retry_limit", retry_limit)
.maybe_field("retry_timeout", retry_timeout);
.maybe_field("retry_timeout", retry_timeout)
.field("max_concurrent_requests", max_concurrent_requests);

if *force_auth {
res.field("force_auth", &true);
Expand All @@ -71,6 +74,7 @@ impl Default for RequestConfig {
timeout: DEFAULT_REQUEST_TIMEOUT,
retry_limit: Default::default(),
retry_timeout: Default::default(),
max_concurrent_requests: 0,
force_auth: false,
}
}
Expand Down Expand Up @@ -106,6 +110,22 @@ impl RequestConfig {
self
}

/// The total limit of request that are pending or run concurrently.
/// Any additional request beyond that number will be waiting until another
/// concurrent requests finished. Requests are queued fairly.
#[must_use]
pub fn max_concurrent_requests(mut self, limit: usize) -> Self {
self.max_concurrent_requests = limit;
self
}

/// Disable the limit of concurrent requests. Setting the limit to 0
/// has the same effect.
#[must_use]
pub fn disable_max_concurrent_requests(mut self) -> Self {
self.max_concurrent_requests = 0;
self
}
/// Set the timeout duration for all HTTP requests.
#[must_use]
pub fn timeout(mut self, timeout: Duration) -> Self {
Expand Down
35 changes: 34 additions & 1 deletion crates/matrix-sdk/src/http_client/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ use ruma::api::{
error::{FromHttpResponseError, IntoHttpError},
AuthScheme, MatrixVersion, OutgoingRequest, SendAccessToken,
};
use tokio::sync::{Semaphore, SemaphorePermit};
use tracing::{debug, field::debug, instrument, trace};

use crate::{config::RequestConfig, error::HttpError};
Expand All @@ -48,16 +49,45 @@ pub(crate) use native::HttpSettings;

pub(crate) const DEFAULT_REQUEST_TIMEOUT: Duration = Duration::from_secs(30);

#[derive(Clone, Debug)]
struct MaybeSemaphore(Arc<Option<Semaphore>>);

#[allow(dead_code)] // holding this until drop is all we are doing
struct MaybeSemaphorePermit<'a>(Option<SemaphorePermit<'a>>);

impl MaybeSemaphore {
fn new(max: usize) -> Self {
let inner = if max > 0 { Some(Semaphore::new(max)) } else { None };
MaybeSemaphore(Arc::new(inner))
}

async fn acquire(&self) -> MaybeSemaphorePermit {
match self.0.as_ref() {
Some(inner) => {
// ignoring errors as we never close this
MaybeSemaphorePermit(inner.acquire().await.ok())
}
None => MaybeSemaphorePermit(None),
}
}
}

#[derive(Clone, Debug)]
pub(crate) struct HttpClient {
pub(crate) inner: reqwest::Client,
pub(crate) request_config: RequestConfig,
queue: MaybeSemaphore,
next_request_id: Arc<AtomicU64>,
}

impl HttpClient {
pub(crate) fn new(inner: reqwest::Client, request_config: RequestConfig) -> Self {
HttpClient { inner, request_config, next_request_id: AtomicU64::new(0).into() }
HttpClient {
inner,
request_config,
queue: MaybeSemaphore::new(request_config.max_concurrent_requests),
next_request_id: AtomicU64::new(0).into(),
}
}

fn get_request_id(&self) -> String {
Expand Down Expand Up @@ -184,6 +214,9 @@ impl HttpClient {
request
};

// will be automatically dropped at the end of this function
let _handle = self.queue.acquire().await;

debug!("Sending request");

// There's a bunch of state in send_request, factor out a pinned inner
Expand Down

0 comments on commit db3d4d7

Please sign in to comment.