Skip to content

Commit

Permalink
Add support for external termination notifications for streams
Browse files Browse the repository at this point in the history
By adding a specialized unsubsribe request handler that lets the server terminates specifc streams
  • Loading branch information
Eligioo committed Jan 22, 2025
1 parent c8cbf5a commit b592346
Show file tree
Hide file tree
Showing 3 changed files with 126 additions and 33 deletions.
8 changes: 5 additions & 3 deletions derive/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -149,10 +149,12 @@ impl<'a> RpcMethod<'a> {
request,
move |params: #args_struct_ident| async move {
let stream = self.#method_ident(#(#method_args),*).await?;
let notifier = ::std::sync::Arc::new(::nimiq_jsonrpc_server::Notify::new());
let listener = notifier.clone();

let subscription = ::nimiq_jsonrpc_server::connect_stream(stream, tx, stream_id, #method_name.to_owned());
let subscription = ::nimiq_jsonrpc_server::connect_stream(stream, tx, stream_id, #method_name.to_owned(), listener);

Ok::<_, ::nimiq_jsonrpc_core::RpcError>(subscription)
Ok::<_, ::nimiq_jsonrpc_core::RpcError>((subscription, Some(notifier)))
}
).await
}
Expand All @@ -171,7 +173,7 @@ impl<'a> RpcMethod<'a> {
return ::nimiq_jsonrpc_server::dispatch_method_with_args(
request,
move |params: #args_struct_ident| async move {
Ok::<_, ::nimiq_jsonrpc_core::RpcError>(self.#method_ident(#(#method_args),*).await?)
Ok::<(_, Option<::std::sync::Arc<::nimiq_jsonrpc_server::Notify>>), ::nimiq_jsonrpc_core::RpcError>((self.#method_ident(#(#method_args),*).await?, None))
}
).await
}
Expand Down
2 changes: 1 addition & 1 deletion derive/src/service.rs
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ fn impl_service(im: &mut ItemImpl, args: &ServiceMeta) -> TokenStream {
request: ::nimiq_jsonrpc_core::Request,
tx: Option<&::tokio::sync::mpsc::Sender<::nimiq_jsonrpc_server::Message>>,
stream_id: u64,
) -> Option<::nimiq_jsonrpc_core::Response> {
) -> Option<::nimiq_jsonrpc_server::ResponseAndSubScriptionNotifier> {
match request.method.as_str() {
#(#match_arms)*
_ => ::nimiq_jsonrpc_server::method_not_found(request),
Expand Down
149 changes: 120 additions & 29 deletions server/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
#![warn(rustdoc::missing_doc_code_examples)]

use std::{
collections::HashSet,
collections::{HashMap, HashSet},
error,
fmt::{self, Debug},
future::Future,
Expand Down Expand Up @@ -51,8 +51,12 @@ use nimiq_jsonrpc_core::{
};

pub use axum::extract::ws::Message;
pub use tokio::sync::Notify;
use tower_http::cors::{Any, CorsLayer};

/// Type defining a response and a possible notify handle used to terminate a subscription stream
pub type ResponseAndSubScriptionNotifier = (Response, Option<Arc<Notify>>);

/// A server error.
#[derive(Debug, Error)]
pub enum Error {
Expand Down Expand Up @@ -245,6 +249,7 @@ struct Inner<D: Dispatcher> {
config: Config,
dispatcher: RwLock<D>,
next_id: AtomicU64,
subscription_notifiers: RwLock<HashMap<SubscriptionId, Arc<Notify>>>,
}

/// A JSON-RPC server.
Expand All @@ -266,6 +271,7 @@ impl<D: Dispatcher> Server<D> {
config,
dispatcher: RwLock::new(dispatcher),
next_id: AtomicU64::new(1),
subscription_notifiers: RwLock::new(HashMap::new()),
}),
}
}
Expand Down Expand Up @@ -450,7 +456,7 @@ impl<D: Dispatcher> Server<D> {
match request {
SingleOrBatch::Single(request) => Self::handle_single_request(inner, request, tx)
.await
.map(SingleOrBatch::Single),
.map(|(response, _)| SingleOrBatch::Single(response)),

SingleOrBatch::Batch(requests) => {
let futures = requests
Expand All @@ -459,7 +465,7 @@ impl<D: Dispatcher> Server<D> {
.collect::<FuturesUnordered<_>>();

let responses = futures
.filter_map(|response_opt| async { response_opt })
.filter_map(|response_opt| async { response_opt.map(|(response, _)| response) })
.collect::<Vec<Response>>()
.await;

Expand All @@ -469,15 +475,15 @@ impl<D: Dispatcher> Server<D> {
}

/// Handles a single JSON RPC request
///
/// # TODO
///
/// - Handle subscriptions
async fn handle_single_request(
inner: Arc<Inner<D>>,
request: Request,
tx: Option<&mpsc::Sender<Message>>,
) -> Option<Response> {
) -> Option<ResponseAndSubScriptionNotifier> {
if request.method == "unsubscribe" {
return Self::handle_unsubscribe_stream(request, inner).await;
}

let mut dispatcher = inner.dispatcher.write().await;
// This ID is only used for streams
let id = inner.next_id.fetch_add(1, Ordering::SeqCst);
Expand All @@ -488,8 +494,67 @@ impl<D: Dispatcher> Server<D> {

log::debug!("response: {:#?}", response);

if let Some((_, Some(ref handler))) = response {
inner
.subscription_notifiers
.write()
.await
.insert(SubscriptionId::Number(id), handler.clone());
}

response
}

async fn handle_unsubscribe_stream(
request: Request,
inner: Arc<Inner<D>>,
) -> Option<ResponseAndSubScriptionNotifier> {
let params = if let Some(params) = request.params {
params
} else {
return error_response(request.id, || {
RpcError::invalid_request(Some(
"Missing request parameter containing a list of subscription ids".to_owned(),
))
});
};

let subscription_ids =
if let Ok(ids) = serde_json::from_value::<Vec<SubscriptionId>>(params) {
ids
} else {
return error_response(request.id, || {
RpcError::invalid_params(Some(
"A list of subscription ids is not provided".to_owned(),
))
});
};

if subscription_ids.is_empty() {
return error_response(request.id, || {
RpcError::invalid_params(Some(
"Empty list of subscription ids is provided".to_owned(),
))
});
}

let mut terminated_streams = vec![];
let mut subscription_notifiers = inner.subscription_notifiers.write().await;
for id in subscription_ids.iter() {
if let Some(notifier) = subscription_notifiers.remove(id) {
notifier.notify_one();
terminated_streams.push(id);
}
}

Some((
Response::new_success(
serde_json::to_value(request.id.unwrap_or_default()).unwrap(),
serde_json::to_value(terminated_streams).unwrap(),
),
None,
))
}
}

/// A method dispatcher. These take a request and handle the method execution. Can be generated from an `impl` block
Expand All @@ -502,7 +567,7 @@ pub trait Dispatcher: Send + Sync + 'static {
request: Request,
tx: Option<&mpsc::Sender<Message>>,
id: u64,
) -> Option<Response>;
) -> Option<ResponseAndSubScriptionNotifier>;

/// Returns whether a method should be dispatched with this dispatcher.
///
Expand Down Expand Up @@ -542,7 +607,7 @@ impl Dispatcher for ModularDispatcher {
request: Request,
tx: Option<&mpsc::Sender<Message>>,
id: u64,
) -> Option<Response> {
) -> Option<ResponseAndSubScriptionNotifier> {
for dispatcher in &mut self.dispatchers {
let m = dispatcher.match_method(&request.method);
log::debug!("Matching '{}' against dispatcher -> {}", request.method, m);
Expand Down Expand Up @@ -611,7 +676,7 @@ where
request: Request,
tx: Option<&mpsc::Sender<Message>>,
id: u64,
) -> Option<Response> {
) -> Option<ResponseAndSubScriptionNotifier> {
if self.is_allowed(&request.method) {
log::debug!("Dispatching method: {}", request.method);
self.inner.dispatch(request, tx, id).await
Expand Down Expand Up @@ -649,13 +714,16 @@ where
/// - Currently this always expects an object with named parameters. Do we want to accept a list too?
/// - Merge with it's other variant, as a function call without arguments is just one with `()` as request parameter.
///
pub async fn dispatch_method_with_args<P, R, E, F, Fut>(request: Request, f: F) -> Option<Response>
pub async fn dispatch_method_with_args<P, R, E, F, Fut>(
request: Request,
f: F,
) -> Option<ResponseAndSubScriptionNotifier>
where
P: for<'de> Deserialize<'de> + Send,
R: Serialize,
RpcError: From<E>,
F: FnOnce(P) -> Fut + Send,
Fut: Future<Output = Result<R, E>> + Send,
Fut: Future<Output = Result<(R, Option<Arc<Notify>>), E>> + Send,
{
let params = match request.params {
Some(params) => params,
Expand Down Expand Up @@ -683,12 +751,15 @@ where
///
/// This is a helper function used by implementations of `Dispatcher`.
///
pub async fn dispatch_method_without_args<R, E, F, Fut>(request: Request, f: F) -> Option<Response>
pub async fn dispatch_method_without_args<R, E, F, Fut>(
request: Request,
f: F,
) -> Option<ResponseAndSubScriptionNotifier>
where
R: Serialize,
RpcError: From<E>,
F: FnOnce() -> Fut + Send,
Fut: Future<Output = Result<R, E>> + Send,
Fut: Future<Output = Result<(R, Option<Arc<Notify>>), E>> + Send,
{
let result = f().await;

Expand All @@ -707,17 +778,20 @@ where
}

/// Constructs a [`Response`] if necessary (i.e., if the request ID was set).
fn response<R, E>(id_opt: Option<Value>, result: Result<R, E>) -> Option<Response>
fn response<R, E>(
id_opt: Option<Value>,
result: Result<(R, Option<Arc<Notify>>), E>,
) -> Option<ResponseAndSubScriptionNotifier>
where
R: Serialize,
RpcError: From<E>,
{
let response = match (id_opt, result) {
(Some(id), Ok(retval)) => {
let retval = serde_json::to_value(retval).expect("Failed to serialize return value");
Some(Response::new_success(id, retval))
(Some(id), Ok((value, subscription))) => {
let retval = serde_json::to_value(value).expect("Failed to serialize return value");
Some((Response::new_success(id, retval), subscription))
}
(Some(id), Err(e)) => Some(Response::new_error(id, RpcError::from(e))),
(Some(id), Err(e)) => Some((Response::new_error(id, RpcError::from(e)), None)),
(None, _) => None,
};

Expand All @@ -733,22 +807,22 @@ where
/// - `id_opt`: The ID field from the request.
/// - `e`: A function that returns the error. This is only called, if we actually can respond with an error.
///
pub fn error_response<E>(id_opt: Option<Value>, e: E) -> Option<Response>
pub fn error_response<E>(id_opt: Option<Value>, e: E) -> Option<ResponseAndSubScriptionNotifier>
where
E: FnOnce() -> RpcError,
{
if let Some(id) = id_opt {
let e = e();
log::error!("Error response: {:?}", e);
Some(Response::new_error(id, e))
Some((Response::new_error(id, e), None))
} else {
None
}
}

/// Returns an error response for a method that was not found. This returns `None`, if the request doesn't expect a
/// response.
pub fn method_not_found(request: Request) -> Option<Response> {
pub fn method_not_found(request: Request) -> Option<ResponseAndSubScriptionNotifier> {
let ::nimiq_jsonrpc_core::Request { id, method, .. } = request;

error_response(id, || {
Expand Down Expand Up @@ -798,6 +872,7 @@ pub fn connect_stream<T, S>(
tx: &mpsc::Sender<Message>,
stream_id: u64,
method: String,
notify_handler: Arc<Notify>,
) -> SubscriptionId
where
T: Serialize + Debug + Send + Sync,
Expand All @@ -811,14 +886,30 @@ where
tokio::spawn(async move {
pin_mut!(stream);

while let Some(item) = stream.next().await {
if let Err(e) = forward_notification(item, &mut tx, &id, &method).await {
// Break the loop when the channel is closed
if let Error::Mpsc(_) = e {
let notify_future = notify_handler.notified();
pin_mut!(notify_future);

loop {
tokio::select! {
item = stream.next() => {
match item {
Some(notification) => {
if let Err(e) = forward_notification(notification, &mut tx, &id, &method).await {
// Break the loop when the channel is closed
if let Error::Mpsc(_) = e {
break;
}

log::error!("{}", e);
}
},
None => break,
}
}
_ = &mut notify_future => {
// Break the loop when a unsubscribe notification is received
break;
}

log::error!("{}", e);
}
}
});
Expand Down

0 comments on commit b592346

Please sign in to comment.