Skip to content

Commit

Permalink
Refactor server from warp into axum
Browse files Browse the repository at this point in the history
  • Loading branch information
Eligioo committed Jan 20, 2025
1 parent 251db16 commit 826b534
Show file tree
Hide file tree
Showing 3 changed files with 113 additions and 149 deletions.
11 changes: 5 additions & 6 deletions server/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,26 +11,25 @@ categories.workspace = true
keywords.workspace = true

[dependencies]
axum = { version = "0.8.1", features = ["ws"] }
axum-extra = { version = "0.10.0", features = ["typed-header"] }
async-trait = "0.1"
blake2 = "0.10"
bytes = "1.4"
futures = "0.3"
headers = "0.3"
http = "0.2"
log = "0.4"
serde = "1.0"
serde_json = "1.0"
subtle = "2.5"
thiserror = "1.0"
tokio = { version = "1.25", features = ["sync"] }
warp = "0.3"
tokio = { version = "1.43.0", features = ["sync"] }
tower-http = { version = "0.6.2", features = ["auth", "cors"] }

nimiq-jsonrpc-core = { workspace = true }

[dev-dependencies]
anyhow = "1.0"
pretty_env_logger = "0.5.0"
tokio = { version = "1.25", features = ["macros", "rt-multi-thread"] }
tokio = { version = "1.43.0", features = ["macros", "rt-multi-thread"] }

nimiq-jsonrpc-client = { workspace = true }
nimiq-jsonrpc-derive = { workspace = true }
53 changes: 0 additions & 53 deletions server/src/auth_filter.rs

This file was deleted.

198 changes: 108 additions & 90 deletions server/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,11 @@
#![warn(missing_docs)]
#![warn(rustdoc::missing_doc_code_examples)]

mod auth_filter;

use std::{
collections::HashSet,
error,
fmt::{self, Debug},
future::{self, Future},
future::Future,
net::{IpAddr, SocketAddr},
sync::{
atomic::{AtomicU64, Ordering},
Expand All @@ -19,33 +17,48 @@ use std::{
};

use async_trait::async_trait;
use axum::{
body::{Body, Bytes},
extract::{DefaultBodyLimit, State, WebSocketUpgrade},
http::{header::CONTENT_TYPE, response::Builder, HeaderValue, Method, StatusCode},
middleware::Next,
response::{IntoResponse as _, Response as HttpResponse},
routing::{any, post},
Router,
};
use axum_extra::{
headers::{authorization::Basic, Authorization},
TypedHeader,
};
use blake2::{digest::consts::U32, Blake2b, Digest};
use bytes::Bytes;
use futures::{
pin_mut,
sink::SinkExt,
stream::{FuturesUnordered, StreamExt},
Stream,
};
use headers::{authorization::Basic, Authorization};
use serde::{de::Deserialize, ser::Serialize};
use serde_json::Value;
use subtle::ConstantTimeEq;
use thiserror::Error;
use tokio::sync::{mpsc, RwLock, RwLockReadGuard, RwLockWriteGuard};
pub use warp::filters::ws::Message;
use warp::{filters::cors::Builder, Filter};
use tokio::{
net::TcpListener,
sync::{mpsc, RwLock, RwLockReadGuard, RwLockWriteGuard},
};

use nimiq_jsonrpc_core::{
Request, Response, RpcError, Sensitive, SingleOrBatch, SubscriptionId, SubscriptionMessage,
};

pub use axum::extract::ws::Message;
use tower_http::cors::{Any, CorsLayer};

/// A server error.
#[derive(Debug, Error)]
pub enum Error {
/// Error returned by warp
/// Error returned by axum
#[error("HTTP error: {0}")]
Warp(#[from] warp::Error),
Axum(#[from] axum::Error),

/// Error from the message queues, that are used internally.
#[error("Queue error: {0}")]
Expand Down Expand Up @@ -100,33 +113,63 @@ fn blake2b(bytes: &[u8]) -> [u8; 32] {
*Blake2b::<U32>::digest(bytes).as_ref()
}

async fn basic_auth_middleware<D: Dispatcher>(
State(state): State<Arc<Inner<D>>>,
basic_auth_header: Option<TypedHeader<Authorization<Basic>>>,
request: axum::extract::Request,
next: Next,
) -> HttpResponse {
if let Some(auth_config) = &state.config.basic_auth {
if let Some(auth_header) = basic_auth_header {
if auth_config
.verify(auth_header.username(), auth_header.password())
.is_ok()
{
return next.run(request).await;
}
return StatusCode::UNAUTHORIZED.into_response();
} else {
return StatusCode::UNAUTHORIZED.into_response();
}
}

next.run(request).await
}

#[derive(Clone, Debug)]
/// CORS configuration
pub struct Cors(Builder);
pub struct Cors(CorsLayer);

impl Cors {
/// Create a new instance with `Content-Type` as mandatory header and `POST` as mandatory method.
pub fn new() -> Self {
Self(
warp::cors()
.allow_header("Content-Type")
.allow_method("POST"),
CorsLayer::new()
.allow_headers([CONTENT_TYPE])
.allow_methods([Method::POST]),
)
}

/// Configure CORS to only allow specific origins.
/// Note that multiple calls to this method will override any previous origin-related calls.
pub fn with_origins(mut self, origins: Vec<&str>) -> Self {
self.0 = self.0.allow_origins(origins);
self.0 = self.0.allow_origin::<Vec<HeaderValue>>(
origins
.iter()
.map(|o| o.parse::<HeaderValue>().unwrap())
.collect(),
);
self
}

/// Configure CORS to allow every origin. Also known as the `*` wildcard.
/// Note that multiple calls to this method will override any previous origin-related calls.
pub fn with_any_origin(mut self) -> Self {
self.0 = self.0.allow_any_origin();
self.0 = self.0.allow_origin(Any);
self
}

pub(crate) fn into_wrapper(self) -> Builder {
pub(crate) fn into_layer(self) -> CorsLayer {
self.0
}
}
Expand Down Expand Up @@ -235,75 +278,48 @@ impl<D: Dispatcher> Server<D> {

/// Runs the server forever.
pub async fn run(&self) {
// Route to use JSON-RPC over websocket
let inner = Arc::clone(&self.inner);
let ws_route = warp::path("ws")
.and(warp::path::end())
.and(warp::ws())
.map(move |ws| Self::upgrade_to_ws(Arc::clone(&inner), ws));
let http_router = Router::new().route(
"/",
post(|body: Bytes| async move {
let data = Self::handle_raw_request(inner, &Message::binary(body), None)
.await
.unwrap_or(Message::Binary(Bytes::new()));

Builder::new()
.status(StatusCode::OK)
.header(CONTENT_TYPE, "application/json")
.body(Body::from(data.into_data().to_owned()))
.unwrap() // As long as the hard-coded status code and content-type is correct, this won't fail.
}),
);

// Route for backwards-compatibility to use JSON-RPC over HTTP at /
let inner = Arc::clone(&self.inner);
let post_route = warp::path::end()
.and(warp::post())
.and(warp::body::content_length_limit(1024 * 1024))
.and(warp::body::bytes())
.and_then(move |body: Bytes| {
let inner = Arc::clone(&inner);
async move {
let data = Self::handle_raw_request(inner, &Message::binary(body), None)
.await
.unwrap_or(Message::binary([]));

let response = http::response::Builder::new()
.status(200)
.header("Content-Type", "application/json")
.body(data.as_bytes().to_owned())
.unwrap(); // As long as the hard-coded status code and content-type is correct, this won't fail.

Ok::<_, warp::Rejection>(response)
}
});

let json_rpc_route = ws_route.or(post_route);

let root = if self.inner.config.basic_auth.is_some() {
let inner = Arc::clone(&self.inner);
let realm = "JSON-RPC";
auth_filter::basic_auth_filter(realm)
.and_then(move |auth_header: Authorization<Basic>| {
let inner = Arc::clone(&inner);

let basic_auth = inner.config.basic_auth.as_ref().unwrap();
future::ready(
basic_auth
.verify(auth_header.0.username(), auth_header.0.password())
.map_err(|CredentialsVerificationError(())| {
warp::reject::custom(auth_filter::Unauthorized {
realm: realm.to_string(),
})
}),
)
})
.untuple_one()
.boxed()
} else {
warp::any().boxed()
};

warp::serve(
root.and(json_rpc_route)
.with(
self.inner
.config
.cors
.clone()
.map_or(warp::cors(), |cors| cors.into_wrapper()),
)
.recover(auth_filter::handle_auth_rejection),
)
.run(self.inner.config.bind_to)
.await;
let ws_router = Router::new().route(
"/ws",
any(|ws: WebSocketUpgrade| async move { Self::upgrade_to_ws(inner, ws) }),
);

let app = Router::new()
.merge(http_router)
.merge(ws_router)
.route_layer(axum::middleware::from_fn_with_state(
Arc::clone(&self.inner),
basic_auth_middleware,
))
.layer(DefaultBodyLimit::max(1024 * 1024 /* 1MB */))
.layer(
self.inner
.config
.cors
.clone()
.unwrap_or_default()
.into_layer(),
)
.with_state(Arc::clone(&self.inner));

let listener = TcpListener::bind(self.inner.config.bind_to).await.unwrap();
axum::serve(listener, app).await.unwrap();
}

/// Upgrades a connection to websocket. This creates message queues and tasks to forward messages between them.
Expand All @@ -316,7 +332,7 @@ impl<D: Dispatcher> Server<D> {
/// - This sends stuff as binary websocket frames. It should really use text frames.
/// - Make the queue size configurable
///
fn upgrade_to_ws(inner: Arc<Inner<D>>, ws: warp::ws::Ws) -> impl warp::Reply {
fn upgrade_to_ws(inner: Arc<Inner<D>>, ws: WebSocketUpgrade) -> HttpResponse<Body> {
ws.on_upgrade(move |websocket| {
let (mut tx, mut rx) = websocket.split();

Expand All @@ -326,7 +342,7 @@ impl<D: Dispatcher> Server<D> {
let forward_fut = async move {
while let Some(data) = multiplex_rx.recv().await {
// Close the sink if we get a close message (don't echo the message since this is not permitted)
if data.is_close() {
if matches!(data, Message::Close(_)) {
tx.close().await?;
} else {
tx.send(data).await?;
Expand All @@ -339,11 +355,13 @@ impl<D: Dispatcher> Server<D> {
let handle_fut = {
async move {
while let Some(message) = rx.next().await.transpose()? {
if message.is_ping() || message.is_pong() {
if matches!(message, Message::Ping(_))
|| matches!(message, Message::Pong(_))
{
// Do nothing - these messages are handled automatically
} else if message.is_close() {
} else if matches!(message, Message::Close(_)) {
// We received the close message, so we need to send a close message to close the sink
multiplex_tx.send(warp::ws::Message::close()).await?;
multiplex_tx.send(Message::Close(None)).await?;
// Then we exit the loop which closes the connection
break;
} else if let Some(response) = Self::handle_raw_request(
Expand Down Expand Up @@ -382,7 +400,7 @@ impl<D: Dispatcher> Server<D> {
request: &Message,
tx: Option<&mpsc::Sender<Message>>,
) -> Option<Message> {
match serde_json::from_slice(request.as_bytes()) {
match serde_json::from_slice(request.clone().into_data().as_ref()) {
Ok(request) => Self::handle_request(inner, request, tx).await,
Err(_e) => {
log::error!("Received invalid JSON from client");
Expand All @@ -393,7 +411,7 @@ impl<D: Dispatcher> Server<D> {
}
}
.map(|response| {
if request.is_text() {
if matches!(&request, Message::Text(_)) {
Message::text(
serde_json::to_string(&response)
.expect("Failed to serialize JSON RPC response"),
Expand Down

0 comments on commit 826b534

Please sign in to comment.