axum/routing/
route.rs

1use crate::{
2    body::{Body, HttpBody},
3    response::Response,
4    util::MapIntoResponse,
5};
6use axum_core::{extract::Request, response::IntoResponse};
7use bytes::Bytes;
8use http::{
9    header::{self, CONTENT_LENGTH},
10    HeaderMap, HeaderValue, Method,
11};
12use pin_project_lite::pin_project;
13use std::{
14    convert::Infallible,
15    fmt,
16    future::Future,
17    pin::Pin,
18    task::{ready, Context, Poll},
19};
20use tower::{
21    util::{BoxCloneSyncService, MapErrLayer, Oneshot},
22    ServiceExt,
23};
24use tower_layer::Layer;
25use tower_service::Service;
26
27/// How routes are stored inside a [`Router`](super::Router).
28///
29/// You normally shouldn't need to care about this type. It's used in
30/// [`Router::layer`](super::Router::layer).
31pub struct Route<E = Infallible>(BoxCloneSyncService<Request, Response, E>);
32
33impl<E> Route<E> {
34    pub(crate) fn new<T>(svc: T) -> Self
35    where
36        T: Service<Request, Error = E> + Clone + Send + Sync + 'static,
37        T::Response: IntoResponse + 'static,
38        T::Future: Send + 'static,
39    {
40        Self(BoxCloneSyncService::new(MapIntoResponse::new(svc)))
41    }
42
43    /// Variant of [`Route::call`] that takes ownership of the route to avoid cloning.
44    pub(crate) fn call_owned(self, req: Request<Body>) -> RouteFuture<E> {
45        let req = req.map(Body::new);
46        self.oneshot_inner_owned(req).not_top_level()
47    }
48
49    pub(crate) fn oneshot_inner(&mut self, req: Request) -> RouteFuture<E> {
50        let method = req.method().clone();
51        RouteFuture::new(method, self.0.clone().oneshot(req))
52    }
53
54    /// Variant of [`Route::oneshot_inner`] that takes ownership of the route to avoid cloning.
55    pub(crate) fn oneshot_inner_owned(self, req: Request) -> RouteFuture<E> {
56        let method = req.method().clone();
57        RouteFuture::new(method, self.0.oneshot(req))
58    }
59
60    pub(crate) fn layer<L, NewError>(self, layer: L) -> Route<NewError>
61    where
62        L: Layer<Route<E>> + Clone + Send + 'static,
63        L::Service: Service<Request> + Clone + Send + Sync + 'static,
64        <L::Service as Service<Request>>::Response: IntoResponse + 'static,
65        <L::Service as Service<Request>>::Error: Into<NewError> + 'static,
66        <L::Service as Service<Request>>::Future: Send + 'static,
67        NewError: 'static,
68    {
69        let layer = (MapErrLayer::new(Into::into), layer);
70
71        Route::new(layer.layer(self))
72    }
73}
74
75impl<E> Clone for Route<E> {
76    #[track_caller]
77    fn clone(&self) -> Self {
78        Self(self.0.clone())
79    }
80}
81
82impl<E> fmt::Debug for Route<E> {
83    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
84        f.debug_struct("Route").finish()
85    }
86}
87
88impl<B, E> Service<Request<B>> for Route<E>
89where
90    B: HttpBody<Data = bytes::Bytes> + Send + 'static,
91    B::Error: Into<axum_core::BoxError>,
92{
93    type Response = Response;
94    type Error = E;
95    type Future = RouteFuture<E>;
96
97    #[inline]
98    fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
99        Poll::Ready(Ok(()))
100    }
101
102    #[inline]
103    fn call(&mut self, req: Request<B>) -> Self::Future {
104        self.oneshot_inner(req.map(Body::new)).not_top_level()
105    }
106}
107
108pin_project! {
109    /// Response future for [`Route`].
110    pub struct RouteFuture<E> {
111        #[pin]
112        inner: Oneshot<BoxCloneSyncService<Request, Response, E>, Request>,
113        method: Method,
114        allow_header: Option<Bytes>,
115        top_level: bool,
116    }
117}
118
119impl<E> RouteFuture<E> {
120    fn new(
121        method: Method,
122        inner: Oneshot<BoxCloneSyncService<Request, Response, E>, Request>,
123    ) -> Self {
124        Self {
125            inner,
126            method,
127            allow_header: None,
128            top_level: true,
129        }
130    }
131
132    pub(crate) fn allow_header(mut self, allow_header: Bytes) -> Self {
133        self.allow_header = Some(allow_header);
134        self
135    }
136
137    pub(crate) fn not_top_level(mut self) -> Self {
138        self.top_level = false;
139        self
140    }
141}
142
143impl<E> Future for RouteFuture<E> {
144    type Output = Result<Response, E>;
145
146    #[inline]
147    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
148        let this = self.project();
149        let mut res = ready!(this.inner.poll(cx))?;
150
151        if *this.method == Method::CONNECT && res.status().is_success() {
152            // From https://httpwg.org/specs/rfc9110.html#CONNECT:
153            // > A server MUST NOT send any Transfer-Encoding or
154            // > Content-Length header fields in a 2xx (Successful)
155            // > response to CONNECT.
156            if res.headers().contains_key(&CONTENT_LENGTH)
157                || res.headers().contains_key(&header::TRANSFER_ENCODING)
158                || res.size_hint().lower() != 0
159            {
160                error!("response to CONNECT with nonempty body");
161                res = res.map(|_| Body::empty());
162            }
163        } else if *this.top_level {
164            set_allow_header(res.headers_mut(), this.allow_header);
165
166            // make sure to set content-length before removing the body
167            set_content_length(res.size_hint(), res.headers_mut());
168
169            if *this.method == Method::HEAD {
170                *res.body_mut() = Body::empty();
171            }
172        }
173
174        Poll::Ready(Ok(res))
175    }
176}
177
178fn set_allow_header(headers: &mut HeaderMap, allow_header: &mut Option<Bytes>) {
179    match allow_header.take() {
180        Some(allow_header) if !headers.contains_key(header::ALLOW) => {
181            headers.insert(
182                header::ALLOW,
183                HeaderValue::from_maybe_shared(allow_header).expect("invalid `Allow` header"),
184            );
185        }
186        _ => {}
187    }
188}
189
190fn set_content_length(size_hint: http_body::SizeHint, headers: &mut HeaderMap) {
191    if headers.contains_key(CONTENT_LENGTH) {
192        return;
193    }
194
195    if let Some(size) = size_hint.exact() {
196        let header_value = if size == 0 {
197            #[allow(clippy::declare_interior_mutable_const)]
198            const ZERO: HeaderValue = HeaderValue::from_static("0");
199
200            ZERO
201        } else {
202            let mut buffer = itoa::Buffer::new();
203            HeaderValue::from_str(buffer.format(size)).unwrap()
204        };
205
206        headers.insert(CONTENT_LENGTH, header_value);
207    }
208}
209
210pin_project! {
211    /// A [`RouteFuture`] that always yields a [`Response`].
212    pub struct InfallibleRouteFuture {
213        #[pin]
214        future: RouteFuture<Infallible>,
215    }
216}
217
218impl InfallibleRouteFuture {
219    pub(crate) fn new(future: RouteFuture<Infallible>) -> Self {
220        Self { future }
221    }
222}
223
224impl Future for InfallibleRouteFuture {
225    type Output = Response;
226
227    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
228        match ready!(self.project().future.poll(cx)) {
229            Ok(response) => Poll::Ready(response),
230            Err(err) => match err {},
231        }
232    }
233}
234
235#[cfg(test)]
236mod tests {
237    use super::*;
238
239    #[test]
240    fn traits() {
241        use crate::test_helpers::*;
242        assert_send::<Route<()>>();
243    }
244}