Skip to content

Commit

Permalink
Adding Bearer check
Browse files Browse the repository at this point in the history
  • Loading branch information
sordina committed Sep 13, 2023
1 parent 702f949 commit cbedad7
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 6 deletions.
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion rust-connector-sdk/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ serde = { version = "1.0.164", features = ["derive"] }
serde_json = { version = "1.0.97", features = ["raw_value"] }
thiserror = "1.0"
tokio = { version = "1.28.2", features = ["fs", "signal"] }
tower-http = { version = "0.4.1", features = ["cors", "trace"] }
tower-http = { version = "0.4.1", features = ["cors", "trace", "validate-request"] }
tracing = "0.1.37"
uuid = "1.3.4"
tracing-subscriber = { version = "0.3", default-features = false, features = [
Expand Down
27 changes: 22 additions & 5 deletions rust-connector-sdk/src/default_main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,12 @@ use std::net;

use axum::{
extract::State,
http::StatusCode,
http::{StatusCode, Request, HeaderValue},
routing::{get, post},
Json, Router,
Json, Router, body::Body, response::IntoResponse,
};
use tower_http::validate_request::ValidateRequestHeaderLayer;

use clap::{Parser, Subcommand};
use ndc_client::models::{
CapabilitiesResponse, ErrorResponse, ExplainResponse, MutationRequest, MutationResponse,
Expand Down Expand Up @@ -190,9 +192,24 @@ where

let server_state = init_server_state::<C>(serve_command.configuration).await;

let router = create_router::<C>(server_state).layer(
TraceLayer::new_for_http().make_span_with(DefaultMakeSpan::default().level(Level::INFO)),
);
let expected_auth_header: Option<HeaderValue> = serve_command.service_token_secret.and_then(|service_token_secret| {
let expected_bearer = format!("Bearer {}", service_token_secret); // TODO
HeaderValue::from_str(&expected_bearer).ok()
});

let router = create_router::<C>(server_state)
.layer(
TraceLayer::new_for_http().make_span_with(DefaultMakeSpan::default().level(Level::INFO)),
).layer(
ValidateRequestHeaderLayer::custom(move |request: &mut Request<Body>| {
// Validate the request
let auth_header = request.headers().get("Authorization")
.map(|v| v.clone());

if auth_header == expected_auth_header { return Ok(()); }
Err((StatusCode::UNAUTHORIZED, "").into_response())
})
);

let port = serve_command.port;
let address = net::SocketAddr::new(net::IpAddr::V4(net::Ipv4Addr::UNSPECIFIED), port);
Expand Down

0 comments on commit cbedad7

Please sign in to comment.