Skip to content

Commit

Permalink
Put JWT extraction in separate module (#19)
Browse files Browse the repository at this point in the history
In addition:

- Add tests for JWT extraction
- Write a debug log when JWT extraction fails (it's logged when
validation fails, extraction should be consistent)
- Rename JWT module that performs validation
  • Loading branch information
Dunklas authored Oct 30, 2024
1 parent f5bd1ff commit e26f8dc
Show file tree
Hide file tree
Showing 5 changed files with 89 additions and 24 deletions.
14 changes: 14 additions & 0 deletions tower-oauth2-resource-server/src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -55,3 +55,17 @@ impl Display for StartupError {
}
}
impl Error for StartupError {}

impl Display for JwkError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{:?}", self)
}
}
impl Error for JwkError {}

impl Display for AuthError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{:?}", self)
}
}
impl Error for AuthError {}
63 changes: 63 additions & 0 deletions tower-oauth2-resource-server/src/jwt_extract.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
use http::HeaderMap;

use crate::error::AuthError;

pub trait JwtExtractor {
fn extract_jwt(&self, headers: &HeaderMap) -> Result<String, AuthError>;
}

pub struct BearerTokenJwtExtractor;

impl JwtExtractor for BearerTokenJwtExtractor {
fn extract_jwt(&self, headers: &HeaderMap) -> Result<String, AuthError> {
Ok(headers
.get(http::header::AUTHORIZATION)
.ok_or(AuthError::MissingAuthorizationHeader)?
.to_str()
.map_err(|_| AuthError::InvalidAuthorizationHeader)?
.strip_prefix("Bearer ")
.ok_or(AuthError::InvalidAuthorizationHeader)?
.to_owned())
}
}

#[cfg(test)]
mod tests {
use http::HeaderValue;

use super::*;

#[test]
fn test_missing_authorization() {
let headers = HeaderMap::new();
let result = BearerTokenJwtExtractor {}.extract_jwt(&headers);

assert!(result.is_err());
assert_eq!(result.unwrap_err(), AuthError::MissingAuthorizationHeader);
}

#[test]
fn test_missing_bearer_prefix() {
let mut headers = HeaderMap::new();
headers.insert(
"Authorization",
HeaderValue::from_str("Boarer XXX").unwrap(),
);
let result = BearerTokenJwtExtractor {}.extract_jwt(&headers);

assert!(result.is_err());
assert_eq!(result.unwrap_err(), AuthError::InvalidAuthorizationHeader);
}

#[test]
fn test_ok() {
let mut headers = HeaderMap::new();
headers.insert(
"Authorization",
HeaderValue::from_str("Bearer XXX").unwrap(),
);
let result = BearerTokenJwtExtractor {}.extract_jwt(&headers);

assert!(result.is_ok());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ use std::{
};

use async_trait::async_trait;
use http::HeaderMap;
use jsonwebtoken::{
decode, decode_header,
jwk::{AlgorithmParameters, Jwk, JwkSet, KeyAlgorithm},
Expand All @@ -20,25 +19,6 @@ use crate::{
validation::ClaimsValidationSpec,
};

pub trait JwtExtractor {
fn extract_jwt(&self, headers: &HeaderMap) -> Result<String, AuthError>;
}

pub struct BearerTokenJwtExtractor;

impl JwtExtractor for BearerTokenJwtExtractor {
fn extract_jwt(&self, headers: &HeaderMap) -> Result<String, AuthError> {
Ok(headers
.get(http::header::AUTHORIZATION)
.ok_or(AuthError::MissingAuthorizationHeader)?
.to_str()
.map_err(|_| AuthError::InvalidAuthorizationHeader)?
.strip_prefix("Bearer ")
.ok_or(AuthError::InvalidAuthorizationHeader)?
.to_owned())
}
}

#[async_trait]
pub trait JwtValidator<Claims> {
async fn validate(&self, jwt: &str) -> Result<Claims, AuthError>;
Expand Down
3 changes: 2 additions & 1 deletion tower-oauth2-resource-server/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,5 +6,6 @@ pub mod validation;

mod error;
mod jwks;
mod jwt;
mod jwt_extract;
mod jwt_validate;
mod oidc;
13 changes: 10 additions & 3 deletions tower-oauth2-resource-server/src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@ use crate::{
claims::DefaultClaims,
error::{AuthError, StartupError},
jwks::{JwksProducer, TimerJwksProducer},
jwt::{BearerTokenJwtExtractor, JwtExtractor, JwtValidator, OnlyJwtValidator},
jwt_extract::{BearerTokenJwtExtractor, JwtExtractor},
jwt_validate::{JwtValidator, OnlyJwtValidator},
layer::OAuth2ResourceServerLayer,
validation::ClaimsValidationSpec,
};
Expand Down Expand Up @@ -70,15 +71,21 @@ where
&self,
mut request: Request<Body>,
) -> Result<Request<Body>, AuthError> {
let token = self.jwt_extractor.extract_jwt(request.headers())?;
let token = match self.jwt_extractor.extract_jwt(request.headers()) {
Ok(token) => token,
Err(e) => {
debug!("JWT extraction failed: {}", e);
return Err(e);
}
};
match self.jwt_validator.validate(&token).await {
Ok(res) => {
debug!("JWT validation successful");
request.extensions_mut().insert(res);
Ok(request)
}
Err(e) => {
debug!("JWT validation failed due to: {:?}", e);
debug!("JWT validation failed: {}", e);
Err(e)
}
}
Expand Down

0 comments on commit e26f8dc

Please sign in to comment.