1use super::{future::InfallibleRouteFuture, IntoMakeService};
4#[cfg(feature = "tokio")]
5use crate::extract::connect_info::IntoMakeServiceWithConnectInfo;
6use crate::{
7 body::{Body, Bytes, HttpBody},
8 boxed::BoxedIntoRoute,
9 error_handling::{HandleError, HandleErrorLayer},
10 handler::Handler,
11 http::{Method, StatusCode},
12 response::Response,
13 routing::{future::RouteFuture, Fallback, MethodFilter, Route},
14};
15use axum_core::{extract::Request, response::IntoResponse, BoxError};
16use bytes::BytesMut;
17use std::{
18 borrow::Cow,
19 convert::Infallible,
20 fmt,
21 task::{Context, Poll},
22};
23use tower::service_fn;
24use tower_layer::Layer;
25use tower_service::Service;
26
27macro_rules! top_level_service_fn {
28 (
29 $name:ident, GET
30 ) => {
31 top_level_service_fn!(
32 $name,
59 GET
60 );
61 };
62
63 (
64 $name:ident, CONNECT
65 ) => {
66 top_level_service_fn!(
67 $name,
72 CONNECT
73 );
74 };
75
76 (
77 $name:ident, $method:ident
78 ) => {
79 top_level_service_fn!(
80 #[doc = concat!("Route `", stringify!($method) ,"` requests to the given service.")]
81 $name,
84 $method
85 );
86 };
87
88 (
89 $(#[$m:meta])+
90 $name:ident, $method:ident
91 ) => {
92 $(#[$m])+
93 pub fn $name<T, S>(svc: T) -> MethodRouter<S, T::Error>
94 where
95 T: Service<Request> + Clone + Send + Sync + 'static,
96 T::Response: IntoResponse + 'static,
97 T::Future: Send + 'static,
98 S: Clone,
99 {
100 on_service(MethodFilter::$method, svc)
101 }
102 };
103}
104
105macro_rules! top_level_handler_fn {
106 (
107 $name:ident, GET
108 ) => {
109 top_level_handler_fn!(
110 $name,
131 GET
132 );
133 };
134
135 (
136 $name:ident, CONNECT
137 ) => {
138 top_level_handler_fn!(
139 $name,
144 CONNECT
145 );
146 };
147
148 (
149 $name:ident, $method:ident
150 ) => {
151 top_level_handler_fn!(
152 #[doc = concat!("Route `", stringify!($method) ,"` requests to the given handler.")]
153 $name,
156 $method
157 );
158 };
159
160 (
161 $(#[$m:meta])+
162 $name:ident, $method:ident
163 ) => {
164 $(#[$m])+
165 pub fn $name<H, T, S>(handler: H) -> MethodRouter<S, Infallible>
166 where
167 H: Handler<T, S>,
168 T: 'static,
169 S: Clone + Send + Sync + 'static,
170 {
171 on(MethodFilter::$method, handler)
172 }
173 };
174}
175
176macro_rules! chained_service_fn {
177 (
178 $name:ident, GET
179 ) => {
180 chained_service_fn!(
181 $name,
213 GET
214 );
215 };
216
217 (
218 $name:ident, CONNECT
219 ) => {
220 chained_service_fn!(
221 $name,
226 CONNECT
227 );
228 };
229
230 (
231 $name:ident, $method:ident
232 ) => {
233 chained_service_fn!(
234 #[doc = concat!("Chain an additional service that will only accept `", stringify!($method),"` requests.")]
235 $name,
238 $method
239 );
240 };
241
242 (
243 $(#[$m:meta])+
244 $name:ident, $method:ident
245 ) => {
246 $(#[$m])+
247 #[track_caller]
248 pub fn $name<T>(self, svc: T) -> Self
249 where
250 T: Service<Request, Error = E>
251 + Clone
252 + Send
253 + Sync
254 + 'static,
255 T::Response: IntoResponse + 'static,
256 T::Future: Send + 'static,
257 {
258 self.on_service(MethodFilter::$method, svc)
259 }
260 };
261}
262
263macro_rules! chained_handler_fn {
264 (
265 $name:ident, GET
266 ) => {
267 chained_handler_fn!(
268 $name,
289 GET
290 );
291 };
292
293 (
294 $name:ident, CONNECT
295 ) => {
296 chained_handler_fn!(
297 $name,
302 CONNECT
303 );
304 };
305
306 (
307 $name:ident, $method:ident
308 ) => {
309 chained_handler_fn!(
310 #[doc = concat!("Chain an additional handler that will only accept `", stringify!($method),"` requests.")]
311 $name,
314 $method
315 );
316 };
317
318 (
319 $(#[$m:meta])+
320 $name:ident, $method:ident
321 ) => {
322 $(#[$m])+
323 #[track_caller]
324 pub fn $name<H, T>(self, handler: H) -> Self
325 where
326 H: Handler<T, S>,
327 T: 'static,
328 S: Send + Sync + 'static,
329 {
330 self.on(MethodFilter::$method, handler)
331 }
332 };
333}
334
335top_level_service_fn!(connect_service, CONNECT);
336top_level_service_fn!(delete_service, DELETE);
337top_level_service_fn!(get_service, GET);
338top_level_service_fn!(head_service, HEAD);
339top_level_service_fn!(options_service, OPTIONS);
340top_level_service_fn!(patch_service, PATCH);
341top_level_service_fn!(post_service, POST);
342top_level_service_fn!(put_service, PUT);
343top_level_service_fn!(trace_service, TRACE);
344
345pub fn on_service<T, S>(filter: MethodFilter, svc: T) -> MethodRouter<S, T::Error>
369where
370 T: Service<Request> + Clone + Send + Sync + 'static,
371 T::Response: IntoResponse + 'static,
372 T::Future: Send + 'static,
373 S: Clone,
374{
375 MethodRouter::new().on_service(filter, svc)
376}
377
378pub fn any_service<T, S>(svc: T) -> MethodRouter<S, T::Error>
428where
429 T: Service<Request> + Clone + Send + Sync + 'static,
430 T::Response: IntoResponse + 'static,
431 T::Future: Send + 'static,
432 S: Clone,
433{
434 MethodRouter::new()
435 .fallback_service(svc)
436 .skip_allow_header()
437}
438
439top_level_handler_fn!(connect, CONNECT);
440top_level_handler_fn!(delete, DELETE);
441top_level_handler_fn!(get, GET);
442top_level_handler_fn!(head, HEAD);
443top_level_handler_fn!(options, OPTIONS);
444top_level_handler_fn!(patch, PATCH);
445top_level_handler_fn!(post, POST);
446top_level_handler_fn!(put, PUT);
447top_level_handler_fn!(trace, TRACE);
448
449pub fn on<H, T, S>(filter: MethodFilter, handler: H) -> MethodRouter<S, Infallible>
467where
468 H: Handler<T, S>,
469 T: 'static,
470 S: Clone + Send + Sync + 'static,
471{
472 MethodRouter::new().on(filter, handler)
473}
474
475pub fn any<H, T, S>(handler: H) -> MethodRouter<S, Infallible>
509where
510 H: Handler<T, S>,
511 T: 'static,
512 S: Clone + Send + Sync + 'static,
513{
514 MethodRouter::new().fallback(handler).skip_allow_header()
515}
516
517#[must_use]
547pub struct MethodRouter<S = (), E = Infallible> {
548 get: MethodEndpoint<S, E>,
549 head: MethodEndpoint<S, E>,
550 delete: MethodEndpoint<S, E>,
551 options: MethodEndpoint<S, E>,
552 patch: MethodEndpoint<S, E>,
553 post: MethodEndpoint<S, E>,
554 put: MethodEndpoint<S, E>,
555 trace: MethodEndpoint<S, E>,
556 connect: MethodEndpoint<S, E>,
557 fallback: Fallback<S, E>,
558 allow_header: AllowHeader,
559}
560
561#[derive(Clone, Debug)]
562enum AllowHeader {
563 None,
565 Skip,
567 Bytes(BytesMut),
569}
570
571impl AllowHeader {
572 fn merge(self, other: Self) -> Self {
573 match (self, other) {
574 (AllowHeader::Skip, _) | (_, AllowHeader::Skip) => AllowHeader::Skip,
575 (AllowHeader::None, AllowHeader::None) => AllowHeader::None,
576 (AllowHeader::None, AllowHeader::Bytes(pick)) => AllowHeader::Bytes(pick),
577 (AllowHeader::Bytes(pick), AllowHeader::None) => AllowHeader::Bytes(pick),
578 (AllowHeader::Bytes(mut a), AllowHeader::Bytes(b)) => {
579 a.extend_from_slice(b",");
580 a.extend_from_slice(&b);
581 AllowHeader::Bytes(a)
582 }
583 }
584 }
585}
586
587impl<S, E> fmt::Debug for MethodRouter<S, E> {
588 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
589 f.debug_struct("MethodRouter")
590 .field("get", &self.get)
591 .field("head", &self.head)
592 .field("delete", &self.delete)
593 .field("options", &self.options)
594 .field("patch", &self.patch)
595 .field("post", &self.post)
596 .field("put", &self.put)
597 .field("trace", &self.trace)
598 .field("connect", &self.connect)
599 .field("fallback", &self.fallback)
600 .field("allow_header", &self.allow_header)
601 .finish()
602 }
603}
604
605impl<S> MethodRouter<S, Infallible>
606where
607 S: Clone,
608{
609 #[track_caller]
631 pub fn on<H, T>(self, filter: MethodFilter, handler: H) -> Self
632 where
633 H: Handler<T, S>,
634 T: 'static,
635 S: Send + Sync + 'static,
636 {
637 self.on_endpoint(
638 filter,
639 MethodEndpoint::BoxedHandler(BoxedIntoRoute::from_handler(handler)),
640 )
641 }
642
643 chained_handler_fn!(connect, CONNECT);
644 chained_handler_fn!(delete, DELETE);
645 chained_handler_fn!(get, GET);
646 chained_handler_fn!(head, HEAD);
647 chained_handler_fn!(options, OPTIONS);
648 chained_handler_fn!(patch, PATCH);
649 chained_handler_fn!(post, POST);
650 chained_handler_fn!(put, PUT);
651 chained_handler_fn!(trace, TRACE);
652
653 pub fn fallback<H, T>(mut self, handler: H) -> Self
655 where
656 H: Handler<T, S>,
657 T: 'static,
658 S: Send + Sync + 'static,
659 {
660 self.fallback = Fallback::BoxedHandler(BoxedIntoRoute::from_handler(handler));
661 self
662 }
663
664 pub(crate) fn default_fallback<H, T>(self, handler: H) -> Self
666 where
667 H: Handler<T, S>,
668 T: 'static,
669 S: Send + Sync + 'static,
670 {
671 match self.fallback {
672 Fallback::Default(_) => self.fallback(handler),
673 _ => self,
674 }
675 }
676}
677
678impl MethodRouter<(), Infallible> {
679 pub fn into_make_service(self) -> IntoMakeService<Self> {
707 IntoMakeService::new(self.with_state(()))
708 }
709
710 #[cfg(feature = "tokio")]
739 pub fn into_make_service_with_connect_info<C>(self) -> IntoMakeServiceWithConnectInfo<Self, C> {
740 IntoMakeServiceWithConnectInfo::new(self.with_state(()))
741 }
742}
743
744impl<S, E> MethodRouter<S, E>
745where
746 S: Clone,
747{
748 pub fn new() -> Self {
751 let fallback = Route::new(service_fn(|_: Request| async {
752 Ok(StatusCode::METHOD_NOT_ALLOWED)
753 }));
754
755 Self {
756 get: MethodEndpoint::None,
757 head: MethodEndpoint::None,
758 delete: MethodEndpoint::None,
759 options: MethodEndpoint::None,
760 patch: MethodEndpoint::None,
761 post: MethodEndpoint::None,
762 put: MethodEndpoint::None,
763 trace: MethodEndpoint::None,
764 connect: MethodEndpoint::None,
765 allow_header: AllowHeader::None,
766 fallback: Fallback::Default(fallback),
767 }
768 }
769
770 pub fn with_state<S2>(self, state: S) -> MethodRouter<S2, E> {
772 MethodRouter {
773 get: self.get.with_state(&state),
774 head: self.head.with_state(&state),
775 delete: self.delete.with_state(&state),
776 options: self.options.with_state(&state),
777 patch: self.patch.with_state(&state),
778 post: self.post.with_state(&state),
779 put: self.put.with_state(&state),
780 trace: self.trace.with_state(&state),
781 connect: self.connect.with_state(&state),
782 allow_header: self.allow_header,
783 fallback: self.fallback.with_state(state),
784 }
785 }
786
787 #[track_caller]
811 pub fn on_service<T>(self, filter: MethodFilter, svc: T) -> Self
812 where
813 T: Service<Request, Error = E> + Clone + Send + Sync + 'static,
814 T::Response: IntoResponse + 'static,
815 T::Future: Send + 'static,
816 {
817 self.on_endpoint(filter, MethodEndpoint::Route(Route::new(svc)))
818 }
819
820 #[track_caller]
821 fn on_endpoint(mut self, filter: MethodFilter, endpoint: MethodEndpoint<S, E>) -> Self {
822 #[track_caller]
824 fn set_endpoint<S, E>(
825 method_name: &str,
826 out: &mut MethodEndpoint<S, E>,
827 endpoint: &MethodEndpoint<S, E>,
828 endpoint_filter: MethodFilter,
829 filter: MethodFilter,
830 allow_header: &mut AllowHeader,
831 methods: &[&'static str],
832 ) where
833 MethodEndpoint<S, E>: Clone,
834 S: Clone,
835 {
836 if endpoint_filter.contains(filter) {
837 if out.is_some() {
838 panic!(
839 "Overlapping method route. Cannot add two method routes that both handle \
840 `{method_name}`",
841 )
842 }
843 *out = endpoint.clone();
844 for method in methods {
845 append_allow_header(allow_header, method);
846 }
847 }
848 }
849
850 set_endpoint(
851 "GET",
852 &mut self.get,
853 &endpoint,
854 filter,
855 MethodFilter::GET,
856 &mut self.allow_header,
857 &["GET", "HEAD"],
858 );
859
860 set_endpoint(
861 "HEAD",
862 &mut self.head,
863 &endpoint,
864 filter,
865 MethodFilter::HEAD,
866 &mut self.allow_header,
867 &["HEAD"],
868 );
869
870 set_endpoint(
871 "TRACE",
872 &mut self.trace,
873 &endpoint,
874 filter,
875 MethodFilter::TRACE,
876 &mut self.allow_header,
877 &["TRACE"],
878 );
879
880 set_endpoint(
881 "PUT",
882 &mut self.put,
883 &endpoint,
884 filter,
885 MethodFilter::PUT,
886 &mut self.allow_header,
887 &["PUT"],
888 );
889
890 set_endpoint(
891 "POST",
892 &mut self.post,
893 &endpoint,
894 filter,
895 MethodFilter::POST,
896 &mut self.allow_header,
897 &["POST"],
898 );
899
900 set_endpoint(
901 "PATCH",
902 &mut self.patch,
903 &endpoint,
904 filter,
905 MethodFilter::PATCH,
906 &mut self.allow_header,
907 &["PATCH"],
908 );
909
910 set_endpoint(
911 "OPTIONS",
912 &mut self.options,
913 &endpoint,
914 filter,
915 MethodFilter::OPTIONS,
916 &mut self.allow_header,
917 &["OPTIONS"],
918 );
919
920 set_endpoint(
921 "DELETE",
922 &mut self.delete,
923 &endpoint,
924 filter,
925 MethodFilter::DELETE,
926 &mut self.allow_header,
927 &["DELETE"],
928 );
929
930 set_endpoint(
931 "CONNECT",
932 &mut self.options,
933 &endpoint,
934 filter,
935 MethodFilter::CONNECT,
936 &mut self.allow_header,
937 &["CONNECT"],
938 );
939
940 self
941 }
942
943 chained_service_fn!(connect_service, CONNECT);
944 chained_service_fn!(delete_service, DELETE);
945 chained_service_fn!(get_service, GET);
946 chained_service_fn!(head_service, HEAD);
947 chained_service_fn!(options_service, OPTIONS);
948 chained_service_fn!(patch_service, PATCH);
949 chained_service_fn!(post_service, POST);
950 chained_service_fn!(put_service, PUT);
951 chained_service_fn!(trace_service, TRACE);
952
953 #[doc = include_str!("../docs/method_routing/fallback.md")]
954 pub fn fallback_service<T>(mut self, svc: T) -> Self
955 where
956 T: Service<Request, Error = E> + Clone + Send + Sync + 'static,
957 T::Response: IntoResponse + 'static,
958 T::Future: Send + 'static,
959 {
960 self.fallback = Fallback::Service(Route::new(svc));
961 self
962 }
963
964 #[doc = include_str!("../docs/method_routing/layer.md")]
965 pub fn layer<L, NewError>(self, layer: L) -> MethodRouter<S, NewError>
966 where
967 L: Layer<Route<E>> + Clone + Send + Sync + 'static,
968 L::Service: Service<Request> + Clone + Send + Sync + 'static,
969 <L::Service as Service<Request>>::Response: IntoResponse + 'static,
970 <L::Service as Service<Request>>::Error: Into<NewError> + 'static,
971 <L::Service as Service<Request>>::Future: Send + 'static,
972 E: 'static,
973 S: 'static,
974 NewError: 'static,
975 {
976 let layer_fn = move |route: Route<E>| route.layer(layer.clone());
977
978 MethodRouter {
979 get: self.get.map(layer_fn.clone()),
980 head: self.head.map(layer_fn.clone()),
981 delete: self.delete.map(layer_fn.clone()),
982 options: self.options.map(layer_fn.clone()),
983 patch: self.patch.map(layer_fn.clone()),
984 post: self.post.map(layer_fn.clone()),
985 put: self.put.map(layer_fn.clone()),
986 trace: self.trace.map(layer_fn.clone()),
987 connect: self.connect.map(layer_fn.clone()),
988 fallback: self.fallback.map(layer_fn),
989 allow_header: self.allow_header,
990 }
991 }
992
993 #[doc = include_str!("../docs/method_routing/route_layer.md")]
994 #[track_caller]
995 pub fn route_layer<L>(mut self, layer: L) -> MethodRouter<S, E>
996 where
997 L: Layer<Route<E>> + Clone + Send + Sync + 'static,
998 L::Service: Service<Request, Error = E> + Clone + Send + Sync + 'static,
999 <L::Service as Service<Request>>::Response: IntoResponse + 'static,
1000 <L::Service as Service<Request>>::Future: Send + 'static,
1001 E: 'static,
1002 S: 'static,
1003 {
1004 if self.get.is_none()
1005 && self.head.is_none()
1006 && self.delete.is_none()
1007 && self.options.is_none()
1008 && self.patch.is_none()
1009 && self.post.is_none()
1010 && self.put.is_none()
1011 && self.trace.is_none()
1012 && self.connect.is_none()
1013 {
1014 panic!(
1015 "Adding a route_layer before any routes is a no-op. \
1016 Add the routes you want the layer to apply to first."
1017 );
1018 }
1019
1020 let layer_fn = move |svc| Route::new(layer.layer(svc));
1021
1022 self.get = self.get.map(layer_fn.clone());
1023 self.head = self.head.map(layer_fn.clone());
1024 self.delete = self.delete.map(layer_fn.clone());
1025 self.options = self.options.map(layer_fn.clone());
1026 self.patch = self.patch.map(layer_fn.clone());
1027 self.post = self.post.map(layer_fn.clone());
1028 self.put = self.put.map(layer_fn.clone());
1029 self.trace = self.trace.map(layer_fn.clone());
1030 self.connect = self.connect.map(layer_fn);
1031
1032 self
1033 }
1034
1035 pub(crate) fn merge_for_path(
1036 mut self,
1037 path: Option<&str>,
1038 other: MethodRouter<S, E>,
1039 ) -> Result<Self, Cow<'static, str>> {
1040 fn merge_inner<S, E>(
1042 path: Option<&str>,
1043 name: &str,
1044 first: MethodEndpoint<S, E>,
1045 second: MethodEndpoint<S, E>,
1046 ) -> Result<MethodEndpoint<S, E>, Cow<'static, str>> {
1047 match (first, second) {
1048 (MethodEndpoint::None, MethodEndpoint::None) => Ok(MethodEndpoint::None),
1049 (pick, MethodEndpoint::None) | (MethodEndpoint::None, pick) => Ok(pick),
1050 _ => {
1051 if let Some(path) = path {
1052 Err(format!(
1053 "Overlapping method route. Handler for `{name} {path}` already exists"
1054 )
1055 .into())
1056 } else {
1057 Err(format!(
1058 "Overlapping method route. Cannot merge two method routes that both \
1059 define `{name}`"
1060 )
1061 .into())
1062 }
1063 }
1064 }
1065 }
1066
1067 self.get = merge_inner(path, "GET", self.get, other.get)?;
1068 self.head = merge_inner(path, "HEAD", self.head, other.head)?;
1069 self.delete = merge_inner(path, "DELETE", self.delete, other.delete)?;
1070 self.options = merge_inner(path, "OPTIONS", self.options, other.options)?;
1071 self.patch = merge_inner(path, "PATCH", self.patch, other.patch)?;
1072 self.post = merge_inner(path, "POST", self.post, other.post)?;
1073 self.put = merge_inner(path, "PUT", self.put, other.put)?;
1074 self.trace = merge_inner(path, "TRACE", self.trace, other.trace)?;
1075 self.connect = merge_inner(path, "CONNECT", self.connect, other.connect)?;
1076
1077 self.fallback = self
1078 .fallback
1079 .merge(other.fallback)
1080 .ok_or("Cannot merge two `MethodRouter`s that both have a fallback")?;
1081
1082 self.allow_header = self.allow_header.merge(other.allow_header);
1083
1084 Ok(self)
1085 }
1086
1087 #[doc = include_str!("../docs/method_routing/merge.md")]
1088 #[track_caller]
1089 pub fn merge(self, other: MethodRouter<S, E>) -> Self {
1090 match self.merge_for_path(None, other) {
1091 Ok(t) => t,
1092 Err(e) => panic!("{e}"),
1094 }
1095 }
1096
1097 pub fn handle_error<F, T>(self, f: F) -> MethodRouter<S, Infallible>
1101 where
1102 F: Clone + Send + Sync + 'static,
1103 HandleError<Route<E>, F, T>: Service<Request, Error = Infallible>,
1104 <HandleError<Route<E>, F, T> as Service<Request>>::Future: Send,
1105 <HandleError<Route<E>, F, T> as Service<Request>>::Response: IntoResponse + Send,
1106 T: 'static,
1107 E: 'static,
1108 S: 'static,
1109 {
1110 self.layer(HandleErrorLayer::new(f))
1111 }
1112
1113 fn skip_allow_header(mut self) -> Self {
1114 self.allow_header = AllowHeader::Skip;
1115 self
1116 }
1117
1118 pub(crate) fn call_with_state(&self, req: Request, state: S) -> RouteFuture<E> {
1119 macro_rules! call {
1120 (
1121 $req:expr,
1122 $method_variant:ident,
1123 $svc:expr
1124 ) => {
1125 if *req.method() == Method::$method_variant {
1126 match $svc {
1127 MethodEndpoint::None => {}
1128 MethodEndpoint::Route(route) => {
1129 return route.clone().oneshot_inner_owned($req);
1130 }
1131 MethodEndpoint::BoxedHandler(handler) => {
1132 let route = handler.clone().into_route(state);
1133 return route.oneshot_inner_owned($req);
1134 }
1135 }
1136 }
1137 };
1138 }
1139
1140 let Self {
1142 get,
1143 head,
1144 delete,
1145 options,
1146 patch,
1147 post,
1148 put,
1149 trace,
1150 connect,
1151 fallback,
1152 allow_header,
1153 } = self;
1154
1155 call!(req, HEAD, head);
1156 call!(req, HEAD, get);
1157 call!(req, GET, get);
1158 call!(req, POST, post);
1159 call!(req, OPTIONS, options);
1160 call!(req, PATCH, patch);
1161 call!(req, PUT, put);
1162 call!(req, DELETE, delete);
1163 call!(req, TRACE, trace);
1164 call!(req, CONNECT, connect);
1165
1166 let future = fallback.clone().call_with_state(req, state);
1167
1168 match allow_header {
1169 AllowHeader::None => future.allow_header(Bytes::new()),
1170 AllowHeader::Skip => future,
1171 AllowHeader::Bytes(allow_header) => future.allow_header(allow_header.clone().freeze()),
1172 }
1173 }
1174}
1175
1176fn append_allow_header(allow_header: &mut AllowHeader, method: &'static str) {
1177 match allow_header {
1178 AllowHeader::None => {
1179 *allow_header = AllowHeader::Bytes(BytesMut::from(method));
1180 }
1181 AllowHeader::Skip => {}
1182 AllowHeader::Bytes(allow_header) => {
1183 if let Ok(s) = std::str::from_utf8(allow_header) {
1184 if !s.contains(method) {
1185 allow_header.extend_from_slice(b",");
1186 allow_header.extend_from_slice(method.as_bytes());
1187 }
1188 } else {
1189 #[cfg(debug_assertions)]
1190 panic!("`allow_header` contained invalid uft-8. This should never happen")
1191 }
1192 }
1193 }
1194}
1195
1196impl<S, E> Clone for MethodRouter<S, E> {
1197 fn clone(&self) -> Self {
1198 Self {
1199 get: self.get.clone(),
1200 head: self.head.clone(),
1201 delete: self.delete.clone(),
1202 options: self.options.clone(),
1203 patch: self.patch.clone(),
1204 post: self.post.clone(),
1205 put: self.put.clone(),
1206 trace: self.trace.clone(),
1207 connect: self.connect.clone(),
1208 fallback: self.fallback.clone(),
1209 allow_header: self.allow_header.clone(),
1210 }
1211 }
1212}
1213
1214impl<S, E> Default for MethodRouter<S, E>
1215where
1216 S: Clone,
1217{
1218 fn default() -> Self {
1219 Self::new()
1220 }
1221}
1222
1223enum MethodEndpoint<S, E> {
1224 None,
1225 Route(Route<E>),
1226 BoxedHandler(BoxedIntoRoute<S, E>),
1227}
1228
1229impl<S, E> MethodEndpoint<S, E>
1230where
1231 S: Clone,
1232{
1233 fn is_some(&self) -> bool {
1234 matches!(self, Self::Route(_) | Self::BoxedHandler(_))
1235 }
1236
1237 fn is_none(&self) -> bool {
1238 matches!(self, Self::None)
1239 }
1240
1241 fn map<F, E2>(self, f: F) -> MethodEndpoint<S, E2>
1242 where
1243 S: 'static,
1244 E: 'static,
1245 F: FnOnce(Route<E>) -> Route<E2> + Clone + Send + Sync + 'static,
1246 E2: 'static,
1247 {
1248 match self {
1249 Self::None => MethodEndpoint::None,
1250 Self::Route(route) => MethodEndpoint::Route(f(route)),
1251 Self::BoxedHandler(handler) => MethodEndpoint::BoxedHandler(handler.map(f)),
1252 }
1253 }
1254
1255 fn with_state<S2>(self, state: &S) -> MethodEndpoint<S2, E> {
1256 match self {
1257 MethodEndpoint::None => MethodEndpoint::None,
1258 MethodEndpoint::Route(route) => MethodEndpoint::Route(route),
1259 MethodEndpoint::BoxedHandler(handler) => {
1260 MethodEndpoint::Route(handler.into_route(state.clone()))
1261 }
1262 }
1263 }
1264}
1265
1266impl<S, E> Clone for MethodEndpoint<S, E> {
1267 fn clone(&self) -> Self {
1268 match self {
1269 Self::None => Self::None,
1270 Self::Route(inner) => Self::Route(inner.clone()),
1271 Self::BoxedHandler(inner) => Self::BoxedHandler(inner.clone()),
1272 }
1273 }
1274}
1275
1276impl<S, E> fmt::Debug for MethodEndpoint<S, E> {
1277 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
1278 match self {
1279 Self::None => f.debug_tuple("None").finish(),
1280 Self::Route(inner) => inner.fmt(f),
1281 Self::BoxedHandler(_) => f.debug_tuple("BoxedHandler").finish(),
1282 }
1283 }
1284}
1285
1286impl<B, E> Service<Request<B>> for MethodRouter<(), E>
1287where
1288 B: HttpBody<Data = Bytes> + Send + 'static,
1289 B::Error: Into<BoxError>,
1290{
1291 type Response = Response;
1292 type Error = E;
1293 type Future = RouteFuture<E>;
1294
1295 #[inline]
1296 fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
1297 Poll::Ready(Ok(()))
1298 }
1299
1300 #[inline]
1301 fn call(&mut self, req: Request<B>) -> Self::Future {
1302 let req = req.map(Body::new);
1303 self.call_with_state(req, ())
1304 }
1305}
1306
1307impl<S> Handler<(), S> for MethodRouter<S>
1308where
1309 S: Clone + 'static,
1310{
1311 type Future = InfallibleRouteFuture;
1312
1313 fn call(self, req: Request, state: S) -> Self::Future {
1314 InfallibleRouteFuture::new(self.call_with_state(req, state))
1315 }
1316}
1317
1318#[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))]
1320const _: () = {
1321 use crate::serve;
1322
1323 impl<L> Service<serve::IncomingStream<'_, L>> for MethodRouter<()>
1324 where
1325 L: serve::Listener,
1326 {
1327 type Response = Self;
1328 type Error = Infallible;
1329 type Future = std::future::Ready<Result<Self::Response, Self::Error>>;
1330
1331 fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
1332 Poll::Ready(Ok(()))
1333 }
1334
1335 fn call(&mut self, _req: serve::IncomingStream<'_, L>) -> Self::Future {
1336 std::future::ready(Ok(self.clone().with_state(())))
1337 }
1338 }
1339};
1340
1341#[cfg(test)]
1342mod tests {
1343 use super::*;
1344 use crate::{extract::State, handler::HandlerWithoutStateExt};
1345 use http::{header::ALLOW, HeaderMap};
1346 use http_body_util::BodyExt;
1347 use std::time::Duration;
1348 use tower::ServiceExt;
1349 use tower_http::{
1350 services::fs::ServeDir, timeout::TimeoutLayer, validate_request::ValidateRequestHeaderLayer,
1351 };
1352
1353 #[crate::test]
1354 async fn method_not_allowed_by_default() {
1355 let mut svc = MethodRouter::new();
1356 let (status, _, body) = call(Method::GET, &mut svc).await;
1357 assert_eq!(status, StatusCode::METHOD_NOT_ALLOWED);
1358 assert!(body.is_empty());
1359 }
1360
1361 #[crate::test]
1362 async fn get_service_fn() {
1363 async fn handle(_req: Request) -> Result<Response<Body>, Infallible> {
1364 Ok(Response::new(Body::from("ok")))
1365 }
1366
1367 let mut svc = get_service(service_fn(handle));
1368
1369 let (status, _, body) = call(Method::GET, &mut svc).await;
1370 assert_eq!(status, StatusCode::OK);
1371 assert_eq!(body, "ok");
1372 }
1373
1374 #[crate::test]
1375 async fn get_handler() {
1376 let mut svc = MethodRouter::new().get(ok);
1377 let (status, _, body) = call(Method::GET, &mut svc).await;
1378 assert_eq!(status, StatusCode::OK);
1379 assert_eq!(body, "ok");
1380 }
1381
1382 #[crate::test]
1383 async fn get_accepts_head() {
1384 let mut svc = MethodRouter::new().get(ok);
1385 let (status, _, body) = call(Method::HEAD, &mut svc).await;
1386 assert_eq!(status, StatusCode::OK);
1387 assert!(body.is_empty());
1388 }
1389
1390 #[crate::test]
1391 async fn head_takes_precedence_over_get() {
1392 let mut svc = MethodRouter::new().head(created).get(ok);
1393 let (status, _, body) = call(Method::HEAD, &mut svc).await;
1394 assert_eq!(status, StatusCode::CREATED);
1395 assert!(body.is_empty());
1396 }
1397
1398 #[crate::test]
1399 async fn merge() {
1400 let mut svc = get(ok).merge(post(ok));
1401
1402 let (status, _, _) = call(Method::GET, &mut svc).await;
1403 assert_eq!(status, StatusCode::OK);
1404
1405 let (status, _, _) = call(Method::POST, &mut svc).await;
1406 assert_eq!(status, StatusCode::OK);
1407 }
1408
1409 #[crate::test]
1410 async fn layer() {
1411 let mut svc = MethodRouter::new()
1412 .get(|| async { std::future::pending::<()>().await })
1413 .layer(ValidateRequestHeaderLayer::bearer("password"));
1414
1415 let (status, _, _) = call(Method::GET, &mut svc).await;
1417 assert_eq!(status, StatusCode::UNAUTHORIZED);
1418
1419 let (status, _, _) = call(Method::DELETE, &mut svc).await;
1421 assert_eq!(status, StatusCode::UNAUTHORIZED);
1422 }
1423
1424 #[crate::test]
1425 async fn route_layer() {
1426 let mut svc = MethodRouter::new()
1427 .get(|| async { std::future::pending::<()>().await })
1428 .route_layer(ValidateRequestHeaderLayer::bearer("password"));
1429
1430 let (status, _, _) = call(Method::GET, &mut svc).await;
1432 assert_eq!(status, StatusCode::UNAUTHORIZED);
1433
1434 let (status, _, _) = call(Method::DELETE, &mut svc).await;
1436 assert_eq!(status, StatusCode::METHOD_NOT_ALLOWED);
1437 }
1438
1439 #[allow(dead_code)]
1440 async fn building_complex_router() {
1441 let app = crate::Router::new().route(
1442 "/",
1443 get(ok)
1445 .post(ok)
1446 .route_layer(ValidateRequestHeaderLayer::bearer("password"))
1447 .merge(delete_service(ServeDir::new(".")))
1448 .fallback(|| async { StatusCode::NOT_FOUND })
1449 .put(ok)
1450 .layer(TimeoutLayer::new(Duration::from_secs(10))),
1451 );
1452
1453 let listener = tokio::net::TcpListener::bind("0.0.0.0:0").await.unwrap();
1454 crate::serve(listener, app).await.unwrap();
1455 }
1456
1457 #[crate::test]
1458 async fn sets_allow_header() {
1459 let mut svc = MethodRouter::new().put(ok).patch(ok);
1460 let (status, headers, _) = call(Method::GET, &mut svc).await;
1461 assert_eq!(status, StatusCode::METHOD_NOT_ALLOWED);
1462 assert_eq!(headers[ALLOW], "PUT,PATCH");
1463 }
1464
1465 #[crate::test]
1466 async fn sets_allow_header_get_head() {
1467 let mut svc = MethodRouter::new().get(ok).head(ok);
1468 let (status, headers, _) = call(Method::PUT, &mut svc).await;
1469 assert_eq!(status, StatusCode::METHOD_NOT_ALLOWED);
1470 assert_eq!(headers[ALLOW], "GET,HEAD");
1471 }
1472
1473 #[crate::test]
1474 async fn empty_allow_header_by_default() {
1475 let mut svc = MethodRouter::new();
1476 let (status, headers, _) = call(Method::PATCH, &mut svc).await;
1477 assert_eq!(status, StatusCode::METHOD_NOT_ALLOWED);
1478 assert_eq!(headers[ALLOW], "");
1479 }
1480
1481 #[crate::test]
1482 async fn allow_header_when_merging() {
1483 let a = put(ok).patch(ok);
1484 let b = get(ok).head(ok);
1485 let mut svc = a.merge(b);
1486
1487 let (status, headers, _) = call(Method::DELETE, &mut svc).await;
1488 assert_eq!(status, StatusCode::METHOD_NOT_ALLOWED);
1489 assert_eq!(headers[ALLOW], "PUT,PATCH,GET,HEAD");
1490 }
1491
1492 #[crate::test]
1493 async fn allow_header_any() {
1494 let mut svc = any(ok);
1495
1496 let (status, headers, _) = call(Method::GET, &mut svc).await;
1497 assert_eq!(status, StatusCode::OK);
1498 assert!(!headers.contains_key(ALLOW));
1499 }
1500
1501 #[crate::test]
1502 async fn allow_header_with_fallback() {
1503 let mut svc = MethodRouter::new()
1504 .get(ok)
1505 .fallback(|| async { (StatusCode::METHOD_NOT_ALLOWED, "Method not allowed") });
1506
1507 let (status, headers, _) = call(Method::DELETE, &mut svc).await;
1508 assert_eq!(status, StatusCode::METHOD_NOT_ALLOWED);
1509 assert_eq!(headers[ALLOW], "GET,HEAD");
1510 }
1511
1512 #[crate::test]
1513 async fn allow_header_with_fallback_that_sets_allow() {
1514 async fn fallback(method: Method) -> Response {
1515 if method == Method::POST {
1516 "OK".into_response()
1517 } else {
1518 (
1519 StatusCode::METHOD_NOT_ALLOWED,
1520 [(ALLOW, "GET,POST")],
1521 "Method not allowed",
1522 )
1523 .into_response()
1524 }
1525 }
1526
1527 let mut svc = MethodRouter::new().get(ok).fallback(fallback);
1528
1529 let (status, _, _) = call(Method::GET, &mut svc).await;
1530 assert_eq!(status, StatusCode::OK);
1531
1532 let (status, _, _) = call(Method::POST, &mut svc).await;
1533 assert_eq!(status, StatusCode::OK);
1534
1535 let (status, headers, _) = call(Method::DELETE, &mut svc).await;
1536 assert_eq!(status, StatusCode::METHOD_NOT_ALLOWED);
1537 assert_eq!(headers[ALLOW], "GET,POST");
1538 }
1539
1540 #[crate::test]
1541 async fn allow_header_noop_middleware() {
1542 let mut svc = MethodRouter::new()
1543 .get(ok)
1544 .layer(tower::layer::util::Identity::new());
1545
1546 let (status, headers, _) = call(Method::DELETE, &mut svc).await;
1547 assert_eq!(status, StatusCode::METHOD_NOT_ALLOWED);
1548 assert_eq!(headers[ALLOW], "GET,HEAD");
1549 }
1550
1551 #[crate::test]
1552 #[should_panic(
1553 expected = "Overlapping method route. Cannot add two method routes that both handle `GET`"
1554 )]
1555 async fn handler_overlaps() {
1556 let _: MethodRouter<()> = get(ok).get(ok);
1557 }
1558
1559 #[crate::test]
1560 #[should_panic(
1561 expected = "Overlapping method route. Cannot add two method routes that both handle `POST`"
1562 )]
1563 async fn service_overlaps() {
1564 let _: MethodRouter<()> = post_service(ok.into_service()).post_service(ok.into_service());
1565 }
1566
1567 #[crate::test]
1568 async fn get_head_does_not_overlap() {
1569 let _: MethodRouter<()> = get(ok).head(ok);
1570 }
1571
1572 #[crate::test]
1573 async fn head_get_does_not_overlap() {
1574 let _: MethodRouter<()> = head(ok).get(ok);
1575 }
1576
1577 #[crate::test]
1578 async fn accessing_state() {
1579 let mut svc = MethodRouter::new()
1580 .get(|State(state): State<&'static str>| async move { state })
1581 .with_state("state");
1582
1583 let (status, _, text) = call(Method::GET, &mut svc).await;
1584
1585 assert_eq!(status, StatusCode::OK);
1586 assert_eq!(text, "state");
1587 }
1588
1589 #[crate::test]
1590 async fn fallback_accessing_state() {
1591 let mut svc = MethodRouter::new()
1592 .fallback(|State(state): State<&'static str>| async move { state })
1593 .with_state("state");
1594
1595 let (status, _, text) = call(Method::GET, &mut svc).await;
1596
1597 assert_eq!(status, StatusCode::OK);
1598 assert_eq!(text, "state");
1599 }
1600
1601 #[crate::test]
1602 async fn merge_accessing_state() {
1603 let one = get(|State(state): State<&'static str>| async move { state });
1604 let two = post(|State(state): State<&'static str>| async move { state });
1605
1606 let mut svc = one.merge(two).with_state("state");
1607
1608 let (status, _, text) = call(Method::GET, &mut svc).await;
1609 assert_eq!(status, StatusCode::OK);
1610 assert_eq!(text, "state");
1611
1612 let (status, _, _) = call(Method::POST, &mut svc).await;
1613 assert_eq!(status, StatusCode::OK);
1614 assert_eq!(text, "state");
1615 }
1616
1617 async fn call<S>(method: Method, svc: &mut S) -> (StatusCode, HeaderMap, String)
1618 where
1619 S: Service<Request, Error = Infallible>,
1620 S::Response: IntoResponse,
1621 {
1622 let request = Request::builder()
1623 .uri("/")
1624 .method(method)
1625 .body(Body::empty())
1626 .unwrap();
1627 let response = svc
1628 .ready()
1629 .await
1630 .unwrap()
1631 .call(request)
1632 .await
1633 .unwrap()
1634 .into_response();
1635 let (parts, body) = response.into_parts();
1636 let body =
1637 String::from_utf8(BodyExt::collect(body).await.unwrap().to_bytes().to_vec()).unwrap();
1638 (parts.status, parts.headers, body)
1639 }
1640
1641 async fn ok() -> (StatusCode, &'static str) {
1642 (StatusCode::OK, "ok")
1643 }
1644
1645 async fn created() -> (StatusCode, &'static str) {
1646 (StatusCode::CREATED, "created")
1647 }
1648}