From 92cb6a8bbc9baf5cefb5e693a1cf2d875fe6a1df Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9mi=20Bardon?= Date: Tue, 13 Aug 2024 21:58:47 +0200 Subject: [PATCH] =?UTF-8?q?=F0=9F=9A=A7=20WIP=20Migrate=20from=20Rocket=20?= =?UTF-8?q?to=20Axum?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/orangutan-server/Cargo.toml | 10 +- src/orangutan-server/src/main.rs | 125 +++++----- src/orangutan-server/src/request_guards.rs | 189 ++++++++++++++-- .../src/routes/auth_routes.rs | 177 --------------- .../src/routes/debug_routes.rs | 79 +++---- src/orangutan-server/src/routes/main_route.rs | 214 ++++-------------- src/orangutan-server/src/routes/mod.rs | 16 +- .../src/routes/update_content_routes.rs | 22 +- src/orangutan-server/src/util/mod.rs | 44 +++- src/orangutan-server/src/util/templating.rs | 9 +- src/orangutan-server/src/util/website_root.rs | 40 ++-- 11 files changed, 404 insertions(+), 521 deletions(-) delete mode 100644 src/orangutan-server/src/routes/auth_routes.rs diff --git a/src/orangutan-server/Cargo.toml b/src/orangutan-server/Cargo.toml index af180f2..66ef348 100644 --- a/src/orangutan-server/Cargo.toml +++ b/src/orangutan-server/Cargo.toml @@ -1,23 +1,29 @@ [package] name = "orangutan-server" -version = "0.4.13" +version = "0.5.0" edition = "2021" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [dependencies] +axum = { workspace = true } +axum-extra = { workspace = true } base64 = { workspace = true } biscuit-auth = { workspace = true } chrono = { workspace = true } hex = { workspace = true } lazy_static = { workspace = true } +mime = { workspace = true } orangutan-helpers = { path = "../helpers" } orangutan-refresh-token = { path = "../orangutan-refresh-token" } -rocket = { workspace = true } +serde = { workspace = true } serde_json = { workspace = true } tera = { workspace = true, optional = true } thiserror = { workspace = true } time = { workspace = true } +tokio = { workspace = true } +tower = { workspace = true } +tower-http = { workspace = true } tracing = { workspace = true } tracing-subscriber = { workspace = true } urlencoding = { workspace = true } diff --git a/src/orangutan-server/src/main.rs b/src/orangutan-server/src/main.rs index ecfefe5..81fe42f 100644 --- a/src/orangutan-server/src/main.rs +++ b/src/orangutan-server/src/main.rs @@ -3,24 +3,22 @@ mod request_guards; mod routes; mod util; -use object_reader::ObjectReader; +use axum::{ + http::{Response, StatusCode}, + response::IntoResponse, + routing::get, + Router, +}; use orangutan_helpers::{ generate::{self, *}, - readers::object_reader, website_id::WebsiteId, }; -use rocket::{ - catch, catchers, - fairing::AdHoc, - fs::NamedFile, - http::Status, - response::{self, Responder}, - Request, -}; -use routes::auth_routes::REVOKED_TOKENS; +use request_guards::{handle_refresh_token, REVOKED_TOKENS}; +use tera::Tera; +use tower_http::{services::ServeFile, trace::TraceLayer}; #[cfg(feature = "templating")] use tracing::debug; -use tracing::warn; +use tracing::{info, warn}; use tracing_subscriber::{EnvFilter, FmtSubscriber}; #[cfg(feature = "templating")] @@ -31,45 +29,50 @@ use crate::{ util::error, }; -#[rocket::launch] -fn rocket() -> _ { - let rocket = rocket::build() - .mount("/", routes::routes()) - .register("/", catchers![unauthorized, forbidden, not_found]) - .manage(ObjectReader::new()) - .attach(AdHoc::on_ignite("Tracing subsciber", |rocket| async move { - let subscriber = FmtSubscriber::builder() - .with_env_filter(EnvFilter::from_default_env()) - .finish(); - tracing::subscriber::set_global_default(subscriber) - .expect("Failed to set tracing subscriber."); - rocket - })) - .attach(AdHoc::on_liftoff("Website generation", |rocket| { - Box::pin(async move { - if let Err(err) = liftoff() { - // We drop the error to get a Rocket-formatted panic. - drop(err); - rocket.shutdown().notify(); - } - }) - })); +#[derive(Clone, Default)] +struct AppState { + #[cfg(feature = "templating")] + tera: Tera, +} + +#[tokio::main] +async fn main() { + // build our application with a single route + let app = Router::new().nest("/", routes::router()).layer( + ServiceBuilder::new() + .layer(TraceLayer::new_for_http()) + .layer(handle_refresh_token), + ); + // .register("/", catchers![unauthorized, forbidden, not_found]) + + let mut app_state = AppState::default(); + + info!("Setting up tracing…"); + let subscriber = FmtSubscriber::builder() + .with_env_filter(EnvFilter::from_default_env()) + .finish(); + tracing::subscriber::set_global_default(subscriber).expect("Failed to set tracing subscriber."); // Add support for templating if needed #[cfg(feature = "templating")] - let rocket = rocket.attach(AdHoc::on_ignite( - "Initialize templating engine", - |rocket| async move { - let mut tera = tera::Tera::default(); - if let Err(err) = tera.add_raw_templates(routes::templates()) { - tracing::error!("{err}"); - std::process::exit(1) - } - rocket.manage(tera) - }, - )); + { + info!("Initializing templating engine…"); + if let Err(err) = app_state.tera.add_raw_templates(routes::templates()) { + tracing::error!("{err}"); + std::process::exit(1) + } + } - rocket + info!("Generating default website"); + if let Err(err) = liftoff() { + panic!("{err}"); + } + + let app = app.with_state(app_state); + + // Run our app with hyper, listening globally on port 8080. + let listener = tokio::net::TcpListener::bind("0.0.0.0:8080").await.unwrap(); + axum::serve(listener, app).await.unwrap(); } fn liftoff() -> Result<(), Error> { @@ -81,20 +84,20 @@ fn liftoff() -> Result<(), Error> { Ok(()) } -#[catch(401)] -async fn unauthorized() -> Result { +// #[catch(401)] +async fn unauthorized() -> Result { not_found().await } -#[catch(403)] +// #[catch(403)] async fn forbidden() -> &'static str { "403 Forbidden. Token revoked." } /// TODO: Re-enable Basic authentication /// (`.raw_header("WWW-Authenticate", "Basic realm=\"This page is protected. Please log in.\"")`). -#[catch(404)] -async fn not_found() -> Result { +// #[catch(404)] +async fn not_found() -> Result { let website_id = WebsiteId::default(); let website_dir = match generate_website_if_needed(&website_id) { Ok(dir) => dir, @@ -104,7 +107,7 @@ async fn not_found() -> Result { }, }; let file_path = website_dir.join(NOT_FOUND_FILE); - NamedFile::open(file_path.clone()).await.map_err(|err| { + ServeFile::open(file_path.clone()).await.map_err(|err| { error(format!( "Could not read \"not found\" file at <{}>: {}", file_path.display(), @@ -118,8 +121,6 @@ async fn not_found() -> Result { enum Error { #[error(transparent)] WebsiteGenerationError(#[from] generate::Error), - #[error(transparent)] - MainRouteError(#[from] main_route::Error), #[error("Could not update content: {0}")] UpdateContentError(#[from] update_content_routes::Error), #[error("Unauthorized")] @@ -137,25 +138,21 @@ enum Error { ClientError(String), } -#[rocket::async_trait] -impl<'r> Responder<'r, 'static> for Error { - fn respond_to( - self, - _: &'r Request<'_>, - ) -> response::Result<'static> { +impl IntoResponse for Error { + fn into_response(self) -> Response { match self { Self::Unauthorized => { warn!("{self}"); - Err(Status::Unauthorized) + StatusCode::UNAUTHORIZED.into() }, #[cfg(feature = "templating")] Self::ClientError(_) => { debug!("{self}"); - Err(Status::BadRequest) + StatusCode::BAD_REQUEST.into() }, _ => { error(format!("{self}")); - Err(Status::InternalServerError) + StatusCode::INTERNAL_SERVER_ERROR.into() }, } } diff --git a/src/orangutan-server/src/request_guards.rs b/src/orangutan-server/src/request_guards.rs index 2f1974c..ea931c7 100644 --- a/src/orangutan-server/src/request_guards.rs +++ b/src/orangutan-server/src/request_guards.rs @@ -1,14 +1,30 @@ -use std::ops::Deref; +use std::{ + collections::{HashMap, HashSet}, + ops::Deref, + sync::RwLock, + time::SystemTime, +}; -use biscuit_auth::Biscuit; -use rocket::{http::Status, outcome::Outcome, request, request::FromRequest, Request}; +use axum::{ + extract::{rejection::QueryRejection, FromRequestParts, Query, Request}, + http::{request, HeaderMap, StatusCode, Uri}, + middleware::Next, + response::{IntoResponse, Redirect, Response}, +}; +use axum_extra::extract::PrivateCookieJar; +use biscuit_auth::{macros::authorizer, Biscuit}; +use lazy_static::lazy_static; use tracing::{debug, trace}; use crate::{ config::*, - util::{add_cookie, add_padding, profiles}, + util::{add_cookie, add_padding, error, profiles}, }; +lazy_static! { + pub static ref REVOKED_TOKENS: RwLock>> = RwLock::default(); +} + pub struct Token { pub biscuit: Biscuit, } @@ -27,17 +43,40 @@ impl Deref for Token { } } -#[derive(Debug)] +#[derive(Debug, thiserror::Error)] pub enum TokenError { // TODO: Re-enable Basic authentication // Invalid, + #[error("Invalid query: {0}")] + InvalidQuery(#[from] QueryRejection), + #[error("Unauthorized")] + Unauthorized, +} +impl IntoResponse for TokenError { + fn into_response(self) -> Response { + match self { + Self::InvalidQuery(err) => err.into_response(), + Self::Unauthorized => StatusCode::UNAUTHORIZED.into(), + } + } } -#[rocket::async_trait] -impl<'r> FromRequest<'r> for Token { - type Error = TokenError; +#[axum::async_trait] +impl FromRequestParts for Token +where + S: Send + Sync, +{ + type Rejection = TokenError; - async fn from_request(req: &'r Request<'_>) -> request::Outcome { + async fn from_request_parts( + parts: &mut request::Parts, + state: &S, + ) -> Result { + // if let Some(user_agent) = parts.headers.get(USER_AGENT) { + // Ok(ExtractUserAgent(user_agent.clone())) + // } else { + // Err((StatusCode::BAD_REQUEST, "`User-Agent` header is missing")) + // } let mut biscuit: Option = None; let mut should_save: bool = false; @@ -74,7 +113,8 @@ impl<'r> FromRequest<'r> for Token { } // Check cookies - if let Some(cookie) = req.cookies().get(TOKEN_COOKIE_NAME) { + let cookies = PrivateCookieJar::from_request_parts(parts, state).await?; + if let Some(cookie) = cookies.get(TOKEN_COOKIE_NAME) { debug!("Found token cookie"); let token: &str = cookie.value(); // NOTE: We don't want to send a `Set-Cookie` header after finding a token in a cookie, @@ -87,7 +127,8 @@ impl<'r> FromRequest<'r> for Token { } // Check authorization headers - let authorization_headers: Vec<&str> = req.headers().get("Authorization").collect(); + let headers = HeaderMap::from_request_parts(parts, state).await?; + let authorization_headers: Vec<&str> = headers.get("Authorization").collect(); debug!( "{} 'Authorization' headers provided", authorization_headers.len() @@ -103,10 +144,8 @@ impl<'r> FromRequest<'r> for Token { } // Check query params - if let Some(token) = req - .query_value::(TOKEN_QUERY_PARAM_NAME) - .and_then(Result::ok) - { + let query = Query::>::from_request_parts(parts, state).await?; + if let Some(token) = query.get(TOKEN_QUERY_PARAM_NAME) { debug!("Found token query param"); process_token(&token, "token query param", &mut biscuit, &mut should_save); } @@ -114,15 +153,129 @@ impl<'r> FromRequest<'r> for Token { match biscuit { Some(biscuit) => { if should_save { - add_cookie(&biscuit, req.cookies()); + add_cookie(&biscuit, cookies); } - Outcome::Success(Token { biscuit }) + Ok(Token { biscuit }) }, - None => Outcome::Forward(Status::Unauthorized), + None => Err(TokenError::Unauthorized), } } } +pub fn handle_refresh_token( + uri: Uri, + cookies: PrivateCookieJar, + Query(refresh_token): Query<&str>, + Query(force): Query>, + token: Option, + next: Next, +) -> Result { + // URL-decode the string. + let mut refresh_token: String = urlencoding::decode(refresh_token).unwrap().to_string(); + + // Because tokens can be passed as URL query params, + // they might have the "=" padding characters removed. + // We need to add them back. + refresh_token = add_padding(&refresh_token); + + let refresh_biscuit: Biscuit = match Biscuit::from_base64(refresh_token, ROOT_KEY.public()) { + Ok(biscuit) => biscuit, + Err(err) => { + debug!("Error decoding biscuit from base64: {}", err); + return Err(StatusCode::UNAUTHORIZED); + }, + }; + + // NOTE: This is just a hotfix. I had to quickly revoke a token. I'll improve this one day. + trace!("Checking if refresh token is revoked…"); + trace!( + "Revocation identifiers: {}", + refresh_biscuit + .revocation_identifiers() + .into_iter() + .map(hex::encode) + .collect::>() + .join(", "), + ); + let revoked_id = refresh_biscuit + .revocation_identifiers() + .into_iter() + .collect::>>() + .intersection(&REVOKED_TOKENS.read().unwrap()) + .next() + .cloned(); + if let Some(revoked_id) = revoked_id { + debug!( + "Refresh token has been revoked ({})", + String::from_utf8(revoked_id).unwrap_or("".to_string()), + ); + return Err(StatusCode::FORBIDDEN); + } + + trace!("Checking if refresh token is valid or not"); + let authorizer = authorizer!( + r#" + time({now}); + allow if true; + "#, + now = SystemTime::now(), + ); + if let Err(err) = refresh_biscuit.authorize(&authorizer) { + debug!("Refresh token is invalid: {}", err); + return Err(StatusCode::UNAUTHORIZED); + } + + fn redirect_to_same_page_without_query_param(uri: &Uri) -> Result { + let query_segs: Vec = uri + .query() + .unwrap_or_default() + .raw_segments() + .filter(|s| !s.starts_with(format!("{REFRESH_TOKEN_QUERY_PARAM_NAME}=").as_str())) + .map(ToString::to_string) + .collect(); + match Uri::parse_owned(format!("{}?{}", uri.path(), query_segs.join("&"))) { + Ok(redirect_to) => { + debug!("Redirecting to <{redirect_to}> from <{uri}>…"); + Ok(Redirect::found(redirect_to.path().to_string())) + }, + Err(err) => { + error(format!("{err}")); + Err(StatusCode::InternalServerError) + }, + } + } + + if let Some(token) = token { + if token.profiles().contains(&"*".to_owned()) && !force.unwrap_or(false) { + // NOTE: If a super admin generates an access link and accidentally opens it, + // they loose their super admin profile. Then we must regenerate a super admin + // access link and send it to the super admin's device, which increases the potential + // for such a sensitive link to be intercepted. As a safety measure, we don't do anything + // if a super admin uses a refresh token link. + return redirect_to_same_page_without_query_param(&uri); + } + } + + trace!("Baking new biscuit from refresh token"); + let block_0 = refresh_biscuit.print_block_source(0).unwrap(); + let mut builder = Biscuit::builder(); + builder.add_code(block_0).unwrap(); + let new_biscuit = match builder.build(&ROOT_KEY) { + Ok(biscuit) => biscuit, + Err(err) => { + error(format!("Error: Could not append block to biscuit: {err}")); + return Err(StatusCode::InternalServerError); + }, + }; + debug!("Successfully created new biscuit from refresh token"); + + // Save token to a HTTP Cookie + add_cookie(&new_biscuit, cookies); + + // Redirect to the same page without the refresh token query param + redirect_to_same_page_without_query_param(&uri) +} + fn merge_biscuits( b1: &Biscuit, b2: &Biscuit, diff --git a/src/orangutan-server/src/routes/auth_routes.rs b/src/orangutan-server/src/routes/auth_routes.rs deleted file mode 100644 index 172dc61..0000000 --- a/src/orangutan-server/src/routes/auth_routes.rs +++ /dev/null @@ -1,177 +0,0 @@ -use std::{collections::HashSet, sync::RwLock, time::SystemTime}; - -use biscuit_auth::{macros::authorizer, Biscuit}; -use lazy_static::lazy_static; -use rocket::{ - get, - http::{uri::Origin, CookieJar, Status}, - response::Redirect, - routes, Route, -}; -use tracing::{debug, trace}; - -use crate::{ - config::*, - error, - request_guards::Token, - util::{add_cookie, add_padding}, -}; - -lazy_static! { - pub static ref REVOKED_TOKENS: RwLock>> = RwLock::default(); -} - -pub(super) fn routes() -> Vec { - routes![handle_refresh_token] -} - -#[get("/<_..>?&")] -fn handle_refresh_token( - origin: &Origin, - cookies: &CookieJar<'_>, - refresh_token: &str, - token: Option, - force: Option, -) -> Result { - // URL-decode the string. - let mut refresh_token: String = urlencoding::decode(refresh_token).unwrap().to_string(); - - // Because tokens can be passed as URL query params, - // they might have the "=" padding characters removed. - // We need to add them back. - refresh_token = add_padding(&refresh_token); - - let refresh_biscuit: Biscuit = match Biscuit::from_base64(refresh_token, ROOT_KEY.public()) { - Ok(biscuit) => biscuit, - Err(err) => { - debug!("Error decoding biscuit from base64: {}", err); - return Err(Status::Unauthorized); - }, - }; - - // NOTE: This is just a hotfix. I had to quickly revoke a token. I'll improve this one day. - trace!("Checking if refresh token is revoked…"); - trace!( - "Revocation identifiers: {}", - refresh_biscuit - .revocation_identifiers() - .into_iter() - .map(hex::encode) - .collect::>() - .join(", "), - ); - let revoked_id = refresh_biscuit - .revocation_identifiers() - .into_iter() - .collect::>>() - .intersection(&REVOKED_TOKENS.read().unwrap()) - .next() - .cloned(); - if let Some(revoked_id) = revoked_id { - debug!( - "Refresh token has been revoked ({})", - String::from_utf8(revoked_id).unwrap_or("".to_string()), - ); - return Err(Status::Forbidden); - } - - trace!("Checking if refresh token is valid or not"); - let authorizer = authorizer!( - r#" - time({now}); - allow if true; - "#, - now = SystemTime::now(), - ); - if let Err(err) = refresh_biscuit.authorize(&authorizer) { - debug!("Refresh token is invalid: {}", err); - return Err(Status::Unauthorized); - } - - fn redirect_to_same_page_without_query_param(origin: &Origin) -> Result { - let query_segs: Vec = origin - .query() - .unwrap() - .raw_segments() - .filter(|s| !s.starts_with(format!("{REFRESH_TOKEN_QUERY_PARAM_NAME}=").as_str())) - .map(ToString::to_string) - .collect(); - match Origin::parse_owned(format!("{}?{}", origin.path(), query_segs.join("&"))) { - Ok(redirect_to) => { - debug!("Redirecting to <{redirect_to}> from <{origin}>…"); - Ok(Redirect::found(redirect_to.path().to_string())) - }, - Err(err) => { - error(format!("{err}")); - Err(Status::InternalServerError) - }, - } - } - - if let Some(token) = token { - if token.profiles().contains(&"*".to_owned()) && !force.unwrap_or(false) { - // NOTE: If a super admin generates an access link and accidentally opens it, - // they loose their super admin profile. Then we must regenerate a super admin - // access link and send it to the super admin's device, which increases the potential - // for such a sensitive link to be intercepted. As a safety measure, we don't do anything - // if a super admin uses a refresh token link. - return redirect_to_same_page_without_query_param(origin); - } - } - - trace!("Baking new biscuit from refresh token"); - let block_0 = refresh_biscuit.print_block_source(0).unwrap(); - let mut builder = Biscuit::builder(); - builder.add_code(block_0).unwrap(); - let new_biscuit = match builder.build(&ROOT_KEY) { - Ok(biscuit) => biscuit, - Err(err) => { - error(format!("Error: Could not append block to biscuit: {err}")); - return Err(Status::InternalServerError); - }, - }; - debug!("Successfully created new biscuit from refresh token"); - - // Save token to a HTTP Cookie - add_cookie(&new_biscuit, cookies); - - // Redirect to the same page without the refresh token query param - redirect_to_same_page_without_query_param(origin) -} - -#[cfg(test)] -mod tests { - use super::add_padding; - - #[test] - fn test_base64_padding() { - assert_eq!(add_padding("a"), "a===".to_string()); - assert_eq!(add_padding("ab"), "ab==".to_string()); - assert_eq!(add_padding("abc"), "abc=".to_string()); - assert_eq!(add_padding("abcd"), "abcd".to_string()); - - assert_eq!(add_padding("a==="), "a===".to_string()); - assert_eq!(add_padding("ab=="), "ab==".to_string()); - assert_eq!(add_padding("abc="), "abc=".to_string()); - assert_eq!(add_padding("abcd"), "abcd".to_string()); - } - - // #[test] - // fn test_should_force_token_refresh() { - // assert_eq!(should_force_token_refresh(None), false); - // assert_eq!(should_force_token_refresh(Some(Ok(true))), true); - // assert_eq!(should_force_token_refresh(Some(Ok(false))), false); - // assert_eq!( - // should_force_token_refresh(Some(Err(Errors::new().with_name("yes")))), - // true - // ); - // assert_eq!( - // should_force_token_refresh(Some(Err(Errors::new().with_name("no")))), - // true - // ); - // assert_eq!( - // should_force_token_refresh(Some(Err(Errors::new().with_name("")))), - // true - // ); - // } -} diff --git a/src/orangutan-server/src/routes/debug_routes.rs b/src/orangutan-server/src/routes/debug_routes.rs index 2f558a8..3537c33 100644 --- a/src/orangutan-server/src/routes/debug_routes.rs +++ b/src/orangutan-server/src/routes/debug_routes.rs @@ -1,11 +1,14 @@ use std::sync::{Arc, RwLock}; +use axum::{routing::get, Router}; +use axum_extra::extract::PrivateCookieJar; use chrono::{DateTime, Utc}; use lazy_static::lazy_static; -use rocket::{get, http::CookieJar, routes, Route}; -use super::auth_routes::REVOKED_TOKENS; -use crate::{request_guards::Token, Error}; +use crate::{ + request_guards::{Token, REVOKED_TOKENS}, + Error, +}; lazy_static! { /// A list of runtime errors, used to show error logs in an admin page @@ -19,21 +22,21 @@ lazy_static! { pub(crate) static ref ACCESS_LOGS: Arc>> = Arc::default(); } -pub(super) fn routes() -> Vec { - let routes = routes![ - clear_cookies, - get_user_info, - errors, - access_logs, - revoked_tokens, - ]; +pub(super) fn router() -> Router { + let mut router = Router::new() + .route("/clear-cookies", get(clear_cookies).put(clear_cookies)) + .route("/_info", get(get_user_info)) + .route("/_errors", get(errors)) + .route("/_access-logs", get(access_logs)) + .route("/_revoked-tokens", get(revoked_tokens)); #[cfg(feature = "token-generator")] - let routes = vec![routes, routes![ - token_generator::token_generation_form, - token_generator::generate_token, - ]] - .concat(); - routes + { + router = router.route( + "/_generate-token", + get(token_generator::token_generation_form).post(token_generator::generate_token), + ); + } + router } #[cfg(feature = "templating")] @@ -44,8 +47,7 @@ pub(super) fn templates() -> Vec<(&'static str, &'static str)> { )] } -#[get("/clear-cookies")] -fn clear_cookies(cookies: &CookieJar<'_>) -> &'static str { +fn clear_cookies(cookies: PrivateCookieJar) -> &'static str { for cookie in cookies.iter().map(Clone::clone) { cookies.remove(cookie); } @@ -53,7 +55,6 @@ fn clear_cookies(cookies: &CookieJar<'_>) -> &'static str { "Success" } -#[get("/_info")] fn get_user_info(token: Option) -> String { match token { Some(Token { biscuit, .. }) => format!( @@ -73,7 +74,6 @@ pub struct ErrorLog { pub line: String, } -#[get("/_errors")] fn errors(token: Token) -> Result { if !token.profiles().contains(&"*".to_owned()) { Err(Error::Unauthorized)? @@ -100,7 +100,6 @@ pub struct AccessLog { pub path: String, } -#[get("/_access-logs")] fn access_logs(token: Token) -> Result { if !token.profiles().contains(&"*".to_owned()) { Err(Error::Unauthorized)? @@ -139,7 +138,6 @@ pub fn log_access( }) } -#[get("/_revoked-tokens")] fn revoked_tokens(token: Token) -> Result { if !token.profiles().contains(&"*".to_owned()) { Err(Error::Forbidden)? @@ -156,49 +154,45 @@ fn revoked_tokens(token: Token) -> Result { #[cfg(feature = "token-generator")] pub mod token_generator { + use axum::{extract::State, Form}; + use axum_extra::response::Html; use orangutan_refresh_token::RefreshToken; - use rocket::{ - form::{Form, Strict}, - get, post, - response::content::RawHtml, - FromForm, State, - }; + use serde::Deserialize; use crate::{ context, request_guards::Token, util::{templating::render, WebsiteRoot}, - Error, + AppState, Error, }; fn token_generation_form_( - tera: &State, + tera: &tera::Tera, link: Option, base_url: &str, - ) -> Result, Error> { + ) -> Result, Error> { let html = render( tera, "generate-token.html", context! { page_title: "Access token generator", link, base_url }, )?; - Ok(RawHtml(html)) + Ok(Html(html)) } - #[get("/_generate-token")] pub fn token_generation_form( token: Token, - tera: &State, + State(app_state): State, website_root: WebsiteRoot, - ) -> Result, Error> { + ) -> Result, Error> { if !token.profiles().contains(&"*".to_owned()) { Err(Error::Unauthorized)? } - token_generation_form_(tera, None, &website_root) + token_generation_form_(&app_state.tera, None, &website_root) } - #[derive(FromForm)] + #[derive(Deserialize)] pub struct GenerateTokenForm { ttl: String, name: String, @@ -206,13 +200,12 @@ pub mod token_generator { url: String, } - #[post("/_generate-token", data = "
")] pub fn generate_token( token: Token, - tera: &State, - form: Form>, + State(app_state): State, + Form(form): Form, website_root: WebsiteRoot, - ) -> Result, Error> { + ) -> Result, Error> { if !token.profiles().contains(&"*".to_owned()) { Err(Error::Unauthorized)? } @@ -229,6 +222,6 @@ pub mod token_generator { let token_base64 = token.as_base64()?; let link = format!("{}?refresh_token={token_base64}", form.url); - token_generation_form_(tera, Some(link), &website_root) + token_generation_form_(&app_state.tera, Some(link), &website_root) } } diff --git a/src/orangutan-server/src/routes/main_route.rs b/src/orangutan-server/src/routes/main_route.rs index bd91661..6bf230a 100644 --- a/src/orangutan-server/src/routes/main_route.rs +++ b/src/orangutan-server/src/routes/main_route.rs @@ -1,84 +1,73 @@ -use std::{path::Path, time::SystemTime}; +use std::{path::PathBuf, str::FromStr, time::SystemTime}; +use axum::{ + extract::Path, + http::{header::ACCEPT, HeaderMap}, + routing::get, + Router, +}; use biscuit_auth::macros::authorizer; -use object_reader::{ObjectReader, ReadObjectResponse}; -use orangutan_helpers::{data_file, read_allowed, readers::object_reader, website_id::WebsiteId}; -use rocket::{ - get, - http::{uri::Origin, Accept}, - routes, Route, State, +use mime::Mime; +use orangutan_helpers::{ + page_metadata, + website_id::{website_dir, WebsiteId}, }; +use tower_http::services::{ServeDir, ServeFile}; use tracing::{debug, trace}; use crate::{config::*, request_guards::Token, routes::debug_routes::log_access, util::error}; -pub(super) fn routes() -> Vec { - routes![handle_request] +pub(super) fn router() -> Router { + Router::new() + .route("/", get(handle_request)) + .route("/*path", get(handle_request)) } -#[get("/<_..>")] async fn handle_request( - origin: &Origin<'_>, + Path(path): Path, token: Option, - object_reader: &State, - accept: Option<&Accept>, -) -> Result, crate::Error> { + headers: HeaderMap, +) -> Result, crate::Error> { // FIXME: Handle error - let path = urlencoding::decode(origin.path().as_str()) - .unwrap() - .into_owned(); + let path = path; trace!("GET {}", &path); let user_profiles: Vec = token.as_ref().map(Token::profiles).unwrap_or_default(); debug!("User has profiles {user_profiles:?}"); let website_id = WebsiteId::from(&user_profiles); + let website_dir = website_dir(&website_id); // Log access only if the page is HTML. // WARN: This solution is far from perfect as someone requesting a page without setting the `Accept` header // would not be logged even though they'd get the file back. - if accept.is_some_and(|a| a.media_types().find(|t| t.is_html()).is_some()) { + let accept = headers + .get(ACCEPT) + .map(|value| value.parse::().ok()) + .flatten(); + if accept.is_some_and(|m| m.type_() == mime::HTML) { log_access(user_profiles.to_owned(), path.to_owned()); } - let stored_objects: Vec = - object_reader - .list_objects(&path, &website_id) - .map_err(|err| Error::CannotListObjects { - path: path.to_owned(), - err, - })?; - let Some(object_key) = matching_files(&path, &stored_objects) - .first() - .map(|o| o.to_owned()) - else { - error(format!( - "No file matching '{}' found in stored objects", - &path - )); - return Ok(None); + let page_relpath = PathBuf::from_str(&path).unwrap(); + let Some(page_metadata) = page_metadata(&page_relpath)? else { + // If metadata can't be found, it means it's a static file + trace!("File <{path> did not explicitly allow profiles, serving static file"); + // TODO: Un-hardcode this value. + return ServeDir::new("static") + .not_found_service(ServeFile::new(website_dir.join(NOT_FOUND_FILE))); }; - let allowed_profiles = allowed_profiles(&object_key); - let Some(allowed_profiles) = allowed_profiles else { - // If allowed profiles is empty, it means it's a static file - trace!( - "File <{}> did not explicitly allow profiles, serving static file", - &path - ); + let allowed_profiles = page_metadata.read_allowed; + // debug!( + // "Page <{}> can be read by {}", + // &path, + // allowed_profiles + // .iter() + // .map(|p| format!("'{}'", p)) + // .collect::>() + // .join(", ") + // ); - return Ok(Some( - object_reader.read_object(&object_key, &website_id).await, - )); - }; - debug!( - "Page <{}> can be read by {}", - &path, - allowed_profiles - .iter() - .map(|p| format!("'{}'", p)) - .collect::>() - .join(", ") - ); let mut profile: Option = None; let biscuit = token.map(|t| t.biscuit); for allowed_profile in allowed_profiles { @@ -116,119 +105,10 @@ async fn handle_request( return Ok(None); } - Ok(Some( - object_reader.read_object(object_key, &website_id).await, - )) -} + let page_abspath = website_dir.join(page_metadata.path); -fn allowed_profiles<'r>(path: &String) -> Option> { - let path = path.rsplit_once("@").unwrap_or((path, "")).0; - let data_file = data_file(&Path::new(path).to_path_buf()); - read_allowed(&data_file) -} - -fn matching_files<'a>( - query: &str, - stored_objects: &'a Vec, -) -> Vec<&'a String> { - stored_objects - .into_iter() - .filter(|p| { - let query = query.strip_suffix("index.html").unwrap_or(query); - let Some(mut p) = p.strip_prefix(query) else { - return false; - }; - p = p.trim_start_matches('/'); - p = p.strip_prefix("index.html").unwrap_or(p); - return p.is_empty() || p.starts_with('@'); - }) - .collect() -} - -#[derive(Debug, thiserror::Error)] -pub enum Error { - #[error("Error when listing objects matching '{path}': {err}")] - CannotListObjects { - path: String, - err: object_reader::Error, - }, -} + ServeFile::new(page_abspath); + ServeFile::new(website_dir.join(NOT_FOUND_FILE)); -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn test_index_html() { - let stored_objects = vec![ - "/index.html@_default", - "/whatever/index.html@friends", - "/whatever/index.html@family", - "/whatever/p.html@_default", - "/whatever/index.htmlindex.html@_default", - "/whatever/other-page/index.html@_default", - "/whatever/a/b.html@_default", - ] - .into_iter() - .map(|p| p.to_string()) - .collect::>(); - - assert_eq!(matching_files("", &stored_objects), vec![ - "/index.html@_default", - ]); - assert_eq!(matching_files("/", &stored_objects), vec![ - "/index.html@_default", - ]); - assert_eq!(matching_files("/index.html", &stored_objects), vec![ - "/index.html@_default", - ]); - - assert_eq!(matching_files("/whatever", &stored_objects), vec![ - "/whatever/index.html@friends", - "/whatever/index.html@family", - ]); - assert_eq!(matching_files("/whatever/", &stored_objects), vec![ - "/whatever/index.html@friends", - "/whatever/index.html@family", - ]); - assert_eq!( - matching_files("/whatever/index.html", &stored_objects), - vec![ - "/whatever/index.html@friends", - "/whatever/index.html@family", - ] - ); - - assert_eq!( - matching_files("/whatever/a", &stored_objects), - Vec::<&str>::new() - ); - assert_eq!( - matching_files("/whatever/a/b", &stored_objects), - Vec::<&str>::new() - ); - assert_eq!(matching_files("/whatever/a/b.html", &stored_objects), vec![ - "/whatever/a/b.html@_default", - ]); - } - - #[test] - fn test_other_extensions() { - let stored_objects = vec![ - "/style.css@_default", - "/anything.custom@friends", - "/anything.custom@family", - ] - .into_iter() - .map(|p| p.to_string()) - .collect::>(); - - assert_eq!(matching_files("/style.css", &stored_objects), vec![ - "/style.css@_default", - ]); - assert_eq!(matching_files("/anything.custom", &stored_objects), vec![ - "/anything.custom@friends", - "/anything.custom@family", - ]); - } + panic!() } diff --git a/src/orangutan-server/src/routes/mod.rs b/src/orangutan-server/src/routes/mod.rs index e59d060..e3609ce 100644 --- a/src/orangutan-server/src/routes/mod.rs +++ b/src/orangutan-server/src/routes/mod.rs @@ -3,21 +3,17 @@ // Copyright: 2023–2024, Rémi Bardon // License: Mozilla Public License v2.0 (MPL v2.0) -pub mod auth_routes; pub mod debug_routes; pub mod main_route; pub mod update_content_routes; -use rocket::Route; +use axum::Router; -pub(super) fn routes() -> Vec { - vec![ - main_route::routes(), - auth_routes::routes(), - update_content_routes::routes(), - debug_routes::routes(), - ] - .concat() +pub(super) fn router() -> Router { + Router::new() + .merge(main_route::router()) + .merge(update_content_routes::router()) + .merge(debug_routes::router()) } #[cfg(feature = "templating")] diff --git a/src/orangutan-server/src/routes/update_content_routes.rs b/src/orangutan-server/src/routes/update_content_routes.rs index 276485e..6091db5 100644 --- a/src/orangutan-server/src/routes/update_content_routes.rs +++ b/src/orangutan-server/src/routes/update_content_routes.rs @@ -1,16 +1,16 @@ +use axum::{extract::Path, http::StatusCode, routing::post, Router}; use orangutan_helpers::generate::{self, *}; -use rocket::{post, response::status::BadRequest, routes, Route}; -use super::auth_routes::REVOKED_TOKENS; -use crate::error; +use crate::{error, request_guards::REVOKED_TOKENS}; -pub(super) fn routes() -> Vec { - routes![update_content_github, update_content_other] +pub(super) fn router() -> Router { + Router::new() + .route("/update-content/github", post(update_content_github)) + .route("/update-content/:source", post(update_content_other)) } /// TODO: [Validate webhook deliveries](https://docs.github.com/en/webhooks/using-webhooks/validating-webhook-deliveries#validating-webhook-deliveries) -#[post("/update-content/github")] -fn update_content_github() -> Result<(), crate::Error> { +async fn update_content_github() -> Result<(), crate::Error> { // Update repository pull_repository().map_err(Error::CannotPullOutdatedRepository)?; @@ -34,9 +34,11 @@ fn update_content_github() -> Result<(), crate::Error> { Ok(()) } -#[post("/update-content/")] -fn update_content_other(source: &str) -> BadRequest { - BadRequest(format!("Source '{source}' is not supported.")) +async fn update_content_other(Path(source): Path<&str>) -> (StatusCode, String) { + ( + StatusCode::BAD_REQUEST, + format!("Source '{source}' is not supported."), + ) } #[derive(Debug, thiserror::Error)] diff --git a/src/orangutan-server/src/util/mod.rs b/src/orangutan-server/src/util/mod.rs index 478e7e2..aac9227 100644 --- a/src/orangutan-server/src/util/mod.rs +++ b/src/orangutan-server/src/util/mod.rs @@ -3,12 +3,15 @@ pub mod templating; #[cfg(feature = "token-generator")] mod website_root; +use axum_extra::extract::{ + cookie::{Cookie, SameSite}, + PrivateCookieJar, +}; use biscuit_auth::{ builder::{Fact, Term}, Biscuit, }; use chrono::Utc; -use rocket::http::{Cookie, CookieJar, SameSite}; use time::Duration; use tracing::error; @@ -58,7 +61,7 @@ pub fn add_padding(base64_string: &str) -> String { pub fn add_cookie( biscuit: &Biscuit, - cookies: &CookieJar<'_>, + cookies: PrivateCookieJar, ) { match biscuit.to_base64() { Ok(base64) => { @@ -76,3 +79,40 @@ pub fn add_cookie( }, } } + +#[cfg(test)] +mod tests { + use super::add_padding; + + #[test] + fn test_base64_padding() { + assert_eq!(add_padding("a"), "a===".to_string()); + assert_eq!(add_padding("ab"), "ab==".to_string()); + assert_eq!(add_padding("abc"), "abc=".to_string()); + assert_eq!(add_padding("abcd"), "abcd".to_string()); + + assert_eq!(add_padding("a==="), "a===".to_string()); + assert_eq!(add_padding("ab=="), "ab==".to_string()); + assert_eq!(add_padding("abc="), "abc=".to_string()); + assert_eq!(add_padding("abcd"), "abcd".to_string()); + } + + // #[test] + // fn test_should_force_token_refresh() { + // assert_eq!(should_force_token_refresh(None), false); + // assert_eq!(should_force_token_refresh(Some(Ok(true))), true); + // assert_eq!(should_force_token_refresh(Some(Ok(false))), false); + // assert_eq!( + // should_force_token_refresh(Some(Err(Errors::new().with_name("yes")))), + // true + // ); + // assert_eq!( + // should_force_token_refresh(Some(Err(Errors::new().with_name("no")))), + // true + // ); + // assert_eq!( + // should_force_token_refresh(Some(Err(Errors::new().with_name("")))), + // true + // ); + // } +} diff --git a/src/orangutan-server/src/util/templating.rs b/src/orangutan-server/src/util/templating.rs index c4258d7..40ddda8 100644 --- a/src/orangutan-server/src/util/templating.rs +++ b/src/orangutan-server/src/util/templating.rs @@ -1,6 +1,3 @@ -use rocket::serde::Serialize; -use tera::Context; - use super::error; #[derive(Debug, thiserror::Error)] @@ -11,12 +8,12 @@ pub enum Error { RenderError(tera::Error), } -pub fn render( +pub fn render( tera: &tera::Tera, template: &str, context: C, ) -> Result { - let tera_ctx = Context::from_serialize(context).map_err(Error::ContextError)?; + let tera_ctx = tera::Context::from_serialize(context).map_err(Error::ContextError)?; tera.render(template, &tera_ctx).map_err(Error::RenderError) } @@ -90,7 +87,7 @@ pub fn render( #[macro_export] macro_rules! context { ($($key:ident $(: $value:expr)?),*$(,)?) => {{ - use rocket::serde::ser::{Serialize, Serializer, SerializeMap}; + use serde::ser::{Serialize, Serializer, SerializeMap}; use ::std::fmt::{Debug, Formatter}; use ::std::result::Result; diff --git a/src/orangutan-server/src/util/website_root.rs b/src/orangutan-server/src/util/website_root.rs index a0b41be..19a1c34 100644 --- a/src/orangutan-server/src/util/website_root.rs +++ b/src/orangutan-server/src/util/website_root.rs @@ -1,10 +1,6 @@ use std::ops::Deref; use lazy_static::lazy_static; -use rocket::{ - request::{FromRequest, Outcome}, - Ignite, Request, Rocket, -}; use tracing::error; lazy_static! { @@ -21,21 +17,21 @@ impl Deref for WebsiteRoot { } } -#[rocket::async_trait] -impl<'r> FromRequest<'r> for WebsiteRoot { - type Error = &'static str; - - async fn from_request(_req: &'r Request<'_>) -> Outcome { - Outcome::Success(Self(WEBSITE_ROOT.to_owned())) - } -} - -impl rocket::Sentinel for WebsiteRoot { - fn abort(_rocket: &Rocket) -> bool { - if WEBSITE_ROOT.is_empty() { - error!("Environment variable `WEBSITE_ROOT` not found."); - return true; - } - false - } -} +warn!("TODO"); +// impl<'r> FromRequest<'r> for WebsiteRoot { +// type Error = &'static str; + +// async fn from_request(_req: &'r Request<'_>) -> Outcome { +// Outcome::Success(Self(WEBSITE_ROOT.to_owned())) +// } +// } + +// impl rocket::Sentinel for WebsiteRoot { +// fn abort(_rocket: &Rocket) -> bool { +// if WEBSITE_ROOT.is_empty() { +// error!("Environment variable `WEBSITE_ROOT` not found."); +// return true; +// } +// false +// } +// }