axum/extract/
nested_path.rs1use std::{
2 sync::Arc,
3 task::{Context, Poll},
4};
5
6use crate::extract::Request;
7use axum_core::extract::FromRequestParts;
8use http::request::Parts;
9use tower_layer::{layer_fn, Layer};
10use tower_service::Service;
11
12use super::rejection::NestedPathRejection;
13
14#[derive(Debug, Clone)]
40pub struct NestedPath(Arc<str>);
41
42impl NestedPath {
43 #[must_use]
45 pub fn as_str(&self) -> &str {
46 &self.0
47 }
48}
49
50#[diagnostic::do_not_recommend] impl<S> FromRequestParts<S> for NestedPath
52where
53 S: Send + Sync,
54{
55 type Rejection = NestedPathRejection;
56
57 async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
58 match parts.extensions.get::<Self>() {
59 Some(nested_path) => Ok(nested_path.clone()),
60 None => Err(NestedPathRejection),
61 }
62 }
63}
64
65#[derive(Clone)]
66pub(crate) struct SetNestedPath<S> {
67 inner: S,
68 path: Arc<str>,
69}
70
71impl<S> SetNestedPath<S> {
72 pub(crate) fn layer(path: &str) -> impl Layer<S, Service = Self> + Clone {
73 let path = Arc::from(path);
74 layer_fn(move |inner| Self {
75 inner,
76 path: Arc::clone(&path),
77 })
78 }
79}
80
81impl<S, B> Service<Request<B>> for SetNestedPath<S>
82where
83 S: Service<Request<B>>,
84{
85 type Response = S::Response;
86 type Error = S::Error;
87 type Future = S::Future;
88
89 #[inline]
90 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
91 self.inner.poll_ready(cx)
92 }
93
94 fn call(&mut self, mut req: Request<B>) -> Self::Future {
95 if let Some(prev) = req.extensions_mut().get_mut::<NestedPath>() {
96 let new_path = if prev.as_str() == "/" {
97 Arc::clone(&self.path)
98 } else {
99 format!("{}{}", prev.as_str().trim_end_matches('/'), self.path).into()
100 };
101 prev.0 = new_path;
102 } else {
103 req.extensions_mut()
104 .insert(NestedPath(Arc::clone(&self.path)));
105 };
106
107 self.inner.call(req)
108 }
109}
110
111#[cfg(test)]
112mod tests {
113 use axum_core::response::Response;
114 use http::StatusCode;
115
116 use crate::{
117 extract::{NestedPath, Request},
118 middleware::{from_fn, Next},
119 routing::get,
120 test_helpers::*,
121 Router,
122 };
123
124 #[crate::test]
125 async fn one_level_of_nesting() {
126 let api = Router::new().route(
127 "/users",
128 get(|nested_path: NestedPath| {
129 assert_eq!(nested_path.as_str(), "/api");
130 async {}
131 }),
132 );
133
134 let app = Router::new().nest("/api", api);
135
136 let client = TestClient::new(app);
137
138 let res = client.get("/api/users").await;
139 assert_eq!(res.status(), StatusCode::OK);
140 }
141
142 #[crate::test]
143 async fn one_level_of_nesting_with_trailing_slash() {
144 let api = Router::new().route(
145 "/users",
146 get(|nested_path: NestedPath| {
147 assert_eq!(nested_path.as_str(), "/api/");
148 async {}
149 }),
150 );
151
152 let app = Router::new().nest("/api/", api);
153
154 let client = TestClient::new(app);
155
156 let res = client.get("/api/users").await;
157 assert_eq!(res.status(), StatusCode::OK);
158 }
159
160 #[crate::test]
161 async fn two_levels_of_nesting() {
162 let api = Router::new().route(
163 "/users",
164 get(|nested_path: NestedPath| {
165 assert_eq!(nested_path.as_str(), "/api/v2");
166 async {}
167 }),
168 );
169
170 let app = Router::new().nest("/api", Router::new().nest("/v2", api));
171
172 let client = TestClient::new(app);
173
174 let res = client.get("/api/v2/users").await;
175 assert_eq!(res.status(), StatusCode::OK);
176 }
177
178 #[crate::test]
179 async fn two_levels_of_nesting_with_trailing_slash() {
180 let api = Router::new().route(
181 "/users",
182 get(|nested_path: NestedPath| {
183 assert_eq!(nested_path.as_str(), "/api/v2");
184 async {}
185 }),
186 );
187
188 let app = Router::new().nest("/api/", Router::new().nest("/v2", api));
189
190 let client = TestClient::new(app);
191
192 let res = client.get("/api/v2/users").await;
193 assert_eq!(res.status(), StatusCode::OK);
194 }
195
196 #[crate::test]
197 async fn in_fallbacks() {
198 let api = Router::new().fallback(get(|nested_path: NestedPath| {
199 assert_eq!(nested_path.as_str(), "/api");
200 async {}
201 }));
202
203 let app = Router::new().nest("/api", api);
204
205 let client = TestClient::new(app);
206
207 let res = client.get("/api/doesnt-exist").await;
208 assert_eq!(res.status(), StatusCode::OK);
209 }
210
211 #[crate::test]
212 async fn in_middleware() {
213 async fn middleware(nested_path: NestedPath, req: Request, next: Next) -> Response {
214 assert_eq!(nested_path.as_str(), "/api");
215 next.run(req).await
216 }
217
218 let api = Router::new()
219 .route("/users", get(|| async {}))
220 .layer(from_fn(middleware));
221
222 let app = Router::new().nest("/api", api);
223
224 let client = TestClient::new(app);
225
226 let res = client.get("/api/users").await;
227 assert_eq!(res.status(), StatusCode::OK);
228 }
229}