axum/routing/
method_routing.rs

1//! Route to services and handlers based on HTTP methods.
2
3use super::{future::InfallibleRouteFuture, IntoMakeService};
4#[cfg(feature = "tokio")]
5use crate::extract::connect_info::IntoMakeServiceWithConnectInfo;
6use crate::{
7    body::{Body, Bytes, HttpBody},
8    boxed::BoxedIntoRoute,
9    error_handling::{HandleError, HandleErrorLayer},
10    handler::Handler,
11    http::{Method, StatusCode},
12    response::Response,
13    routing::{future::RouteFuture, Fallback, MethodFilter, Route},
14};
15use axum_core::{extract::Request, response::IntoResponse, BoxError};
16use bytes::BytesMut;
17use std::{
18    borrow::Cow,
19    convert::Infallible,
20    fmt,
21    task::{Context, Poll},
22};
23use tower::service_fn;
24use tower_layer::Layer;
25use tower_service::Service;
26
27macro_rules! top_level_service_fn {
28    (
29        $name:ident, GET
30    ) => {
31        top_level_service_fn!(
32            /// Route `GET` requests to the given service.
33            ///
34            /// # Example
35            ///
36            /// ```rust
37            /// use axum::{
38            ///     extract::Request,
39            ///     Router,
40            ///     routing::get_service,
41            ///     body::Body,
42            /// };
43            /// use http::Response;
44            /// use std::convert::Infallible;
45            ///
46            /// let service = tower::service_fn(|request: Request| async {
47            ///     Ok::<_, Infallible>(Response::new(Body::empty()))
48            /// });
49            ///
50            /// // Requests to `GET /` will go to `service`.
51            /// let app = Router::new().route("/", get_service(service));
52            /// # let _: Router = app;
53            /// ```
54            ///
55            /// Note that `get` routes will also be called for `HEAD` requests but will have
56            /// the response body removed. Make sure to add explicit `HEAD` routes
57            /// afterwards.
58            $name,
59            GET
60        );
61    };
62
63    (
64        $name:ident, CONNECT
65    ) => {
66        top_level_service_fn!(
67            /// Route `CONNECT` requests to the given service.
68            ///
69            /// See [`MethodFilter::CONNECT`] for when you'd want to use this,
70            /// and [`get_service`] for an example.
71            $name,
72            CONNECT
73        );
74    };
75
76    (
77        $name:ident, $method:ident
78    ) => {
79        top_level_service_fn!(
80            #[doc = concat!("Route `", stringify!($method) ,"` requests to the given service.")]
81            ///
82            /// See [`get_service`] for an example.
83            $name,
84            $method
85        );
86    };
87
88    (
89        $(#[$m:meta])+
90        $name:ident, $method:ident
91    ) => {
92        $(#[$m])+
93        pub fn $name<T, S>(svc: T) -> MethodRouter<S, T::Error>
94        where
95            T: Service<Request> + Clone + Send + Sync + 'static,
96            T::Response: IntoResponse + 'static,
97            T::Future: Send + 'static,
98            S: Clone,
99        {
100            on_service(MethodFilter::$method, svc)
101        }
102    };
103}
104
105macro_rules! top_level_handler_fn {
106    (
107        $name:ident, GET
108    ) => {
109        top_level_handler_fn!(
110            /// Route `GET` requests to the given handler.
111            ///
112            /// # Example
113            ///
114            /// ```rust
115            /// use axum::{
116            ///     routing::get,
117            ///     Router,
118            /// };
119            ///
120            /// async fn handler() {}
121            ///
122            /// // Requests to `GET /` will go to `handler`.
123            /// let app = Router::new().route("/", get(handler));
124            /// # let _: Router = app;
125            /// ```
126            ///
127            /// Note that `get` routes will also be called for `HEAD` requests but will have
128            /// the response body removed. Make sure to add explicit `HEAD` routes
129            /// afterwards.
130            $name,
131            GET
132        );
133    };
134
135    (
136        $name:ident, CONNECT
137    ) => {
138        top_level_handler_fn!(
139            /// Route `CONNECT` requests to the given handler.
140            ///
141            /// See [`MethodFilter::CONNECT`] for when you'd want to use this,
142            /// and [`get`] for an example.
143            $name,
144            CONNECT
145        );
146    };
147
148    (
149        $name:ident, $method:ident
150    ) => {
151        top_level_handler_fn!(
152            #[doc = concat!("Route `", stringify!($method) ,"` requests to the given handler.")]
153            ///
154            /// See [`get`] for an example.
155            $name,
156            $method
157        );
158    };
159
160    (
161        $(#[$m:meta])+
162        $name:ident, $method:ident
163    ) => {
164        $(#[$m])+
165        pub fn $name<H, T, S>(handler: H) -> MethodRouter<S, Infallible>
166        where
167            H: Handler<T, S>,
168            T: 'static,
169            S: Clone + Send + Sync + 'static,
170        {
171            on(MethodFilter::$method, handler)
172        }
173    };
174}
175
176macro_rules! chained_service_fn {
177    (
178        $name:ident, GET
179    ) => {
180        chained_service_fn!(
181            /// Chain an additional service that will only accept `GET` requests.
182            ///
183            /// # Example
184            ///
185            /// ```rust
186            /// use axum::{
187            ///     extract::Request,
188            ///     Router,
189            ///     routing::post_service,
190            ///     body::Body,
191            /// };
192            /// use http::Response;
193            /// use std::convert::Infallible;
194            ///
195            /// let service = tower::service_fn(|request: Request| async {
196            ///     Ok::<_, Infallible>(Response::new(Body::empty()))
197            /// });
198            ///
199            /// let other_service = tower::service_fn(|request: Request| async {
200            ///     Ok::<_, Infallible>(Response::new(Body::empty()))
201            /// });
202            ///
203            /// // Requests to `POST /` will go to `service` and `GET /` will go to
204            /// // `other_service`.
205            /// let app = Router::new().route("/", post_service(service).get_service(other_service));
206            /// # let _: Router = app;
207            /// ```
208            ///
209            /// Note that `get` routes will also be called for `HEAD` requests but will have
210            /// the response body removed. Make sure to add explicit `HEAD` routes
211            /// afterwards.
212            $name,
213            GET
214        );
215    };
216
217    (
218        $name:ident, CONNECT
219    ) => {
220        chained_service_fn!(
221            /// Chain an additional service that will only accept `CONNECT` requests.
222            ///
223            /// See [`MethodFilter::CONNECT`] for when you'd want to use this,
224            /// and [`MethodRouter::get_service`] for an example.
225            $name,
226            CONNECT
227        );
228    };
229
230    (
231        $name:ident, $method:ident
232    ) => {
233        chained_service_fn!(
234            #[doc = concat!("Chain an additional service that will only accept `", stringify!($method),"` requests.")]
235            ///
236            /// See [`MethodRouter::get_service`] for an example.
237            $name,
238            $method
239        );
240    };
241
242    (
243        $(#[$m:meta])+
244        $name:ident, $method:ident
245    ) => {
246        $(#[$m])+
247        #[track_caller]
248        pub fn $name<T>(self, svc: T) -> Self
249        where
250            T: Service<Request, Error = E>
251                + Clone
252                + Send
253                + Sync
254                + 'static,
255            T::Response: IntoResponse + 'static,
256            T::Future: Send + 'static,
257        {
258            self.on_service(MethodFilter::$method, svc)
259        }
260    };
261}
262
263macro_rules! chained_handler_fn {
264    (
265        $name:ident, GET
266    ) => {
267        chained_handler_fn!(
268            /// Chain an additional handler that will only accept `GET` requests.
269            ///
270            /// # Example
271            ///
272            /// ```rust
273            /// use axum::{routing::post, Router};
274            ///
275            /// async fn handler() {}
276            ///
277            /// async fn other_handler() {}
278            ///
279            /// // Requests to `POST /` will go to `handler` and `GET /` will go to
280            /// // `other_handler`.
281            /// let app = Router::new().route("/", post(handler).get(other_handler));
282            /// # let _: Router = app;
283            /// ```
284            ///
285            /// Note that `get` routes will also be called for `HEAD` requests but will have
286            /// the response body removed. Make sure to add explicit `HEAD` routes
287            /// afterwards.
288            $name,
289            GET
290        );
291    };
292
293    (
294        $name:ident, CONNECT
295    ) => {
296        chained_handler_fn!(
297            /// Chain an additional handler that will only accept `CONNECT` requests.
298            ///
299            /// See [`MethodFilter::CONNECT`] for when you'd want to use this,
300            /// and [`MethodRouter::get`] for an example.
301            $name,
302            CONNECT
303        );
304    };
305
306    (
307        $name:ident, $method:ident
308    ) => {
309        chained_handler_fn!(
310            #[doc = concat!("Chain an additional handler that will only accept `", stringify!($method),"` requests.")]
311            ///
312            /// See [`MethodRouter::get`] for an example.
313            $name,
314            $method
315        );
316    };
317
318    (
319        $(#[$m:meta])+
320        $name:ident, $method:ident
321    ) => {
322        $(#[$m])+
323        #[track_caller]
324        pub fn $name<H, T>(self, handler: H) -> Self
325        where
326            H: Handler<T, S>,
327            T: 'static,
328            S: Send + Sync + 'static,
329        {
330            self.on(MethodFilter::$method, handler)
331        }
332    };
333}
334
335top_level_service_fn!(connect_service, CONNECT);
336top_level_service_fn!(delete_service, DELETE);
337top_level_service_fn!(get_service, GET);
338top_level_service_fn!(head_service, HEAD);
339top_level_service_fn!(options_service, OPTIONS);
340top_level_service_fn!(patch_service, PATCH);
341top_level_service_fn!(post_service, POST);
342top_level_service_fn!(put_service, PUT);
343top_level_service_fn!(trace_service, TRACE);
344
345/// Route requests with the given method to the service.
346///
347/// # Example
348///
349/// ```rust
350/// use axum::{
351///     extract::Request,
352///     routing::on,
353///     Router,
354///     body::Body,
355///     routing::{MethodFilter, on_service},
356/// };
357/// use http::Response;
358/// use std::convert::Infallible;
359///
360/// let service = tower::service_fn(|request: Request| async {
361///     Ok::<_, Infallible>(Response::new(Body::empty()))
362/// });
363///
364/// // Requests to `POST /` will go to `service`.
365/// let app = Router::new().route("/", on_service(MethodFilter::POST, service));
366/// # let _: Router = app;
367/// ```
368pub fn on_service<T, S>(filter: MethodFilter, svc: T) -> MethodRouter<S, T::Error>
369where
370    T: Service<Request> + Clone + Send + Sync + 'static,
371    T::Response: IntoResponse + 'static,
372    T::Future: Send + 'static,
373    S: Clone,
374{
375    MethodRouter::new().on_service(filter, svc)
376}
377
378/// Route requests to the given service regardless of its method.
379///
380/// # Example
381///
382/// ```rust
383/// use axum::{
384///     extract::Request,
385///     Router,
386///     routing::any_service,
387///     body::Body,
388/// };
389/// use http::Response;
390/// use std::convert::Infallible;
391///
392/// let service = tower::service_fn(|request: Request| async {
393///     Ok::<_, Infallible>(Response::new(Body::empty()))
394/// });
395///
396/// // All requests to `/` will go to `service`.
397/// let app = Router::new().route("/", any_service(service));
398/// # let _: Router = app;
399/// ```
400///
401/// Additional methods can still be chained:
402///
403/// ```rust
404/// use axum::{
405///     extract::Request,
406///     Router,
407///     routing::any_service,
408///     body::Body,
409/// };
410/// use http::Response;
411/// use std::convert::Infallible;
412///
413/// let service = tower::service_fn(|request: Request| async {
414///     # Ok::<_, Infallible>(Response::new(Body::empty()))
415///     // ...
416/// });
417///
418/// let other_service = tower::service_fn(|request: Request| async {
419///     # Ok::<_, Infallible>(Response::new(Body::empty()))
420///     // ...
421/// });
422///
423/// // `POST /` goes to `other_service`. All other requests go to `service`
424/// let app = Router::new().route("/", any_service(service).post_service(other_service));
425/// # let _: Router = app;
426/// ```
427pub fn any_service<T, S>(svc: T) -> MethodRouter<S, T::Error>
428where
429    T: Service<Request> + Clone + Send + Sync + 'static,
430    T::Response: IntoResponse + 'static,
431    T::Future: Send + 'static,
432    S: Clone,
433{
434    MethodRouter::new()
435        .fallback_service(svc)
436        .skip_allow_header()
437}
438
439top_level_handler_fn!(connect, CONNECT);
440top_level_handler_fn!(delete, DELETE);
441top_level_handler_fn!(get, GET);
442top_level_handler_fn!(head, HEAD);
443top_level_handler_fn!(options, OPTIONS);
444top_level_handler_fn!(patch, PATCH);
445top_level_handler_fn!(post, POST);
446top_level_handler_fn!(put, PUT);
447top_level_handler_fn!(trace, TRACE);
448
449/// Route requests with the given method to the handler.
450///
451/// # Example
452///
453/// ```rust
454/// use axum::{
455///     routing::on,
456///     Router,
457///     routing::MethodFilter,
458/// };
459///
460/// async fn handler() {}
461///
462/// // Requests to `POST /` will go to `handler`.
463/// let app = Router::new().route("/", on(MethodFilter::POST, handler));
464/// # let _: Router = app;
465/// ```
466pub fn on<H, T, S>(filter: MethodFilter, handler: H) -> MethodRouter<S, Infallible>
467where
468    H: Handler<T, S>,
469    T: 'static,
470    S: Clone + Send + Sync + 'static,
471{
472    MethodRouter::new().on(filter, handler)
473}
474
475/// Route requests with the given handler regardless of the method.
476///
477/// # Example
478///
479/// ```rust
480/// use axum::{
481///     routing::any,
482///     Router,
483/// };
484///
485/// async fn handler() {}
486///
487/// // All requests to `/` will go to `handler`.
488/// let app = Router::new().route("/", any(handler));
489/// # let _: Router = app;
490/// ```
491///
492/// Additional methods can still be chained:
493///
494/// ```rust
495/// use axum::{
496///     routing::any,
497///     Router,
498/// };
499///
500/// async fn handler() {}
501///
502/// async fn other_handler() {}
503///
504/// // `POST /` goes to `other_handler`. All other requests go to `handler`
505/// let app = Router::new().route("/", any(handler).post(other_handler));
506/// # let _: Router = app;
507/// ```
508pub fn any<H, T, S>(handler: H) -> MethodRouter<S, Infallible>
509where
510    H: Handler<T, S>,
511    T: 'static,
512    S: Clone + Send + Sync + 'static,
513{
514    MethodRouter::new().fallback(handler).skip_allow_header()
515}
516
517/// A [`Service`] that accepts requests based on a [`MethodFilter`] and
518/// allows chaining additional handlers and services.
519///
520/// # When does `MethodRouter` implement [`Service`]?
521///
522/// Whether or not `MethodRouter` implements [`Service`] depends on the state type it requires.
523///
524/// ```
525/// use tower::Service;
526/// use axum::{routing::get, extract::{State, Request}, body::Body};
527///
528/// // this `MethodRouter` doesn't require any state, i.e. the state is `()`,
529/// let method_router = get(|| async {});
530/// // and thus it implements `Service`
531/// assert_service(method_router);
532///
533/// // this requires a `String` and doesn't implement `Service`
534/// let method_router = get(|_: State<String>| async {});
535/// // until you provide the `String` with `.with_state(...)`
536/// let method_router_with_state = method_router.with_state(String::new());
537/// // and then it implements `Service`
538/// assert_service(method_router_with_state);
539///
540/// // helper to check that a value implements `Service`
541/// fn assert_service<S>(service: S)
542/// where
543///     S: Service<Request>,
544/// {}
545/// ```
546#[must_use]
547pub struct MethodRouter<S = (), E = Infallible> {
548    get: MethodEndpoint<S, E>,
549    head: MethodEndpoint<S, E>,
550    delete: MethodEndpoint<S, E>,
551    options: MethodEndpoint<S, E>,
552    patch: MethodEndpoint<S, E>,
553    post: MethodEndpoint<S, E>,
554    put: MethodEndpoint<S, E>,
555    trace: MethodEndpoint<S, E>,
556    connect: MethodEndpoint<S, E>,
557    fallback: Fallback<S, E>,
558    allow_header: AllowHeader,
559}
560
561#[derive(Clone, Debug)]
562enum AllowHeader {
563    /// No `Allow` header value has been built-up yet. This is the default state
564    None,
565    /// Don't set an `Allow` header. This is used when `any` or `any_service` are called.
566    Skip,
567    /// The current value of the `Allow` header.
568    Bytes(BytesMut),
569}
570
571impl AllowHeader {
572    fn merge(self, other: Self) -> Self {
573        match (self, other) {
574            (AllowHeader::Skip, _) | (_, AllowHeader::Skip) => AllowHeader::Skip,
575            (AllowHeader::None, AllowHeader::None) => AllowHeader::None,
576            (AllowHeader::None, AllowHeader::Bytes(pick)) => AllowHeader::Bytes(pick),
577            (AllowHeader::Bytes(pick), AllowHeader::None) => AllowHeader::Bytes(pick),
578            (AllowHeader::Bytes(mut a), AllowHeader::Bytes(b)) => {
579                a.extend_from_slice(b",");
580                a.extend_from_slice(&b);
581                AllowHeader::Bytes(a)
582            }
583        }
584    }
585}
586
587impl<S, E> fmt::Debug for MethodRouter<S, E> {
588    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
589        f.debug_struct("MethodRouter")
590            .field("get", &self.get)
591            .field("head", &self.head)
592            .field("delete", &self.delete)
593            .field("options", &self.options)
594            .field("patch", &self.patch)
595            .field("post", &self.post)
596            .field("put", &self.put)
597            .field("trace", &self.trace)
598            .field("connect", &self.connect)
599            .field("fallback", &self.fallback)
600            .field("allow_header", &self.allow_header)
601            .finish()
602    }
603}
604
605impl<S> MethodRouter<S, Infallible>
606where
607    S: Clone,
608{
609    /// Chain an additional handler that will accept requests matching the given
610    /// `MethodFilter`.
611    ///
612    /// # Example
613    ///
614    /// ```rust
615    /// use axum::{
616    ///     routing::get,
617    ///     Router,
618    ///     routing::MethodFilter
619    /// };
620    ///
621    /// async fn handler() {}
622    ///
623    /// async fn other_handler() {}
624    ///
625    /// // Requests to `GET /` will go to `handler` and `DELETE /` will go to
626    /// // `other_handler`
627    /// let app = Router::new().route("/", get(handler).on(MethodFilter::DELETE, other_handler));
628    /// # let _: Router = app;
629    /// ```
630    #[track_caller]
631    pub fn on<H, T>(self, filter: MethodFilter, handler: H) -> Self
632    where
633        H: Handler<T, S>,
634        T: 'static,
635        S: Send + Sync + 'static,
636    {
637        self.on_endpoint(
638            filter,
639            MethodEndpoint::BoxedHandler(BoxedIntoRoute::from_handler(handler)),
640        )
641    }
642
643    chained_handler_fn!(connect, CONNECT);
644    chained_handler_fn!(delete, DELETE);
645    chained_handler_fn!(get, GET);
646    chained_handler_fn!(head, HEAD);
647    chained_handler_fn!(options, OPTIONS);
648    chained_handler_fn!(patch, PATCH);
649    chained_handler_fn!(post, POST);
650    chained_handler_fn!(put, PUT);
651    chained_handler_fn!(trace, TRACE);
652
653    /// Add a fallback [`Handler`] to the router.
654    pub fn fallback<H, T>(mut self, handler: H) -> Self
655    where
656        H: Handler<T, S>,
657        T: 'static,
658        S: Send + Sync + 'static,
659    {
660        self.fallback = Fallback::BoxedHandler(BoxedIntoRoute::from_handler(handler));
661        self
662    }
663
664    /// Add a fallback [`Handler`] if no custom one has been provided.
665    pub(crate) fn default_fallback<H, T>(self, handler: H) -> Self
666    where
667        H: Handler<T, S>,
668        T: 'static,
669        S: Send + Sync + 'static,
670    {
671        match self.fallback {
672            Fallback::Default(_) => self.fallback(handler),
673            _ => self,
674        }
675    }
676}
677
678impl MethodRouter<(), Infallible> {
679    /// Convert the router into a [`MakeService`].
680    ///
681    /// This allows you to serve a single `MethodRouter` if you don't need any
682    /// routing based on the path:
683    ///
684    /// ```rust
685    /// use axum::{
686    ///     handler::Handler,
687    ///     http::{Uri, Method},
688    ///     response::IntoResponse,
689    ///     routing::get,
690    /// };
691    /// use std::net::SocketAddr;
692    ///
693    /// async fn handler(method: Method, uri: Uri, body: String) -> String {
694    ///     format!("received `{method} {uri}` with body `{body:?}`")
695    /// }
696    ///
697    /// let router = get(handler).post(handler);
698    ///
699    /// # async {
700    /// let listener = tokio::net::TcpListener::bind("0.0.0.0:3000").await.unwrap();
701    /// axum::serve(listener, router.into_make_service()).await.unwrap();
702    /// # };
703    /// ```
704    ///
705    /// [`MakeService`]: tower::make::MakeService
706    pub fn into_make_service(self) -> IntoMakeService<Self> {
707        IntoMakeService::new(self.with_state(()))
708    }
709
710    /// Convert the router into a [`MakeService`] which stores information
711    /// about the incoming connection.
712    ///
713    /// See [`Router::into_make_service_with_connect_info`] for more details.
714    ///
715    /// ```rust
716    /// use axum::{
717    ///     handler::Handler,
718    ///     response::IntoResponse,
719    ///     extract::ConnectInfo,
720    ///     routing::get,
721    /// };
722    /// use std::net::SocketAddr;
723    ///
724    /// async fn handler(ConnectInfo(addr): ConnectInfo<SocketAddr>) -> String {
725    ///     format!("Hello {addr}")
726    /// }
727    ///
728    /// let router = get(handler).post(handler);
729    ///
730    /// # async {
731    /// let listener = tokio::net::TcpListener::bind("0.0.0.0:3000").await.unwrap();
732    /// axum::serve(listener, router.into_make_service()).await.unwrap();
733    /// # };
734    /// ```
735    ///
736    /// [`MakeService`]: tower::make::MakeService
737    /// [`Router::into_make_service_with_connect_info`]: crate::routing::Router::into_make_service_with_connect_info
738    #[cfg(feature = "tokio")]
739    pub fn into_make_service_with_connect_info<C>(self) -> IntoMakeServiceWithConnectInfo<Self, C> {
740        IntoMakeServiceWithConnectInfo::new(self.with_state(()))
741    }
742}
743
744impl<S, E> MethodRouter<S, E>
745where
746    S: Clone,
747{
748    /// Create a default `MethodRouter` that will respond with `405 Method Not Allowed` to all
749    /// requests.
750    pub fn new() -> Self {
751        let fallback = Route::new(service_fn(|_: Request| async {
752            Ok(StatusCode::METHOD_NOT_ALLOWED)
753        }));
754
755        Self {
756            get: MethodEndpoint::None,
757            head: MethodEndpoint::None,
758            delete: MethodEndpoint::None,
759            options: MethodEndpoint::None,
760            patch: MethodEndpoint::None,
761            post: MethodEndpoint::None,
762            put: MethodEndpoint::None,
763            trace: MethodEndpoint::None,
764            connect: MethodEndpoint::None,
765            allow_header: AllowHeader::None,
766            fallback: Fallback::Default(fallback),
767        }
768    }
769
770    /// Provide the state for the router.
771    pub fn with_state<S2>(self, state: S) -> MethodRouter<S2, E> {
772        MethodRouter {
773            get: self.get.with_state(&state),
774            head: self.head.with_state(&state),
775            delete: self.delete.with_state(&state),
776            options: self.options.with_state(&state),
777            patch: self.patch.with_state(&state),
778            post: self.post.with_state(&state),
779            put: self.put.with_state(&state),
780            trace: self.trace.with_state(&state),
781            connect: self.connect.with_state(&state),
782            allow_header: self.allow_header,
783            fallback: self.fallback.with_state(state),
784        }
785    }
786
787    /// Chain an additional service that will accept requests matching the given
788    /// `MethodFilter`.
789    ///
790    /// # Example
791    ///
792    /// ```rust
793    /// use axum::{
794    ///     extract::Request,
795    ///     Router,
796    ///     routing::{MethodFilter, on_service},
797    ///     body::Body,
798    /// };
799    /// use http::Response;
800    /// use std::convert::Infallible;
801    ///
802    /// let service = tower::service_fn(|request: Request| async {
803    ///     Ok::<_, Infallible>(Response::new(Body::empty()))
804    /// });
805    ///
806    /// // Requests to `DELETE /` will go to `service`
807    /// let app = Router::new().route("/", on_service(MethodFilter::DELETE, service));
808    /// # let _: Router = app;
809    /// ```
810    #[track_caller]
811    pub fn on_service<T>(self, filter: MethodFilter, svc: T) -> Self
812    where
813        T: Service<Request, Error = E> + Clone + Send + Sync + 'static,
814        T::Response: IntoResponse + 'static,
815        T::Future: Send + 'static,
816    {
817        self.on_endpoint(filter, MethodEndpoint::Route(Route::new(svc)))
818    }
819
820    #[track_caller]
821    fn on_endpoint(mut self, filter: MethodFilter, endpoint: MethodEndpoint<S, E>) -> Self {
822        // written as a separate function to generate less IR
823        #[track_caller]
824        fn set_endpoint<S, E>(
825            method_name: &str,
826            out: &mut MethodEndpoint<S, E>,
827            endpoint: &MethodEndpoint<S, E>,
828            endpoint_filter: MethodFilter,
829            filter: MethodFilter,
830            allow_header: &mut AllowHeader,
831            methods: &[&'static str],
832        ) where
833            MethodEndpoint<S, E>: Clone,
834            S: Clone,
835        {
836            if endpoint_filter.contains(filter) {
837                if out.is_some() {
838                    panic!(
839                        "Overlapping method route. Cannot add two method routes that both handle \
840                         `{method_name}`",
841                    )
842                }
843                *out = endpoint.clone();
844                for method in methods {
845                    append_allow_header(allow_header, method);
846                }
847            }
848        }
849
850        set_endpoint(
851            "GET",
852            &mut self.get,
853            &endpoint,
854            filter,
855            MethodFilter::GET,
856            &mut self.allow_header,
857            &["GET", "HEAD"],
858        );
859
860        set_endpoint(
861            "HEAD",
862            &mut self.head,
863            &endpoint,
864            filter,
865            MethodFilter::HEAD,
866            &mut self.allow_header,
867            &["HEAD"],
868        );
869
870        set_endpoint(
871            "TRACE",
872            &mut self.trace,
873            &endpoint,
874            filter,
875            MethodFilter::TRACE,
876            &mut self.allow_header,
877            &["TRACE"],
878        );
879
880        set_endpoint(
881            "PUT",
882            &mut self.put,
883            &endpoint,
884            filter,
885            MethodFilter::PUT,
886            &mut self.allow_header,
887            &["PUT"],
888        );
889
890        set_endpoint(
891            "POST",
892            &mut self.post,
893            &endpoint,
894            filter,
895            MethodFilter::POST,
896            &mut self.allow_header,
897            &["POST"],
898        );
899
900        set_endpoint(
901            "PATCH",
902            &mut self.patch,
903            &endpoint,
904            filter,
905            MethodFilter::PATCH,
906            &mut self.allow_header,
907            &["PATCH"],
908        );
909
910        set_endpoint(
911            "OPTIONS",
912            &mut self.options,
913            &endpoint,
914            filter,
915            MethodFilter::OPTIONS,
916            &mut self.allow_header,
917            &["OPTIONS"],
918        );
919
920        set_endpoint(
921            "DELETE",
922            &mut self.delete,
923            &endpoint,
924            filter,
925            MethodFilter::DELETE,
926            &mut self.allow_header,
927            &["DELETE"],
928        );
929
930        set_endpoint(
931            "CONNECT",
932            &mut self.options,
933            &endpoint,
934            filter,
935            MethodFilter::CONNECT,
936            &mut self.allow_header,
937            &["CONNECT"],
938        );
939
940        self
941    }
942
943    chained_service_fn!(connect_service, CONNECT);
944    chained_service_fn!(delete_service, DELETE);
945    chained_service_fn!(get_service, GET);
946    chained_service_fn!(head_service, HEAD);
947    chained_service_fn!(options_service, OPTIONS);
948    chained_service_fn!(patch_service, PATCH);
949    chained_service_fn!(post_service, POST);
950    chained_service_fn!(put_service, PUT);
951    chained_service_fn!(trace_service, TRACE);
952
953    #[doc = include_str!("../docs/method_routing/fallback.md")]
954    pub fn fallback_service<T>(mut self, svc: T) -> Self
955    where
956        T: Service<Request, Error = E> + Clone + Send + Sync + 'static,
957        T::Response: IntoResponse + 'static,
958        T::Future: Send + 'static,
959    {
960        self.fallback = Fallback::Service(Route::new(svc));
961        self
962    }
963
964    #[doc = include_str!("../docs/method_routing/layer.md")]
965    pub fn layer<L, NewError>(self, layer: L) -> MethodRouter<S, NewError>
966    where
967        L: Layer<Route<E>> + Clone + Send + Sync + 'static,
968        L::Service: Service<Request> + Clone + Send + Sync + 'static,
969        <L::Service as Service<Request>>::Response: IntoResponse + 'static,
970        <L::Service as Service<Request>>::Error: Into<NewError> + 'static,
971        <L::Service as Service<Request>>::Future: Send + 'static,
972        E: 'static,
973        S: 'static,
974        NewError: 'static,
975    {
976        let layer_fn = move |route: Route<E>| route.layer(layer.clone());
977
978        MethodRouter {
979            get: self.get.map(layer_fn.clone()),
980            head: self.head.map(layer_fn.clone()),
981            delete: self.delete.map(layer_fn.clone()),
982            options: self.options.map(layer_fn.clone()),
983            patch: self.patch.map(layer_fn.clone()),
984            post: self.post.map(layer_fn.clone()),
985            put: self.put.map(layer_fn.clone()),
986            trace: self.trace.map(layer_fn.clone()),
987            connect: self.connect.map(layer_fn.clone()),
988            fallback: self.fallback.map(layer_fn),
989            allow_header: self.allow_header,
990        }
991    }
992
993    #[doc = include_str!("../docs/method_routing/route_layer.md")]
994    #[track_caller]
995    pub fn route_layer<L>(mut self, layer: L) -> MethodRouter<S, E>
996    where
997        L: Layer<Route<E>> + Clone + Send + Sync + 'static,
998        L::Service: Service<Request, Error = E> + Clone + Send + Sync + 'static,
999        <L::Service as Service<Request>>::Response: IntoResponse + 'static,
1000        <L::Service as Service<Request>>::Future: Send + 'static,
1001        E: 'static,
1002        S: 'static,
1003    {
1004        if self.get.is_none()
1005            && self.head.is_none()
1006            && self.delete.is_none()
1007            && self.options.is_none()
1008            && self.patch.is_none()
1009            && self.post.is_none()
1010            && self.put.is_none()
1011            && self.trace.is_none()
1012            && self.connect.is_none()
1013        {
1014            panic!(
1015                "Adding a route_layer before any routes is a no-op. \
1016                 Add the routes you want the layer to apply to first."
1017            );
1018        }
1019
1020        let layer_fn = move |svc| Route::new(layer.layer(svc));
1021
1022        self.get = self.get.map(layer_fn.clone());
1023        self.head = self.head.map(layer_fn.clone());
1024        self.delete = self.delete.map(layer_fn.clone());
1025        self.options = self.options.map(layer_fn.clone());
1026        self.patch = self.patch.map(layer_fn.clone());
1027        self.post = self.post.map(layer_fn.clone());
1028        self.put = self.put.map(layer_fn.clone());
1029        self.trace = self.trace.map(layer_fn.clone());
1030        self.connect = self.connect.map(layer_fn);
1031
1032        self
1033    }
1034
1035    pub(crate) fn merge_for_path(
1036        mut self,
1037        path: Option<&str>,
1038        other: MethodRouter<S, E>,
1039    ) -> Result<Self, Cow<'static, str>> {
1040        // written using inner functions to generate less IR
1041        fn merge_inner<S, E>(
1042            path: Option<&str>,
1043            name: &str,
1044            first: MethodEndpoint<S, E>,
1045            second: MethodEndpoint<S, E>,
1046        ) -> Result<MethodEndpoint<S, E>, Cow<'static, str>> {
1047            match (first, second) {
1048                (MethodEndpoint::None, MethodEndpoint::None) => Ok(MethodEndpoint::None),
1049                (pick, MethodEndpoint::None) | (MethodEndpoint::None, pick) => Ok(pick),
1050                _ => {
1051                    if let Some(path) = path {
1052                        Err(format!(
1053                            "Overlapping method route. Handler for `{name} {path}` already exists"
1054                        )
1055                        .into())
1056                    } else {
1057                        Err(format!(
1058                            "Overlapping method route. Cannot merge two method routes that both \
1059                             define `{name}`"
1060                        )
1061                        .into())
1062                    }
1063                }
1064            }
1065        }
1066
1067        self.get = merge_inner(path, "GET", self.get, other.get)?;
1068        self.head = merge_inner(path, "HEAD", self.head, other.head)?;
1069        self.delete = merge_inner(path, "DELETE", self.delete, other.delete)?;
1070        self.options = merge_inner(path, "OPTIONS", self.options, other.options)?;
1071        self.patch = merge_inner(path, "PATCH", self.patch, other.patch)?;
1072        self.post = merge_inner(path, "POST", self.post, other.post)?;
1073        self.put = merge_inner(path, "PUT", self.put, other.put)?;
1074        self.trace = merge_inner(path, "TRACE", self.trace, other.trace)?;
1075        self.connect = merge_inner(path, "CONNECT", self.connect, other.connect)?;
1076
1077        self.fallback = self
1078            .fallback
1079            .merge(other.fallback)
1080            .ok_or("Cannot merge two `MethodRouter`s that both have a fallback")?;
1081
1082        self.allow_header = self.allow_header.merge(other.allow_header);
1083
1084        Ok(self)
1085    }
1086
1087    #[doc = include_str!("../docs/method_routing/merge.md")]
1088    #[track_caller]
1089    pub fn merge(self, other: MethodRouter<S, E>) -> Self {
1090        match self.merge_for_path(None, other) {
1091            Ok(t) => t,
1092            // not using unwrap or unwrap_or_else to get a clean panic message + the right location
1093            Err(e) => panic!("{e}"),
1094        }
1095    }
1096
1097    /// Apply a [`HandleErrorLayer`].
1098    ///
1099    /// This is a convenience method for doing `self.layer(HandleErrorLayer::new(f))`.
1100    pub fn handle_error<F, T>(self, f: F) -> MethodRouter<S, Infallible>
1101    where
1102        F: Clone + Send + Sync + 'static,
1103        HandleError<Route<E>, F, T>: Service<Request, Error = Infallible>,
1104        <HandleError<Route<E>, F, T> as Service<Request>>::Future: Send,
1105        <HandleError<Route<E>, F, T> as Service<Request>>::Response: IntoResponse + Send,
1106        T: 'static,
1107        E: 'static,
1108        S: 'static,
1109    {
1110        self.layer(HandleErrorLayer::new(f))
1111    }
1112
1113    fn skip_allow_header(mut self) -> Self {
1114        self.allow_header = AllowHeader::Skip;
1115        self
1116    }
1117
1118    pub(crate) fn call_with_state(&self, req: Request, state: S) -> RouteFuture<E> {
1119        macro_rules! call {
1120            (
1121                $req:expr,
1122                $method_variant:ident,
1123                $svc:expr
1124            ) => {
1125                if *req.method() == Method::$method_variant {
1126                    match $svc {
1127                        MethodEndpoint::None => {}
1128                        MethodEndpoint::Route(route) => {
1129                            return route.clone().oneshot_inner_owned($req);
1130                        }
1131                        MethodEndpoint::BoxedHandler(handler) => {
1132                            let route = handler.clone().into_route(state);
1133                            return route.oneshot_inner_owned($req);
1134                        }
1135                    }
1136                }
1137            };
1138        }
1139
1140        // written with a pattern match like this to ensure we call all routes
1141        let Self {
1142            get,
1143            head,
1144            delete,
1145            options,
1146            patch,
1147            post,
1148            put,
1149            trace,
1150            connect,
1151            fallback,
1152            allow_header,
1153        } = self;
1154
1155        call!(req, HEAD, head);
1156        call!(req, HEAD, get);
1157        call!(req, GET, get);
1158        call!(req, POST, post);
1159        call!(req, OPTIONS, options);
1160        call!(req, PATCH, patch);
1161        call!(req, PUT, put);
1162        call!(req, DELETE, delete);
1163        call!(req, TRACE, trace);
1164        call!(req, CONNECT, connect);
1165
1166        let future = fallback.clone().call_with_state(req, state);
1167
1168        match allow_header {
1169            AllowHeader::None => future.allow_header(Bytes::new()),
1170            AllowHeader::Skip => future,
1171            AllowHeader::Bytes(allow_header) => future.allow_header(allow_header.clone().freeze()),
1172        }
1173    }
1174}
1175
1176fn append_allow_header(allow_header: &mut AllowHeader, method: &'static str) {
1177    match allow_header {
1178        AllowHeader::None => {
1179            *allow_header = AllowHeader::Bytes(BytesMut::from(method));
1180        }
1181        AllowHeader::Skip => {}
1182        AllowHeader::Bytes(allow_header) => {
1183            if let Ok(s) = std::str::from_utf8(allow_header) {
1184                if !s.contains(method) {
1185                    allow_header.extend_from_slice(b",");
1186                    allow_header.extend_from_slice(method.as_bytes());
1187                }
1188            } else {
1189                #[cfg(debug_assertions)]
1190                panic!("`allow_header` contained invalid uft-8. This should never happen")
1191            }
1192        }
1193    }
1194}
1195
1196impl<S, E> Clone for MethodRouter<S, E> {
1197    fn clone(&self) -> Self {
1198        Self {
1199            get: self.get.clone(),
1200            head: self.head.clone(),
1201            delete: self.delete.clone(),
1202            options: self.options.clone(),
1203            patch: self.patch.clone(),
1204            post: self.post.clone(),
1205            put: self.put.clone(),
1206            trace: self.trace.clone(),
1207            connect: self.connect.clone(),
1208            fallback: self.fallback.clone(),
1209            allow_header: self.allow_header.clone(),
1210        }
1211    }
1212}
1213
1214impl<S, E> Default for MethodRouter<S, E>
1215where
1216    S: Clone,
1217{
1218    fn default() -> Self {
1219        Self::new()
1220    }
1221}
1222
1223enum MethodEndpoint<S, E> {
1224    None,
1225    Route(Route<E>),
1226    BoxedHandler(BoxedIntoRoute<S, E>),
1227}
1228
1229impl<S, E> MethodEndpoint<S, E>
1230where
1231    S: Clone,
1232{
1233    fn is_some(&self) -> bool {
1234        matches!(self, Self::Route(_) | Self::BoxedHandler(_))
1235    }
1236
1237    fn is_none(&self) -> bool {
1238        matches!(self, Self::None)
1239    }
1240
1241    fn map<F, E2>(self, f: F) -> MethodEndpoint<S, E2>
1242    where
1243        S: 'static,
1244        E: 'static,
1245        F: FnOnce(Route<E>) -> Route<E2> + Clone + Send + Sync + 'static,
1246        E2: 'static,
1247    {
1248        match self {
1249            Self::None => MethodEndpoint::None,
1250            Self::Route(route) => MethodEndpoint::Route(f(route)),
1251            Self::BoxedHandler(handler) => MethodEndpoint::BoxedHandler(handler.map(f)),
1252        }
1253    }
1254
1255    fn with_state<S2>(self, state: &S) -> MethodEndpoint<S2, E> {
1256        match self {
1257            MethodEndpoint::None => MethodEndpoint::None,
1258            MethodEndpoint::Route(route) => MethodEndpoint::Route(route),
1259            MethodEndpoint::BoxedHandler(handler) => {
1260                MethodEndpoint::Route(handler.into_route(state.clone()))
1261            }
1262        }
1263    }
1264}
1265
1266impl<S, E> Clone for MethodEndpoint<S, E> {
1267    fn clone(&self) -> Self {
1268        match self {
1269            Self::None => Self::None,
1270            Self::Route(inner) => Self::Route(inner.clone()),
1271            Self::BoxedHandler(inner) => Self::BoxedHandler(inner.clone()),
1272        }
1273    }
1274}
1275
1276impl<S, E> fmt::Debug for MethodEndpoint<S, E> {
1277    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
1278        match self {
1279            Self::None => f.debug_tuple("None").finish(),
1280            Self::Route(inner) => inner.fmt(f),
1281            Self::BoxedHandler(_) => f.debug_tuple("BoxedHandler").finish(),
1282        }
1283    }
1284}
1285
1286impl<B, E> Service<Request<B>> for MethodRouter<(), E>
1287where
1288    B: HttpBody<Data = Bytes> + Send + 'static,
1289    B::Error: Into<BoxError>,
1290{
1291    type Response = Response;
1292    type Error = E;
1293    type Future = RouteFuture<E>;
1294
1295    #[inline]
1296    fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
1297        Poll::Ready(Ok(()))
1298    }
1299
1300    #[inline]
1301    fn call(&mut self, req: Request<B>) -> Self::Future {
1302        let req = req.map(Body::new);
1303        self.call_with_state(req, ())
1304    }
1305}
1306
1307impl<S> Handler<(), S> for MethodRouter<S>
1308where
1309    S: Clone + 'static,
1310{
1311    type Future = InfallibleRouteFuture;
1312
1313    fn call(self, req: Request, state: S) -> Self::Future {
1314        InfallibleRouteFuture::new(self.call_with_state(req, state))
1315    }
1316}
1317
1318// for `axum::serve(listener, router)`
1319#[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))]
1320const _: () = {
1321    use crate::serve;
1322
1323    impl<L> Service<serve::IncomingStream<'_, L>> for MethodRouter<()>
1324    where
1325        L: serve::Listener,
1326    {
1327        type Response = Self;
1328        type Error = Infallible;
1329        type Future = std::future::Ready<Result<Self::Response, Self::Error>>;
1330
1331        fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
1332            Poll::Ready(Ok(()))
1333        }
1334
1335        fn call(&mut self, _req: serve::IncomingStream<'_, L>) -> Self::Future {
1336            std::future::ready(Ok(self.clone().with_state(())))
1337        }
1338    }
1339};
1340
1341#[cfg(test)]
1342mod tests {
1343    use super::*;
1344    use crate::{extract::State, handler::HandlerWithoutStateExt};
1345    use http::{header::ALLOW, HeaderMap};
1346    use http_body_util::BodyExt;
1347    use std::time::Duration;
1348    use tower::ServiceExt;
1349    use tower_http::{
1350        services::fs::ServeDir, timeout::TimeoutLayer, validate_request::ValidateRequestHeaderLayer,
1351    };
1352
1353    #[crate::test]
1354    async fn method_not_allowed_by_default() {
1355        let mut svc = MethodRouter::new();
1356        let (status, _, body) = call(Method::GET, &mut svc).await;
1357        assert_eq!(status, StatusCode::METHOD_NOT_ALLOWED);
1358        assert!(body.is_empty());
1359    }
1360
1361    #[crate::test]
1362    async fn get_service_fn() {
1363        async fn handle(_req: Request) -> Result<Response<Body>, Infallible> {
1364            Ok(Response::new(Body::from("ok")))
1365        }
1366
1367        let mut svc = get_service(service_fn(handle));
1368
1369        let (status, _, body) = call(Method::GET, &mut svc).await;
1370        assert_eq!(status, StatusCode::OK);
1371        assert_eq!(body, "ok");
1372    }
1373
1374    #[crate::test]
1375    async fn get_handler() {
1376        let mut svc = MethodRouter::new().get(ok);
1377        let (status, _, body) = call(Method::GET, &mut svc).await;
1378        assert_eq!(status, StatusCode::OK);
1379        assert_eq!(body, "ok");
1380    }
1381
1382    #[crate::test]
1383    async fn get_accepts_head() {
1384        let mut svc = MethodRouter::new().get(ok);
1385        let (status, _, body) = call(Method::HEAD, &mut svc).await;
1386        assert_eq!(status, StatusCode::OK);
1387        assert!(body.is_empty());
1388    }
1389
1390    #[crate::test]
1391    async fn head_takes_precedence_over_get() {
1392        let mut svc = MethodRouter::new().head(created).get(ok);
1393        let (status, _, body) = call(Method::HEAD, &mut svc).await;
1394        assert_eq!(status, StatusCode::CREATED);
1395        assert!(body.is_empty());
1396    }
1397
1398    #[crate::test]
1399    async fn merge() {
1400        let mut svc = get(ok).merge(post(ok));
1401
1402        let (status, _, _) = call(Method::GET, &mut svc).await;
1403        assert_eq!(status, StatusCode::OK);
1404
1405        let (status, _, _) = call(Method::POST, &mut svc).await;
1406        assert_eq!(status, StatusCode::OK);
1407    }
1408
1409    #[crate::test]
1410    async fn layer() {
1411        let mut svc = MethodRouter::new()
1412            .get(|| async { std::future::pending::<()>().await })
1413            .layer(ValidateRequestHeaderLayer::bearer("password"));
1414
1415        // method with route
1416        let (status, _, _) = call(Method::GET, &mut svc).await;
1417        assert_eq!(status, StatusCode::UNAUTHORIZED);
1418
1419        // method without route
1420        let (status, _, _) = call(Method::DELETE, &mut svc).await;
1421        assert_eq!(status, StatusCode::UNAUTHORIZED);
1422    }
1423
1424    #[crate::test]
1425    async fn route_layer() {
1426        let mut svc = MethodRouter::new()
1427            .get(|| async { std::future::pending::<()>().await })
1428            .route_layer(ValidateRequestHeaderLayer::bearer("password"));
1429
1430        // method with route
1431        let (status, _, _) = call(Method::GET, &mut svc).await;
1432        assert_eq!(status, StatusCode::UNAUTHORIZED);
1433
1434        // method without route
1435        let (status, _, _) = call(Method::DELETE, &mut svc).await;
1436        assert_eq!(status, StatusCode::METHOD_NOT_ALLOWED);
1437    }
1438
1439    #[allow(dead_code)]
1440    async fn building_complex_router() {
1441        let app = crate::Router::new().route(
1442            "/",
1443            // use the all the things 💣️
1444            get(ok)
1445                .post(ok)
1446                .route_layer(ValidateRequestHeaderLayer::bearer("password"))
1447                .merge(delete_service(ServeDir::new(".")))
1448                .fallback(|| async { StatusCode::NOT_FOUND })
1449                .put(ok)
1450                .layer(TimeoutLayer::new(Duration::from_secs(10))),
1451        );
1452
1453        let listener = tokio::net::TcpListener::bind("0.0.0.0:0").await.unwrap();
1454        crate::serve(listener, app).await.unwrap();
1455    }
1456
1457    #[crate::test]
1458    async fn sets_allow_header() {
1459        let mut svc = MethodRouter::new().put(ok).patch(ok);
1460        let (status, headers, _) = call(Method::GET, &mut svc).await;
1461        assert_eq!(status, StatusCode::METHOD_NOT_ALLOWED);
1462        assert_eq!(headers[ALLOW], "PUT,PATCH");
1463    }
1464
1465    #[crate::test]
1466    async fn sets_allow_header_get_head() {
1467        let mut svc = MethodRouter::new().get(ok).head(ok);
1468        let (status, headers, _) = call(Method::PUT, &mut svc).await;
1469        assert_eq!(status, StatusCode::METHOD_NOT_ALLOWED);
1470        assert_eq!(headers[ALLOW], "GET,HEAD");
1471    }
1472
1473    #[crate::test]
1474    async fn empty_allow_header_by_default() {
1475        let mut svc = MethodRouter::new();
1476        let (status, headers, _) = call(Method::PATCH, &mut svc).await;
1477        assert_eq!(status, StatusCode::METHOD_NOT_ALLOWED);
1478        assert_eq!(headers[ALLOW], "");
1479    }
1480
1481    #[crate::test]
1482    async fn allow_header_when_merging() {
1483        let a = put(ok).patch(ok);
1484        let b = get(ok).head(ok);
1485        let mut svc = a.merge(b);
1486
1487        let (status, headers, _) = call(Method::DELETE, &mut svc).await;
1488        assert_eq!(status, StatusCode::METHOD_NOT_ALLOWED);
1489        assert_eq!(headers[ALLOW], "PUT,PATCH,GET,HEAD");
1490    }
1491
1492    #[crate::test]
1493    async fn allow_header_any() {
1494        let mut svc = any(ok);
1495
1496        let (status, headers, _) = call(Method::GET, &mut svc).await;
1497        assert_eq!(status, StatusCode::OK);
1498        assert!(!headers.contains_key(ALLOW));
1499    }
1500
1501    #[crate::test]
1502    async fn allow_header_with_fallback() {
1503        let mut svc = MethodRouter::new()
1504            .get(ok)
1505            .fallback(|| async { (StatusCode::METHOD_NOT_ALLOWED, "Method not allowed") });
1506
1507        let (status, headers, _) = call(Method::DELETE, &mut svc).await;
1508        assert_eq!(status, StatusCode::METHOD_NOT_ALLOWED);
1509        assert_eq!(headers[ALLOW], "GET,HEAD");
1510    }
1511
1512    #[crate::test]
1513    async fn allow_header_with_fallback_that_sets_allow() {
1514        async fn fallback(method: Method) -> Response {
1515            if method == Method::POST {
1516                "OK".into_response()
1517            } else {
1518                (
1519                    StatusCode::METHOD_NOT_ALLOWED,
1520                    [(ALLOW, "GET,POST")],
1521                    "Method not allowed",
1522                )
1523                    .into_response()
1524            }
1525        }
1526
1527        let mut svc = MethodRouter::new().get(ok).fallback(fallback);
1528
1529        let (status, _, _) = call(Method::GET, &mut svc).await;
1530        assert_eq!(status, StatusCode::OK);
1531
1532        let (status, _, _) = call(Method::POST, &mut svc).await;
1533        assert_eq!(status, StatusCode::OK);
1534
1535        let (status, headers, _) = call(Method::DELETE, &mut svc).await;
1536        assert_eq!(status, StatusCode::METHOD_NOT_ALLOWED);
1537        assert_eq!(headers[ALLOW], "GET,POST");
1538    }
1539
1540    #[crate::test]
1541    async fn allow_header_noop_middleware() {
1542        let mut svc = MethodRouter::new()
1543            .get(ok)
1544            .layer(tower::layer::util::Identity::new());
1545
1546        let (status, headers, _) = call(Method::DELETE, &mut svc).await;
1547        assert_eq!(status, StatusCode::METHOD_NOT_ALLOWED);
1548        assert_eq!(headers[ALLOW], "GET,HEAD");
1549    }
1550
1551    #[crate::test]
1552    #[should_panic(
1553        expected = "Overlapping method route. Cannot add two method routes that both handle `GET`"
1554    )]
1555    async fn handler_overlaps() {
1556        let _: MethodRouter<()> = get(ok).get(ok);
1557    }
1558
1559    #[crate::test]
1560    #[should_panic(
1561        expected = "Overlapping method route. Cannot add two method routes that both handle `POST`"
1562    )]
1563    async fn service_overlaps() {
1564        let _: MethodRouter<()> = post_service(ok.into_service()).post_service(ok.into_service());
1565    }
1566
1567    #[crate::test]
1568    async fn get_head_does_not_overlap() {
1569        let _: MethodRouter<()> = get(ok).head(ok);
1570    }
1571
1572    #[crate::test]
1573    async fn head_get_does_not_overlap() {
1574        let _: MethodRouter<()> = head(ok).get(ok);
1575    }
1576
1577    #[crate::test]
1578    async fn accessing_state() {
1579        let mut svc = MethodRouter::new()
1580            .get(|State(state): State<&'static str>| async move { state })
1581            .with_state("state");
1582
1583        let (status, _, text) = call(Method::GET, &mut svc).await;
1584
1585        assert_eq!(status, StatusCode::OK);
1586        assert_eq!(text, "state");
1587    }
1588
1589    #[crate::test]
1590    async fn fallback_accessing_state() {
1591        let mut svc = MethodRouter::new()
1592            .fallback(|State(state): State<&'static str>| async move { state })
1593            .with_state("state");
1594
1595        let (status, _, text) = call(Method::GET, &mut svc).await;
1596
1597        assert_eq!(status, StatusCode::OK);
1598        assert_eq!(text, "state");
1599    }
1600
1601    #[crate::test]
1602    async fn merge_accessing_state() {
1603        let one = get(|State(state): State<&'static str>| async move { state });
1604        let two = post(|State(state): State<&'static str>| async move { state });
1605
1606        let mut svc = one.merge(two).with_state("state");
1607
1608        let (status, _, text) = call(Method::GET, &mut svc).await;
1609        assert_eq!(status, StatusCode::OK);
1610        assert_eq!(text, "state");
1611
1612        let (status, _, _) = call(Method::POST, &mut svc).await;
1613        assert_eq!(status, StatusCode::OK);
1614        assert_eq!(text, "state");
1615    }
1616
1617    async fn call<S>(method: Method, svc: &mut S) -> (StatusCode, HeaderMap, String)
1618    where
1619        S: Service<Request, Error = Infallible>,
1620        S::Response: IntoResponse,
1621    {
1622        let request = Request::builder()
1623            .uri("/")
1624            .method(method)
1625            .body(Body::empty())
1626            .unwrap();
1627        let response = svc
1628            .ready()
1629            .await
1630            .unwrap()
1631            .call(request)
1632            .await
1633            .unwrap()
1634            .into_response();
1635        let (parts, body) = response.into_parts();
1636        let body =
1637            String::from_utf8(BodyExt::collect(body).await.unwrap().to_bytes().to_vec()).unwrap();
1638        (parts.status, parts.headers, body)
1639    }
1640
1641    async fn ok() -> (StatusCode, &'static str) {
1642        (StatusCode::OK, "ok")
1643    }
1644
1645    async fn created() -> (StatusCode, &'static str) {
1646        (StatusCode::CREATED, "created")
1647    }
1648}