diff --git a/crates/aide/src/axum/mod.rs b/crates/aide/src/axum/mod.rs index ba27822..7a576ee 100644 --- a/crates/aide/src/axum/mod.rs +++ b/crates/aide/src/axum/mod.rs @@ -192,6 +192,9 @@ use indexmap::IndexMap; use tower_layer::Layer; use tower_service::Service; +#[cfg(feature = "axum-extra")] +use axum_extra::routing::RouterExt as AxumExtraRouterExt; + use crate::{ transform::{TransformOpenApi, TransformPathItem}, util::path_colon_params, @@ -307,6 +310,29 @@ where self } + #[cfg(feature = "axum-extra")] + /// Create a route to the given method router with trailing slash removal and include it in + /// the API documentation. + /// + /// As opposed to [`route_with_tsr`](crate::axum::ApiRouter::route_with_tsr), this method only accepts an [`ApiMethodRouter`]. + /// + /// See [`axum_extra::routing::RouterExt::route_with_tsr`] for details. + #[tracing::instrument(skip_all, fields(% path))] + pub fn api_route_with_tsr(mut self, path: &str, mut method_router: ApiMethodRouter) -> Self { + in_context(|ctx| { + let new_path_item = method_router.take_path_item(); + + if let Some(path_item) = self.paths.get_mut(path) { + merge_paths(ctx, path, path_item, new_path_item); + } else { + self.paths.insert(path.into(), new_path_item); + } + }); + + self.router = self.router.route_with_tsr(path, method_router.router); + self + } + /// Create a route to the given method router and include it in /// the API documentation. /// @@ -338,6 +364,38 @@ where self } + #[cfg(feature = "axum-extra")] + /// Create a route to the given method router with trailing slash removal and include it in + /// the API documentation. + /// + /// This method accepts a transform function to edit + /// the generated API documentation with. + /// + /// See [`axum_extra::routing::RouterExt::route_with_tsr`] for details. + #[tracing::instrument(skip_all, fields(%path))] + pub fn api_route_with_tsr_with( + mut self, + path: &str, + mut method_router: ApiMethodRouter, + transform: impl FnOnce(TransformPathItem) -> TransformPathItem, + ) -> Self { + in_context(|ctx| { + let mut p = method_router.take_path_item(); + let t = transform(TransformPathItem::new(&mut p)); + + if !t.hidden { + if let Some(path_item) = self.paths.get_mut(path) { + merge_paths(ctx, path, path_item, p); + } else { + self.paths.insert(path.into(), p); + } + } + }); + + self.router = self.router.route_with_tsr(path, method_router.router); + self + } + /// Turn this router into an [`axum::Router`] while merging /// generated documentation into the provided [`OpenApi`]. #[tracing::instrument(skip_all)] @@ -430,6 +488,16 @@ where self.router = self.router.route(path, method_router.into().router); self } + + /// See [`axum_extra::routing::RouterExt::route_with_tsr`] for details. + /// + /// This method accepts [`ApiMethodRouter`] but does not generate API documentation. + #[cfg(feature = "axum-extra")] + #[tracing::instrument(skip_all)] + pub fn route_with_tsr(mut self, path: &str, method_router: impl Into>) -> Self { + self.router = self.router.route(path, method_router.into().router); + self + } /// See [`axum::Router::route_service`] for details. #[tracing::instrument(skip_all)] @@ -442,6 +510,19 @@ where self.router = self.router.route_service(path, service); self } + + /// See [`axum_extra::routing::RouterExt::route_service_with_tsr`] for details. + #[cfg(feature = "axum-extra")] + #[tracing::instrument(skip_all)] + pub fn route_service_with_tsr(mut self, path: &str, service: T) -> Self + where + T: Service + Clone + Send + 'static, + T::Response: IntoResponse, + T::Future: Send + 'static, + Self: Sized { + self.router = self.router.route_service_with_tsr(path, service); + self + } /// See [`axum::Router::nest`] for details. /// @@ -644,6 +725,9 @@ pub trait RouterExt: private::Sealed + Sized { /// /// This method additionally turns the router into an [`ApiRouter`]. fn api_route(self, path: &str, method_router: ApiMethodRouter) -> ApiRouter; + #[cfg(feature = "axum-extra")] + /// Add an API route, see [`ApiRouter::api_route_with_tsr`](crate::axum::ApiRouter::api_route_with_tsr) + fn api_route_with_tsr(self, path: &str, method_router: ApiMethodRouter) -> ApiRouter; } impl RouterExt for Router @@ -659,6 +743,12 @@ where fn api_route(self, path: &str, method_router: ApiMethodRouter) -> ApiRouter { ApiRouter::from(self).api_route(path, method_router) } + + #[cfg(feature = "axum-extra")] + #[tracing::instrument(skip_all)] + fn api_route_with_tsr(self, path: &str, method_router: ApiMethodRouter) -> ApiRouter { + ApiRouter::from(self).api_route_with_tsr(path, method_router) + } } impl private::Sealed for Router {}