1use crate::{
2 body::{Body, HttpBody},
3 response::Response,
4 util::MapIntoResponse,
5};
6use axum_core::{extract::Request, response::IntoResponse};
7use bytes::Bytes;
8use http::{
9 header::{self, CONTENT_LENGTH},
10 HeaderMap, HeaderValue, Method,
11};
12use pin_project_lite::pin_project;
13use std::{
14 convert::Infallible,
15 fmt,
16 future::Future,
17 pin::Pin,
18 task::{ready, Context, Poll},
19};
20use tower::{
21 util::{BoxCloneSyncService, MapErrLayer, Oneshot},
22 ServiceExt,
23};
24use tower_layer::Layer;
25use tower_service::Service;
26
27pub struct Route<E = Infallible>(BoxCloneSyncService<Request, Response, E>);
32
33impl<E> Route<E> {
34 pub(crate) fn new<T>(svc: T) -> Self
35 where
36 T: Service<Request, Error = E> + Clone + Send + Sync + 'static,
37 T::Response: IntoResponse + 'static,
38 T::Future: Send + 'static,
39 {
40 Self(BoxCloneSyncService::new(MapIntoResponse::new(svc)))
41 }
42
43 pub(crate) fn call_owned(self, req: Request<Body>) -> RouteFuture<E> {
45 let req = req.map(Body::new);
46 self.oneshot_inner_owned(req).not_top_level()
47 }
48
49 pub(crate) fn oneshot_inner(&mut self, req: Request) -> RouteFuture<E> {
50 let method = req.method().clone();
51 RouteFuture::new(method, self.0.clone().oneshot(req))
52 }
53
54 pub(crate) fn oneshot_inner_owned(self, req: Request) -> RouteFuture<E> {
56 let method = req.method().clone();
57 RouteFuture::new(method, self.0.oneshot(req))
58 }
59
60 pub(crate) fn layer<L, NewError>(self, layer: L) -> Route<NewError>
61 where
62 L: Layer<Route<E>> + Clone + Send + 'static,
63 L::Service: Service<Request> + Clone + Send + Sync + 'static,
64 <L::Service as Service<Request>>::Response: IntoResponse + 'static,
65 <L::Service as Service<Request>>::Error: Into<NewError> + 'static,
66 <L::Service as Service<Request>>::Future: Send + 'static,
67 NewError: 'static,
68 {
69 let layer = (MapErrLayer::new(Into::into), layer);
70
71 Route::new(layer.layer(self))
72 }
73}
74
75impl<E> Clone for Route<E> {
76 #[track_caller]
77 fn clone(&self) -> Self {
78 Self(self.0.clone())
79 }
80}
81
82impl<E> fmt::Debug for Route<E> {
83 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
84 f.debug_struct("Route").finish()
85 }
86}
87
88impl<B, E> Service<Request<B>> for Route<E>
89where
90 B: HttpBody<Data = bytes::Bytes> + Send + 'static,
91 B::Error: Into<axum_core::BoxError>,
92{
93 type Response = Response;
94 type Error = E;
95 type Future = RouteFuture<E>;
96
97 #[inline]
98 fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
99 Poll::Ready(Ok(()))
100 }
101
102 #[inline]
103 fn call(&mut self, req: Request<B>) -> Self::Future {
104 self.oneshot_inner(req.map(Body::new)).not_top_level()
105 }
106}
107
108pin_project! {
109 pub struct RouteFuture<E> {
111 #[pin]
112 inner: Oneshot<BoxCloneSyncService<Request, Response, E>, Request>,
113 method: Method,
114 allow_header: Option<Bytes>,
115 top_level: bool,
116 }
117}
118
119impl<E> RouteFuture<E> {
120 fn new(
121 method: Method,
122 inner: Oneshot<BoxCloneSyncService<Request, Response, E>, Request>,
123 ) -> Self {
124 Self {
125 inner,
126 method,
127 allow_header: None,
128 top_level: true,
129 }
130 }
131
132 pub(crate) fn allow_header(mut self, allow_header: Bytes) -> Self {
133 self.allow_header = Some(allow_header);
134 self
135 }
136
137 pub(crate) fn not_top_level(mut self) -> Self {
138 self.top_level = false;
139 self
140 }
141}
142
143impl<E> Future for RouteFuture<E> {
144 type Output = Result<Response, E>;
145
146 #[inline]
147 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
148 let this = self.project();
149 let mut res = ready!(this.inner.poll(cx))?;
150
151 if *this.method == Method::CONNECT && res.status().is_success() {
152 if res.headers().contains_key(&CONTENT_LENGTH)
157 || res.headers().contains_key(&header::TRANSFER_ENCODING)
158 || res.size_hint().lower() != 0
159 {
160 error!("response to CONNECT with nonempty body");
161 res = res.map(|_| Body::empty());
162 }
163 } else if *this.top_level {
164 set_allow_header(res.headers_mut(), this.allow_header);
165
166 set_content_length(res.size_hint(), res.headers_mut());
168
169 if *this.method == Method::HEAD {
170 *res.body_mut() = Body::empty();
171 }
172 }
173
174 Poll::Ready(Ok(res))
175 }
176}
177
178fn set_allow_header(headers: &mut HeaderMap, allow_header: &mut Option<Bytes>) {
179 match allow_header.take() {
180 Some(allow_header) if !headers.contains_key(header::ALLOW) => {
181 headers.insert(
182 header::ALLOW,
183 HeaderValue::from_maybe_shared(allow_header).expect("invalid `Allow` header"),
184 );
185 }
186 _ => {}
187 }
188}
189
190fn set_content_length(size_hint: http_body::SizeHint, headers: &mut HeaderMap) {
191 if headers.contains_key(CONTENT_LENGTH) {
192 return;
193 }
194
195 if let Some(size) = size_hint.exact() {
196 let header_value = if size == 0 {
197 #[allow(clippy::declare_interior_mutable_const)]
198 const ZERO: HeaderValue = HeaderValue::from_static("0");
199
200 ZERO
201 } else {
202 let mut buffer = itoa::Buffer::new();
203 HeaderValue::from_str(buffer.format(size)).unwrap()
204 };
205
206 headers.insert(CONTENT_LENGTH, header_value);
207 }
208}
209
210pin_project! {
211 pub struct InfallibleRouteFuture {
213 #[pin]
214 future: RouteFuture<Infallible>,
215 }
216}
217
218impl InfallibleRouteFuture {
219 pub(crate) fn new(future: RouteFuture<Infallible>) -> Self {
220 Self { future }
221 }
222}
223
224impl Future for InfallibleRouteFuture {
225 type Output = Response;
226
227 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
228 match ready!(self.project().future.poll(cx)) {
229 Ok(response) => Poll::Ready(response),
230 Err(err) => match err {},
231 }
232 }
233}
234
235#[cfg(test)]
236mod tests {
237 use super::*;
238
239 #[test]
240 fn traits() {
241 use crate::test_helpers::*;
242 assert_send::<Route<()>>();
243 }
244}