Skip to content

Commit

Permalink
Add test for JWKS producer (#20)
Browse files Browse the repository at this point in the history
In addition, I've renamed the list of consumers, and a bunch of
arguments to better reflect the naming of the traits themselves
(Producer/Consumer).
  • Loading branch information
Dunklas authored Oct 31, 2024
1 parent e26f8dc commit 0ce8693
Show file tree
Hide file tree
Showing 2 changed files with 97 additions and 7 deletions.
102 changes: 96 additions & 6 deletions tower-oauth2-resource-server/src/jwks.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ use tokio::time;
use crate::error::JwkError;

pub trait JwksProducer {
fn add_receiver(&mut self, receiver: Arc<dyn JwksConsumer>);
fn add_consumer(&mut self, receiver: Arc<dyn JwksConsumer>);
fn start(&self);
}

Expand All @@ -34,8 +34,8 @@ impl TimerJwksProducer {
}

impl JwksProducer for TimerJwksProducer {
fn add_receiver(&mut self, receiver: Arc<dyn JwksConsumer>) {
self.receivers.push(receiver);
fn add_consumer(&mut self, consumer: Arc<dyn JwksConsumer>) {
self.receivers.push(consumer);
}

fn start(&self) {
Expand All @@ -50,15 +50,15 @@ impl JwksProducer for TimerJwksProducer {
async fn fetch_jwks_job(
jwks_url: Url,
refresh_interval: Duration,
receivers: Vec<Arc<dyn JwksConsumer>>,
consumers: Vec<Arc<dyn JwksConsumer>>,
) {
let mut interval = time::interval(refresh_interval);
loop {
interval.tick().await;
match fetch_jwks(jwks_url.clone()).await {
Ok(jwks) => {
for receiver in &receivers {
receiver.receive_jwks(jwks.clone()).await;
for consumer in &consumers {
consumer.receive_jwks(jwks.clone()).await;
}
}
Err(e) => {
Expand All @@ -78,3 +78,93 @@ async fn fetch_jwks(jwks_url: Url) -> Result<JwkSet, JwkError> {
.map_err(|_| JwkError::ParseFailed)?;
Ok(parsed)
}

#[cfg(test)]
mod tests {
use std::time::Instant;

use base64::{
alphabet,
engine::{general_purpose, GeneralPurpose},
Engine,
};
use jsonwebtoken::jwk::Jwk;
use rsa::{traits::PublicKeyParts, RsaPrivateKey, RsaPublicKey};
use serde_json::json;
use tokio::sync::RwLock;
use wiremock::{
matchers::{method, path},
Mock, MockServer, ResponseTemplate,
};

use super::*;

struct TestConsumer {
jwks: Arc<RwLock<Option<JwkSet>>>,
}

impl TestConsumer {
pub fn new() -> Self {
Self {
jwks: Arc::new(RwLock::new(None)),
}
}
pub async fn has_jwks(&self) -> bool {
self.jwks.read().await.is_some()
}
}

#[async_trait]
impl JwksConsumer for TestConsumer {
async fn receive_jwks(&self, jwks: JwkSet) {
self.jwks.write().await.replace(jwks);
}
}

#[tokio::test]
async fn test_should_notify_consumers() {
let mock_server = MockServer::start().await;
mock_jwks(&mock_server, "/jwks.json").await;

let consumer = Arc::new(TestConsumer::new());
let mut producer = TimerJwksProducer::new(
format!("{}/jwks.json", &mock_server.uri())
.parse::<Url>()
.unwrap(),
Duration::from_millis(5),
);
producer.add_consumer(consumer.clone());
producer.start();

let mut success = false;
let start = Instant::now();
while start.elapsed() < Duration::from_millis(500) {
if consumer.has_jwks().await {
success = true;
break;
}
tokio::time::sleep(Duration::from_millis(10)).await;
}
assert!(success, "Consumer did not receive JWKS in time");
}

async fn mock_jwks(server: &MockServer, jwks_path: &str) {
let private = RsaPrivateKey::new(&mut rand::thread_rng(), 2048).unwrap();
let public = RsaPublicKey::from(private);
let base64_engine = GeneralPurpose::new(&alphabet::URL_SAFE, general_purpose::NO_PAD);
let jwk: Jwk = serde_json::from_value(json!({
"kty": "RSA",
"use_": "sig",
"alg": "RS256",
"kid": "test-kid",
"n": base64_engine.encode(public.n().to_bytes_be()),
"e": base64_engine.encode(public.e().to_bytes_be())
}))
.unwrap();
Mock::given(method("GET"))
.and(path(jwks_path))
.respond_with(ResponseTemplate::new(200).set_body_json(JwkSet { keys: vec![jwk] }))
.mount(server)
.await
}
}
2 changes: 1 addition & 1 deletion tower-oauth2-resource-server/src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ where
let validator = Arc::new(OnlyJwtValidator::new(claims_validation_spec));

let mut jwks_producer = TimerJwksProducer::new(jwks_url.clone(), jwk_set_refresh_interval);
jwks_producer.add_receiver(validator.clone());
jwks_producer.add_consumer(validator.clone());
jwks_producer.start();

Ok(OAuth2ResourceServer {
Expand Down

0 comments on commit 0ce8693

Please sign in to comment.