diff --git a/tower-oauth2-resource-server/src/jwks.rs b/tower-oauth2-resource-server/src/jwks.rs index a0e92ef..90d3d8d 100644 --- a/tower-oauth2-resource-server/src/jwks.rs +++ b/tower-oauth2-resource-server/src/jwks.rs @@ -8,7 +8,7 @@ use tokio::time; use crate::error::JwkError; pub trait JwksProducer { - fn add_receiver(&mut self, receiver: Arc); + fn add_consumer(&mut self, receiver: Arc); fn start(&self); } @@ -34,8 +34,8 @@ impl TimerJwksProducer { } impl JwksProducer for TimerJwksProducer { - fn add_receiver(&mut self, receiver: Arc) { - self.receivers.push(receiver); + fn add_consumer(&mut self, consumer: Arc) { + self.receivers.push(consumer); } fn start(&self) { @@ -50,15 +50,15 @@ impl JwksProducer for TimerJwksProducer { async fn fetch_jwks_job( jwks_url: Url, refresh_interval: Duration, - receivers: Vec>, + consumers: Vec>, ) { let mut interval = time::interval(refresh_interval); loop { interval.tick().await; match fetch_jwks(jwks_url.clone()).await { Ok(jwks) => { - for receiver in &receivers { - receiver.receive_jwks(jwks.clone()).await; + for consumer in &consumers { + consumer.receive_jwks(jwks.clone()).await; } } Err(e) => { @@ -78,3 +78,93 @@ async fn fetch_jwks(jwks_url: Url) -> Result { .map_err(|_| JwkError::ParseFailed)?; Ok(parsed) } + +#[cfg(test)] +mod tests { + use std::time::Instant; + + use base64::{ + alphabet, + engine::{general_purpose, GeneralPurpose}, + Engine, + }; + use jsonwebtoken::jwk::Jwk; + use rsa::{traits::PublicKeyParts, RsaPrivateKey, RsaPublicKey}; + use serde_json::json; + use tokio::sync::RwLock; + use wiremock::{ + matchers::{method, path}, + Mock, MockServer, ResponseTemplate, + }; + + use super::*; + + struct TestConsumer { + jwks: Arc>>, + } + + impl TestConsumer { + pub fn new() -> Self { + Self { + jwks: Arc::new(RwLock::new(None)), + } + } + pub async fn has_jwks(&self) -> bool { + self.jwks.read().await.is_some() + } + } + + #[async_trait] + impl JwksConsumer for TestConsumer { + async fn receive_jwks(&self, jwks: JwkSet) { + self.jwks.write().await.replace(jwks); + } + } + + #[tokio::test] + async fn test_should_notify_consumers() { + let mock_server = MockServer::start().await; + mock_jwks(&mock_server, "/jwks.json").await; + + let consumer = Arc::new(TestConsumer::new()); + let mut producer = TimerJwksProducer::new( + format!("{}/jwks.json", &mock_server.uri()) + .parse::() + .unwrap(), + Duration::from_millis(5), + ); + producer.add_consumer(consumer.clone()); + producer.start(); + + let mut success = false; + let start = Instant::now(); + while start.elapsed() < Duration::from_millis(500) { + if consumer.has_jwks().await { + success = true; + break; + } + tokio::time::sleep(Duration::from_millis(10)).await; + } + assert!(success, "Consumer did not receive JWKS in time"); + } + + async fn mock_jwks(server: &MockServer, jwks_path: &str) { + let private = RsaPrivateKey::new(&mut rand::thread_rng(), 2048).unwrap(); + let public = RsaPublicKey::from(private); + let base64_engine = GeneralPurpose::new(&alphabet::URL_SAFE, general_purpose::NO_PAD); + let jwk: Jwk = serde_json::from_value(json!({ + "kty": "RSA", + "use_": "sig", + "alg": "RS256", + "kid": "test-kid", + "n": base64_engine.encode(public.n().to_bytes_be()), + "e": base64_engine.encode(public.e().to_bytes_be()) + })) + .unwrap(); + Mock::given(method("GET")) + .and(path(jwks_path)) + .respond_with(ResponseTemplate::new(200).set_body_json(JwkSet { keys: vec![jwk] })) + .mount(server) + .await + } +} diff --git a/tower-oauth2-resource-server/src/server.rs b/tower-oauth2-resource-server/src/server.rs index c67bfa1..e6fa5ba 100644 --- a/tower-oauth2-resource-server/src/server.rs +++ b/tower-oauth2-resource-server/src/server.rs @@ -57,7 +57,7 @@ where let validator = Arc::new(OnlyJwtValidator::new(claims_validation_spec)); let mut jwks_producer = TimerJwksProducer::new(jwks_url.clone(), jwk_set_refresh_interval); - jwks_producer.add_receiver(validator.clone()); + jwks_producer.add_consumer(validator.clone()); jwks_producer.start(); Ok(OAuth2ResourceServer {