axum/middleware/
from_fn.rs

1use axum_core::extract::{FromRequest, FromRequestParts, Request};
2use futures_util::future::BoxFuture;
3use std::{
4    any::type_name,
5    convert::Infallible,
6    fmt,
7    future::Future,
8    marker::PhantomData,
9    pin::Pin,
10    task::{Context, Poll},
11};
12use tower::util::BoxCloneSyncService;
13use tower_layer::Layer;
14use tower_service::Service;
15
16use crate::{
17    response::{IntoResponse, Response},
18    util::MapIntoResponse,
19};
20
21/// Create a middleware from an async function.
22///
23/// `from_fn` requires the function given to
24///
25/// 1. Be an `async fn`.
26/// 2. Take zero or more [`FromRequestParts`] extractors.
27/// 3. Take exactly one [`FromRequest`] extractor as the second to last argument.
28/// 4. Take [`Next`](Next) as the last argument.
29/// 5. Return something that implements [`IntoResponse`].
30///
31/// Note that this function doesn't support extracting [`State`]. For that, use [`from_fn_with_state`].
32///
33/// # Example
34///
35/// ```rust
36/// use axum::{
37///     Router,
38///     http,
39///     routing::get,
40///     response::Response,
41///     middleware::{self, Next},
42///     extract::Request,
43/// };
44///
45/// async fn my_middleware(
46///     request: Request,
47///     next: Next,
48/// ) -> Response {
49///     // do something with `request`...
50///
51///     let response = next.run(request).await;
52///
53///     // do something with `response`...
54///
55///     response
56/// }
57///
58/// let app = Router::new()
59///     .route("/", get(|| async { /* ... */ }))
60///     .layer(middleware::from_fn(my_middleware));
61/// # let app: Router = app;
62/// ```
63///
64/// # Running extractors
65///
66/// ```rust
67/// use axum::{
68///     Router,
69///     extract::Request,
70///     http::{StatusCode, HeaderMap},
71///     middleware::{self, Next},
72///     response::Response,
73///     routing::get,
74/// };
75///
76/// async fn auth(
77///     // run the `HeaderMap` extractor
78///     headers: HeaderMap,
79///     // you can also add more extractors here but the last
80///     // extractor must implement `FromRequest` which
81///     // `Request` does
82///     request: Request,
83///     next: Next,
84/// ) -> Result<Response, StatusCode> {
85///     match get_token(&headers) {
86///         Some(token) if token_is_valid(token) => {
87///             let response = next.run(request).await;
88///             Ok(response)
89///         }
90///         _ => {
91///             Err(StatusCode::UNAUTHORIZED)
92///         }
93///     }
94/// }
95///
96/// fn get_token(headers: &HeaderMap) -> Option<&str> {
97///     // ...
98///     # None
99/// }
100///
101/// fn token_is_valid(token: &str) -> bool {
102///     // ...
103///     # false
104/// }
105///
106/// let app = Router::new()
107///     .route("/", get(|| async { /* ... */ }))
108///     .route_layer(middleware::from_fn(auth));
109/// # let app: Router = app;
110/// ```
111///
112/// [extractors]: crate::extract::FromRequest
113/// [`State`]: crate::extract::State
114pub fn from_fn<F, T>(f: F) -> FromFnLayer<F, (), T> {
115    from_fn_with_state((), f)
116}
117
118/// Create a middleware from an async function with the given state.
119///
120/// For the requirements for the function supplied see [`from_fn`].
121///
122/// See [`State`](crate::extract::State) for more details about accessing state.
123///
124/// # Example
125///
126/// ```rust
127/// use axum::{
128///     Router,
129///     http::StatusCode,
130///     routing::get,
131///     response::{IntoResponse, Response},
132///     middleware::{self, Next},
133///     extract::{Request, State},
134/// };
135///
136/// #[derive(Clone)]
137/// struct AppState { /* ... */ }
138///
139/// async fn my_middleware(
140///     State(state): State<AppState>,
141///     // you can add more extractors here but the last
142///     // extractor must implement `FromRequest` which
143///     // `Request` does
144///     request: Request,
145///     next: Next,
146/// ) -> Response {
147///     // do something with `request`...
148///
149///     let response = next.run(request).await;
150///
151///     // do something with `response`...
152///
153///     response
154/// }
155///
156/// let state = AppState { /* ... */ };
157///
158/// let app = Router::new()
159///     .route("/", get(|| async { /* ... */ }))
160///     .route_layer(middleware::from_fn_with_state(state.clone(), my_middleware))
161///     .with_state(state);
162/// # let _: axum::Router = app;
163/// ```
164pub fn from_fn_with_state<F, S, T>(state: S, f: F) -> FromFnLayer<F, S, T> {
165    FromFnLayer {
166        f,
167        state,
168        _extractor: PhantomData,
169    }
170}
171
172/// A [`tower::Layer`] from an async function.
173///
174/// [`tower::Layer`] is used to apply middleware to [`Router`](crate::Router)'s.
175///
176/// Created with [`from_fn`] or [`from_fn_with_state`]. See those functions for more details.
177#[must_use]
178pub struct FromFnLayer<F, S, T> {
179    f: F,
180    state: S,
181    _extractor: PhantomData<fn() -> T>,
182}
183
184impl<F, S, T> Clone for FromFnLayer<F, S, T>
185where
186    F: Clone,
187    S: Clone,
188{
189    fn clone(&self) -> Self {
190        Self {
191            f: self.f.clone(),
192            state: self.state.clone(),
193            _extractor: self._extractor,
194        }
195    }
196}
197
198impl<S, I, F, T> Layer<I> for FromFnLayer<F, S, T>
199where
200    F: Clone,
201    S: Clone,
202{
203    type Service = FromFn<F, S, I, T>;
204
205    fn layer(&self, inner: I) -> Self::Service {
206        FromFn {
207            f: self.f.clone(),
208            state: self.state.clone(),
209            inner,
210            _extractor: PhantomData,
211        }
212    }
213}
214
215impl<F, S, T> fmt::Debug for FromFnLayer<F, S, T>
216where
217    S: fmt::Debug,
218{
219    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
220        f.debug_struct("FromFnLayer")
221            // Write out the type name, without quoting it as `&type_name::<F>()` would
222            .field("f", &format_args!("{}", type_name::<F>()))
223            .field("state", &self.state)
224            .finish()
225    }
226}
227
228/// A middleware created from an async function.
229///
230/// Created with [`from_fn`] or [`from_fn_with_state`]. See those functions for more details.
231pub struct FromFn<F, S, I, T> {
232    f: F,
233    inner: I,
234    state: S,
235    _extractor: PhantomData<fn() -> T>,
236}
237
238impl<F, S, I, T> Clone for FromFn<F, S, I, T>
239where
240    F: Clone,
241    I: Clone,
242    S: Clone,
243{
244    fn clone(&self) -> Self {
245        Self {
246            f: self.f.clone(),
247            inner: self.inner.clone(),
248            state: self.state.clone(),
249            _extractor: self._extractor,
250        }
251    }
252}
253
254macro_rules! impl_service {
255    (
256        [$($ty:ident),*], $last:ident
257    ) => {
258        #[allow(non_snake_case, unused_mut)]
259        impl<F, Fut, Out, S, I, $($ty,)* $last> Service<Request> for FromFn<F, S, I, ($($ty,)* $last,)>
260        where
261            F: FnMut($($ty,)* $last, Next) -> Fut + Clone + Send + 'static,
262            $( $ty: FromRequestParts<S> + Send, )*
263            $last: FromRequest<S> + Send,
264            Fut: Future<Output = Out> + Send + 'static,
265            Out: IntoResponse + 'static,
266            I: Service<Request, Error = Infallible>
267                + Clone
268                + Send
269                + Sync
270                + 'static,
271            I::Response: IntoResponse,
272            I::Future: Send + 'static,
273            S: Clone + Send + Sync + 'static,
274        {
275            type Response = Response;
276            type Error = Infallible;
277            type Future = ResponseFuture;
278
279            fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
280                self.inner.poll_ready(cx)
281            }
282
283            fn call(&mut self, req: Request) -> Self::Future {
284                let not_ready_inner = self.inner.clone();
285                let ready_inner = std::mem::replace(&mut self.inner, not_ready_inner);
286
287                let mut f = self.f.clone();
288                let state = self.state.clone();
289                let (mut parts, body) = req.into_parts();
290
291                let future = Box::pin(async move {
292                    $(
293                        let $ty = match $ty::from_request_parts(&mut parts, &state).await {
294                            Ok(value) => value,
295                            Err(rejection) => return rejection.into_response(),
296                        };
297                    )*
298
299                    let req = Request::from_parts(parts, body);
300
301                    let $last = match $last::from_request(req, &state).await {
302                        Ok(value) => value,
303                        Err(rejection) => return rejection.into_response(),
304                    };
305
306                    let inner = BoxCloneSyncService::new(MapIntoResponse::new(ready_inner));
307                    let next = Next { inner };
308
309                    f($($ty,)* $last, next).await.into_response()
310                });
311
312                ResponseFuture {
313                    inner: future
314                }
315            }
316        }
317    };
318}
319
320all_the_tuples!(impl_service);
321
322impl<F, S, I, T> fmt::Debug for FromFn<F, S, I, T>
323where
324    S: fmt::Debug,
325    I: fmt::Debug,
326{
327    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
328        f.debug_struct("FromFnLayer")
329            .field("f", &format_args!("{}", type_name::<F>()))
330            .field("inner", &self.inner)
331            .field("state", &self.state)
332            .finish()
333    }
334}
335
336/// The remainder of a middleware stack, including the handler.
337#[derive(Debug, Clone)]
338pub struct Next {
339    inner: BoxCloneSyncService<Request, Response, Infallible>,
340}
341
342impl Next {
343    /// Execute the remaining middleware stack.
344    pub async fn run(mut self, req: Request) -> Response {
345        match self.inner.call(req).await {
346            Ok(res) => res,
347            Err(err) => match err {},
348        }
349    }
350}
351
352impl Service<Request> for Next {
353    type Response = Response;
354    type Error = Infallible;
355    type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
356
357    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
358        self.inner.poll_ready(cx)
359    }
360
361    fn call(&mut self, req: Request) -> Self::Future {
362        self.inner.call(req)
363    }
364}
365
366/// Response future for [`FromFn`].
367pub struct ResponseFuture {
368    inner: BoxFuture<'static, Response>,
369}
370
371impl Future for ResponseFuture {
372    type Output = Result<Response, Infallible>;
373
374    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
375        self.inner.as_mut().poll(cx).map(Ok)
376    }
377}
378
379impl fmt::Debug for ResponseFuture {
380    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
381        f.debug_struct("ResponseFuture").finish()
382    }
383}
384
385#[cfg(test)]
386mod tests {
387    use super::*;
388    use crate::{body::Body, routing::get, Router};
389    use http::{HeaderMap, StatusCode};
390    use http_body_util::BodyExt;
391    use tower::ServiceExt;
392
393    #[crate::test]
394    async fn basic() {
395        async fn insert_header(mut req: Request, next: Next) -> impl IntoResponse {
396            req.headers_mut()
397                .insert("x-axum-test", "ok".parse().unwrap());
398
399            next.run(req).await
400        }
401
402        async fn handle(headers: HeaderMap) -> String {
403            headers["x-axum-test"].to_str().unwrap().to_owned()
404        }
405
406        let app = Router::new()
407            .route("/", get(handle))
408            .layer(from_fn(insert_header));
409
410        let res = app
411            .oneshot(Request::builder().uri("/").body(Body::empty()).unwrap())
412            .await
413            .unwrap();
414        assert_eq!(res.status(), StatusCode::OK);
415        let body = res.collect().await.unwrap().to_bytes();
416        assert_eq!(&body[..], b"ok");
417    }
418}