1use axum_core::extract::{FromRequest, FromRequestParts, Request};
2use futures_util::future::BoxFuture;
3use std::{
4 any::type_name,
5 convert::Infallible,
6 fmt,
7 future::Future,
8 marker::PhantomData,
9 pin::Pin,
10 task::{Context, Poll},
11};
12use tower::util::BoxCloneSyncService;
13use tower_layer::Layer;
14use tower_service::Service;
15
16use crate::{
17 response::{IntoResponse, Response},
18 util::MapIntoResponse,
19};
20
21pub fn from_fn<F, T>(f: F) -> FromFnLayer<F, (), T> {
115 from_fn_with_state((), f)
116}
117
118pub fn from_fn_with_state<F, S, T>(state: S, f: F) -> FromFnLayer<F, S, T> {
165 FromFnLayer {
166 f,
167 state,
168 _extractor: PhantomData,
169 }
170}
171
172#[must_use]
178pub struct FromFnLayer<F, S, T> {
179 f: F,
180 state: S,
181 _extractor: PhantomData<fn() -> T>,
182}
183
184impl<F, S, T> Clone for FromFnLayer<F, S, T>
185where
186 F: Clone,
187 S: Clone,
188{
189 fn clone(&self) -> Self {
190 Self {
191 f: self.f.clone(),
192 state: self.state.clone(),
193 _extractor: self._extractor,
194 }
195 }
196}
197
198impl<S, I, F, T> Layer<I> for FromFnLayer<F, S, T>
199where
200 F: Clone,
201 S: Clone,
202{
203 type Service = FromFn<F, S, I, T>;
204
205 fn layer(&self, inner: I) -> Self::Service {
206 FromFn {
207 f: self.f.clone(),
208 state: self.state.clone(),
209 inner,
210 _extractor: PhantomData,
211 }
212 }
213}
214
215impl<F, S, T> fmt::Debug for FromFnLayer<F, S, T>
216where
217 S: fmt::Debug,
218{
219 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
220 f.debug_struct("FromFnLayer")
221 .field("f", &format_args!("{}", type_name::<F>()))
223 .field("state", &self.state)
224 .finish()
225 }
226}
227
228pub struct FromFn<F, S, I, T> {
232 f: F,
233 inner: I,
234 state: S,
235 _extractor: PhantomData<fn() -> T>,
236}
237
238impl<F, S, I, T> Clone for FromFn<F, S, I, T>
239where
240 F: Clone,
241 I: Clone,
242 S: Clone,
243{
244 fn clone(&self) -> Self {
245 Self {
246 f: self.f.clone(),
247 inner: self.inner.clone(),
248 state: self.state.clone(),
249 _extractor: self._extractor,
250 }
251 }
252}
253
254macro_rules! impl_service {
255 (
256 [$($ty:ident),*], $last:ident
257 ) => {
258 #[allow(non_snake_case, unused_mut)]
259 impl<F, Fut, Out, S, I, $($ty,)* $last> Service<Request> for FromFn<F, S, I, ($($ty,)* $last,)>
260 where
261 F: FnMut($($ty,)* $last, Next) -> Fut + Clone + Send + 'static,
262 $( $ty: FromRequestParts<S> + Send, )*
263 $last: FromRequest<S> + Send,
264 Fut: Future<Output = Out> + Send + 'static,
265 Out: IntoResponse + 'static,
266 I: Service<Request, Error = Infallible>
267 + Clone
268 + Send
269 + Sync
270 + 'static,
271 I::Response: IntoResponse,
272 I::Future: Send + 'static,
273 S: Clone + Send + Sync + 'static,
274 {
275 type Response = Response;
276 type Error = Infallible;
277 type Future = ResponseFuture;
278
279 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
280 self.inner.poll_ready(cx)
281 }
282
283 fn call(&mut self, req: Request) -> Self::Future {
284 let not_ready_inner = self.inner.clone();
285 let ready_inner = std::mem::replace(&mut self.inner, not_ready_inner);
286
287 let mut f = self.f.clone();
288 let state = self.state.clone();
289 let (mut parts, body) = req.into_parts();
290
291 let future = Box::pin(async move {
292 $(
293 let $ty = match $ty::from_request_parts(&mut parts, &state).await {
294 Ok(value) => value,
295 Err(rejection) => return rejection.into_response(),
296 };
297 )*
298
299 let req = Request::from_parts(parts, body);
300
301 let $last = match $last::from_request(req, &state).await {
302 Ok(value) => value,
303 Err(rejection) => return rejection.into_response(),
304 };
305
306 let inner = BoxCloneSyncService::new(MapIntoResponse::new(ready_inner));
307 let next = Next { inner };
308
309 f($($ty,)* $last, next).await.into_response()
310 });
311
312 ResponseFuture {
313 inner: future
314 }
315 }
316 }
317 };
318}
319
320all_the_tuples!(impl_service);
321
322impl<F, S, I, T> fmt::Debug for FromFn<F, S, I, T>
323where
324 S: fmt::Debug,
325 I: fmt::Debug,
326{
327 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
328 f.debug_struct("FromFnLayer")
329 .field("f", &format_args!("{}", type_name::<F>()))
330 .field("inner", &self.inner)
331 .field("state", &self.state)
332 .finish()
333 }
334}
335
336#[derive(Debug, Clone)]
338pub struct Next {
339 inner: BoxCloneSyncService<Request, Response, Infallible>,
340}
341
342impl Next {
343 pub async fn run(mut self, req: Request) -> Response {
345 match self.inner.call(req).await {
346 Ok(res) => res,
347 Err(err) => match err {},
348 }
349 }
350}
351
352impl Service<Request> for Next {
353 type Response = Response;
354 type Error = Infallible;
355 type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
356
357 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
358 self.inner.poll_ready(cx)
359 }
360
361 fn call(&mut self, req: Request) -> Self::Future {
362 self.inner.call(req)
363 }
364}
365
366pub struct ResponseFuture {
368 inner: BoxFuture<'static, Response>,
369}
370
371impl Future for ResponseFuture {
372 type Output = Result<Response, Infallible>;
373
374 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
375 self.inner.as_mut().poll(cx).map(Ok)
376 }
377}
378
379impl fmt::Debug for ResponseFuture {
380 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
381 f.debug_struct("ResponseFuture").finish()
382 }
383}
384
385#[cfg(test)]
386mod tests {
387 use super::*;
388 use crate::{body::Body, routing::get, Router};
389 use http::{HeaderMap, StatusCode};
390 use http_body_util::BodyExt;
391 use tower::ServiceExt;
392
393 #[crate::test]
394 async fn basic() {
395 async fn insert_header(mut req: Request, next: Next) -> impl IntoResponse {
396 req.headers_mut()
397 .insert("x-axum-test", "ok".parse().unwrap());
398
399 next.run(req).await
400 }
401
402 async fn handle(headers: HeaderMap) -> String {
403 headers["x-axum-test"].to_str().unwrap().to_owned()
404 }
405
406 let app = Router::new()
407 .route("/", get(handle))
408 .layer(from_fn(insert_header));
409
410 let res = app
411 .oneshot(Request::builder().uri("/").body(Body::empty()).unwrap())
412 .await
413 .unwrap();
414 assert_eq!(res.status(), StatusCode::OK);
415 let body = res.collect().await.unwrap().to_bytes();
416 assert_eq!(&body[..], b"ok");
417 }
418}