Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
  • Loading branch information
NishantJoshi00 committed Jan 29, 2024
2 parents fee135f + 358cdb8 commit 2861935
Show file tree
Hide file tree
Showing 11 changed files with 619 additions and 452 deletions.
851 changes: 480 additions & 371 deletions Cargo.lock

Large diffs are not rendered by default.

22 changes: 13 additions & 9 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,22 +20,26 @@ aws-config = { version = "1.0.1", optional = true }
aws-sdk-kms = { version = "1.3.0", optional = true }
base64 = "0.21.2"
futures = "0.3.28"
tracing = { version = "0.1.40" }
tracing-appender = { version = "0.2.2" }
tracing-attributes = "0.1.27"
tracing-subscriber = { version = "0.3.17", default-features = true, features = ["env-filter", "json", "registry"] }
gethostname = "0.4.3"
rustc-hash = "1.1"
once_cell = "1.18.0"
vaultrs = { version = "0.7.0", optional = true }

# Tokio Dependencies
tokio = { version = "1.33.0", features = ["macros", "rt-multi-thread"] }
axum = "0.6.20"
hyper = "0.14.27"
axum = "0.7.3"
hyper = "1.0.1"
tower = { version = "0.4.13", features = ["limit", "buffer", "load-shed"] }
tower-http = { version = "0.4.4", features = ["trace"] }

tower-http = { version = "0.5.0", features = ["trace"] }
tracing = { version = "0.1.40" }
tracing-appender = { version = "0.2.2" }
tracing-attributes = "0.1.27"
tracing-subscriber = { version = "0.3.17", default-features = true, features = [
"env-filter",
"json",
"registry",
] }
http-body-util = "0.1.0"

diesel = { version = "2.1.3", features = ["postgres", "serde_json", "time"] }
diesel-async = { version = "0.4.1", features = ["postgres", "deadpool"] }
Expand All @@ -61,8 +65,8 @@ argh = "0.1.12"

[dev-dependencies]
rand = "0.8.5"
axum-test = "13.0.1"
criterion = "0.5.1"
axum-test = "14.2.2"

[build-dependencies]
cargo_metadata = "0.15.4"
Expand Down
13 changes: 7 additions & 6 deletions src/app.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
use axum::routing;
use error_stack::ResultExt;
use hyper::server::conn;
use masking::PeekInterface;
#[cfg(feature = "key_custodian")]
use tokio::sync::{mpsc::Sender, RwLock};
Expand Down Expand Up @@ -52,7 +51,7 @@ pub async fn server1_builder(
state: Arc<RwLock<AppState>>,
server_tx: Sender<()>,
) -> Result<
hyper::Server<conn::AddrIncoming, routing::IntoMakeService<axum::Router>>,
axum::serve::Serve<routing::IntoMakeService<axum::Router>, axum::Router>,
error::ConfigurationError,
>
where
Expand All @@ -71,7 +70,9 @@ where
.with_state(shared_state)
.route("/health", routing::get(routes::health::health));

let server = axum::Server::try_bind(&socket_addr)?.serve(router.into_make_service());
let tcp_listener = tokio::net::TcpListener::bind(&socket_addr).await?;
let server = axum::serve(tcp_listener, router.into_make_service());

Ok(server)
}

Expand All @@ -82,7 +83,7 @@ where
pub async fn server2_builder(
state: &AppState,
) -> Result<
hyper::Server<conn::AddrIncoming, routing::IntoMakeService<axum::Router>>,
axum::serve::Serve<routing::IntoMakeService<axum::Router>, axum::Router>,
error::ConfigurationError,
>
where
Expand Down Expand Up @@ -121,8 +122,8 @@ where
.level(tracing::Level::ERROR),
),
);

let server = axum::Server::try_bind(&socket_addr)?.serve(router.into_make_service());
let tcp_listener = tokio::net::TcpListener::bind(&socket_addr).await?;
let server = axum::serve(tcp_listener, router.into_make_service());
Ok(server)
}

Expand Down
6 changes: 3 additions & 3 deletions src/bin/locker.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,11 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
{
let state_lock = state.clone();

let (server1_tx, mut server1_rx) = tokio::sync::mpsc::channel::<()>(1);
let (server1_tx, server1_rx) = tokio::sync::mpsc::channel::<()>(1);

let server1 = tartarus::app::server1_builder(state_lock, server1_tx.clone())
.await?
.with_graceful_shutdown(graceful_shutdown_server1(&mut server1_rx));
.with_graceful_shutdown(graceful_shutdown_server1(server1_rx));

logger::info!(
"Key Custodian started [{:?}] [{:?}]",
Expand All @@ -42,7 +42,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
}

#[cfg(feature = "key_custodian")]
async fn graceful_shutdown_server1(recv: &mut tokio::sync::mpsc::Receiver<()>) {
async fn graceful_shutdown_server1(mut recv: tokio::sync::mpsc::Receiver<()>) {
recv.recv().await;
logger::info!("Shutting down the server1 gracefully.");
}
20 changes: 16 additions & 4 deletions src/bin/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -80,8 +80,13 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
let pub_key = read_file_to_string(
&public_key.ok_or(error::CryptoError::InvalidData("public key not found"))?,
)?;
jwe_operation(|x| {
JWEncryption::new(priv_key, pub_key, jwe::RSA_OAEP_256, jwe::RSA_OAEP).encrypt(x)
jwe_operation(|payload| {
JWEncryption::new(priv_key, pub_key, jwe::RSA_OAEP_256, jwe::RSA_OAEP)
.encrypt(payload)
.and_then(|payload| {
Ok(serde_json::to_vec(&payload)
.map_err(error::CryptoError::SerdeJsonError)?)
})
})?;
}
SubCommand::JweDecrypt(JweD {
Expand All @@ -94,8 +99,15 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
let pub_key = read_file_to_string(
&public_key.ok_or(error::CryptoError::InvalidData("private key not found"))?,
)?;
jwe_operation(|x| {
JWEncryption::new(priv_key, pub_key, jwe::RSA_OAEP_256, jwe::RSA_OAEP).decrypt(x)
jwe_operation(|payload| {
serde_json::from_slice(&payload)
.map_err(error::CryptoError::SerdeJsonError)
.map_err(Into::into)
.and_then(|payload| {
JWEncryption::new(priv_key, pub_key, jwe::RSA_OAEP_256, jwe::RSA_OAEP)
.decrypt(payload)
})
// (x)
})?;
}
}
Expand Down
14 changes: 6 additions & 8 deletions src/crypto/jw.rs
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ impl JwsBody {
}
}

#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize, PartialEq, Eq)]
#[serde(rename_all = "camelCase")]
pub struct JweBody {
pub header: String,
Expand Down Expand Up @@ -90,10 +90,10 @@ impl JweBody {
}
}

impl super::Encryption<Vec<u8>, Vec<u8>> for JWEncryption {
impl super::Encryption<Vec<u8>, JweBody> for JWEncryption {
type ReturnType<'a, T> = Result<T, ContainerError<error::CryptoError>>;

fn encrypt(&self, input: Vec<u8>) -> Self::ReturnType<'_, Vec<u8>> {
fn encrypt(&self, input: Vec<u8>) -> Self::ReturnType<'_, JweBody> {
let payload = input;
let jws_encoded = jws_sign_payload(&payload, self.private_key.peek().as_bytes())?;
let jws_body = JwsBody::from_dotted_str(&jws_encoded).ok_or(
Expand All @@ -107,13 +107,11 @@ impl super::Encryption<Vec<u8>, Vec<u8>> for JWEncryption {
)?;
let jwe_body = JweBody::from_str(&jwe_encrypted)
.ok_or(error::CryptoError::InvalidData("JWE data incomplete"))?;
Ok(serde_json::to_vec(&jwe_body).map_err(error::CryptoError::from)?)
Ok(jwe_body)
}

fn decrypt(&self, input: Vec<u8>) -> Self::ReturnType<'_, Vec<u8>> {
let jwe_body: JweBody = serde_json::from_slice(&input).map_err(error::CryptoError::from)?;
let jwe_encoded = jwe_body.get_dotted_jwe();
// let algo = jwe::RSA_OAEP_256;
fn decrypt(&self, input: JweBody) -> Self::ReturnType<'_, Vec<u8>> {
let jwe_encoded = input.get_dotted_jwe();
let jwe_decrypted =
decrypt_jwe(&jwe_encoded, self.private_key.peek(), self.decryption_algo)?;

Expand Down
2 changes: 2 additions & 0 deletions src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ pub enum ConfigurationError {
ServerError(#[from] hyper::Error),
#[error("invalid host for socket")]
AddressError(#[from] std::net::AddrParseError),
#[error("invalid host for socket")]
IOError(#[from] std::io::Error),
#[error("Error while connecting/creating database pool")]
DatabaseError,
#[error("Failed to KMS decrypt: {0}")]
Expand Down
65 changes: 26 additions & 39 deletions src/middleware.rs
Original file line number Diff line number Diff line change
@@ -1,63 +1,50 @@
use crate::app::AppState;
use crate::crypto::jw::JWEncryption;
use crate::crypto::jw::{self, JWEncryption};
use crate::crypto::Encryption;
use crate::error::{self, ContainerError, ResultContainerExt};
use axum::{
body::BoxBody,
extract,
http::{Request, Response},
middleware::Next,
};
use error_stack::ResultExt;
use hyper::body::HttpBody;
use hyper::Body;
use axum::body::Body;
use axum::http::{request, response};
use axum::{extract, http::Request, middleware::Next};

use http_body_util::BodyExt;
use josekit::jwe;

/// Middleware providing implementation to perform JWE + JWS encryption and decryption around the
/// card APIs
pub async fn middleware(
extract::State(state): extract::State<AppState>,
request: Request<Body>,
next: Next<Body>,
) -> Result<Response<BoxBody>, ContainerError<error::ApiError>> {
let (parts, body) = request.into_parts();

let request_body =
hyper::body::to_bytes(body)
.await
.change_error(error::ApiError::RequestMiddlewareError(
"Failed to read request body for jwe decryption",
))?;

parts: request::Parts,
axum::Json(jwe_body): axum::Json<jw::JweBody>,
next: Next,
) -> Result<(response::Parts, axum::Json<jw::JweBody>), ContainerError<error::ApiError>> {
let keys = JWEncryption {
private_key: state.config.secrets.locker_private_key,
public_key: state.config.secrets.tenant_public_key,
encryption_algo: jwe::RSA_OAEP,
decryption_algo: jwe::RSA_OAEP_256,
};

let jwe_decrypted = keys.decrypt(request_body.to_vec())?;
let jwe_decrypted = keys.decrypt(jwe_body)?;

let next_layer_payload = Request::from_parts(parts, Body::from(jwe_decrypted));

let response = next.run(next_layer_payload).await;

let (parts, body) = response.into_parts();
let (mut parts, body) = next.run(next_layer_payload).await.into_parts();

let response_body = hyper::body::to_bytes(body).await.change_error(
error::ApiError::ResponseMiddlewareError("Failed to read response body for jws signing"),
)?;
let response_body = body
.collect()
.await
.change_error(error::ApiError::ResponseMiddlewareError(
"Failed to read response body for jws signing",
))?
.to_bytes();

let jws_signed = keys.encrypt(response_body.to_vec())?;
let jwe_payload = keys.encrypt(response_body.to_vec())?;

let jwt = String::from_utf8(jws_signed).change_error(
error::ApiError::ResponseMiddlewareError("Could not convert to UTF-8"),
)?;
parts.headers = hyper::HeaderMap::new();
parts.headers.append(
hyper::header::CONTENT_TYPE,
axum::http::HeaderValue::from_static("application/json"),
);

Ok(axum::http::response::Builder::new()
.status(parts.status)
.body(jwt.map_err(axum::Error::new).boxed_unsync())
.change_context(error::ApiError::ResponseMiddlewareError(
"failed while generating the response",
))?)
Ok((parts, axum::Json(jwe_payload)))
}
21 changes: 15 additions & 6 deletions src/routes/data.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ use axum::middleware;

use masking::ExposeInterface;

use types::StoreCardResponse;

use crate::{
app::AppState,
crypto::{aes::GcmAes256, sha::Sha512},
Expand Down Expand Up @@ -94,7 +96,7 @@ pub async fn add_card(

let optional_hash_table = state.db.find_by_data_hash(&hash_data).await?;

let output = match optional_hash_table {
let (duplication_check, output) = match optional_hash_table {
Some(hash_table) => {
let stored_data = state
.db
Expand All @@ -107,7 +109,10 @@ pub async fn add_card(
)
.await?;

match stored_data {
let duplication_check =
transformers::validate_card_metadata(stored_data.as_ref(), &request.data)?;

let output = match stored_data {
Some(data) => data,
None => {
state
Expand All @@ -123,12 +128,14 @@ pub async fn add_card(
)
.await?
}
}
};

(duplication_check, output)
}
None => {
let hash_table = state.db.insert_hash(hash_data).await?;

state
let output = state
.db
.insert_or_get_from_locker(
(
Expand All @@ -139,11 +146,13 @@ pub async fn add_card(
.try_into()?,
&merchant_dek,
)
.await?
.await?;

(None, output)
}
};

Ok(Json(output.into()))
Ok(Json(StoreCardResponse::from((duplication_check, output))))
}

/// `/data/delete` handling the requirement of deleting cards
Expand Down
Loading

0 comments on commit 2861935

Please sign in to comment.