axum/serve/
mod.rs

1//! Serve services.
2
3use std::{
4    convert::Infallible,
5    fmt::Debug,
6    future::{poll_fn, Future, IntoFuture},
7    io,
8    marker::PhantomData,
9    sync::Arc,
10};
11
12use axum_core::{body::Body, extract::Request, response::Response};
13use futures_util::{pin_mut, FutureExt};
14use hyper::body::Incoming;
15use hyper_util::rt::{TokioExecutor, TokioIo};
16#[cfg(any(feature = "http1", feature = "http2"))]
17use hyper_util::{server::conn::auto::Builder, service::TowerToHyperService};
18use tokio::sync::watch;
19use tower::ServiceExt as _;
20use tower_service::Service;
21
22mod listener;
23
24pub use self::listener::{Listener, ListenerExt, TapIo};
25
26/// Serve the service with the supplied listener.
27///
28/// This method of running a service is intentionally simple and doesn't support any configuration.
29/// Use hyper or hyper-util if you need configuration.
30///
31/// It supports both HTTP/1 as well as HTTP/2.
32///
33/// # Examples
34///
35/// Serving a [`Router`]:
36///
37/// ```
38/// use axum::{Router, routing::get};
39///
40/// # async {
41/// let router = Router::new().route("/", get(|| async { "Hello, World!" }));
42///
43/// let listener = tokio::net::TcpListener::bind("0.0.0.0:3000").await.unwrap();
44/// axum::serve(listener, router).await.unwrap();
45/// # };
46/// ```
47///
48/// See also [`Router::into_make_service_with_connect_info`].
49///
50/// Serving a [`MethodRouter`]:
51///
52/// ```
53/// use axum::routing::get;
54///
55/// # async {
56/// let router = get(|| async { "Hello, World!" });
57///
58/// let listener = tokio::net::TcpListener::bind("0.0.0.0:3000").await.unwrap();
59/// axum::serve(listener, router).await.unwrap();
60/// # };
61/// ```
62///
63/// See also [`MethodRouter::into_make_service_with_connect_info`].
64///
65/// Serving a [`Handler`]:
66///
67/// ```
68/// use axum::handler::HandlerWithoutStateExt;
69///
70/// # async {
71/// async fn handler() -> &'static str {
72///     "Hello, World!"
73/// }
74///
75/// let listener = tokio::net::TcpListener::bind("0.0.0.0:3000").await.unwrap();
76/// axum::serve(listener, handler.into_make_service()).await.unwrap();
77/// # };
78/// ```
79///
80/// See also [`HandlerWithoutStateExt::into_make_service_with_connect_info`] and
81/// [`HandlerService::into_make_service_with_connect_info`].
82///
83/// # Return Value
84///
85/// Although this future resolves to `io::Result<()>`, it will never actually complete or return an
86/// error. Errors on the TCP socket will be handled by sleeping for a short while (currently, one
87/// second).
88///
89/// [`Router`]: crate::Router
90/// [`Router::into_make_service_with_connect_info`]: crate::Router::into_make_service_with_connect_info
91/// [`MethodRouter`]: crate::routing::MethodRouter
92/// [`MethodRouter::into_make_service_with_connect_info`]: crate::routing::MethodRouter::into_make_service_with_connect_info
93/// [`Handler`]: crate::handler::Handler
94/// [`HandlerWithoutStateExt::into_make_service_with_connect_info`]: crate::handler::HandlerWithoutStateExt::into_make_service_with_connect_info
95/// [`HandlerService::into_make_service_with_connect_info`]: crate::handler::HandlerService::into_make_service_with_connect_info
96#[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))]
97pub fn serve<L, M, S>(listener: L, make_service: M) -> Serve<L, M, S>
98where
99    L: Listener,
100    M: for<'a> Service<IncomingStream<'a, L>, Error = Infallible, Response = S>,
101    S: Service<Request, Response = Response, Error = Infallible> + Clone + Send + 'static,
102    S::Future: Send,
103{
104    Serve {
105        listener,
106        make_service,
107        _marker: PhantomData,
108    }
109}
110
111/// Future returned by [`serve`].
112#[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))]
113#[must_use = "futures must be awaited or polled"]
114pub struct Serve<L, M, S> {
115    listener: L,
116    make_service: M,
117    _marker: PhantomData<S>,
118}
119
120#[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))]
121impl<L, M, S> Serve<L, M, S>
122where
123    L: Listener,
124{
125    /// Prepares a server to handle graceful shutdown when the provided future completes.
126    ///
127    /// # Example
128    ///
129    /// ```
130    /// use axum::{Router, routing::get};
131    ///
132    /// # async {
133    /// let router = Router::new().route("/", get(|| async { "Hello, World!" }));
134    ///
135    /// let listener = tokio::net::TcpListener::bind("0.0.0.0:3000").await.unwrap();
136    /// axum::serve(listener, router)
137    ///     .with_graceful_shutdown(shutdown_signal())
138    ///     .await
139    ///     .unwrap();
140    /// # };
141    ///
142    /// async fn shutdown_signal() {
143    ///     // ...
144    /// }
145    /// ```
146    ///
147    /// # Return Value
148    ///
149    /// Similarly to [`serve`], although this future resolves to `io::Result<()>`, it will never
150    /// error. It returns `Ok(())` only after the `signal` future completes.
151    pub fn with_graceful_shutdown<F>(self, signal: F) -> WithGracefulShutdown<L, M, S, F>
152    where
153        F: Future<Output = ()> + Send + 'static,
154    {
155        WithGracefulShutdown {
156            listener: self.listener,
157            make_service: self.make_service,
158            signal,
159            _marker: PhantomData,
160        }
161    }
162
163    /// Returns the local address this server is bound to.
164    pub fn local_addr(&self) -> io::Result<L::Addr> {
165        self.listener.local_addr()
166    }
167}
168
169#[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))]
170impl<L, M, S> Debug for Serve<L, M, S>
171where
172    L: Debug + 'static,
173    M: Debug,
174{
175    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
176        let Self {
177            listener,
178            make_service,
179            _marker: _,
180        } = self;
181
182        let mut s = f.debug_struct("Serve");
183        s.field("listener", listener)
184            .field("make_service", make_service);
185
186        s.finish()
187    }
188}
189
190#[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))]
191impl<L, M, S> IntoFuture for Serve<L, M, S>
192where
193    L: Listener,
194    L::Addr: Debug,
195    M: for<'a> Service<IncomingStream<'a, L>, Error = Infallible, Response = S> + Send + 'static,
196    for<'a> <M as Service<IncomingStream<'a, L>>>::Future: Send,
197    S: Service<Request, Response = Response, Error = Infallible> + Clone + Send + 'static,
198    S::Future: Send,
199{
200    type Output = io::Result<()>;
201    type IntoFuture = private::ServeFuture;
202
203    fn into_future(self) -> Self::IntoFuture {
204        self.with_graceful_shutdown(std::future::pending())
205            .into_future()
206    }
207}
208
209/// Serve future with graceful shutdown enabled.
210#[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))]
211#[must_use = "futures must be awaited or polled"]
212pub struct WithGracefulShutdown<L, M, S, F> {
213    listener: L,
214    make_service: M,
215    signal: F,
216    _marker: PhantomData<S>,
217}
218
219#[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))]
220impl<L, M, S, F> WithGracefulShutdown<L, M, S, F>
221where
222    L: Listener,
223{
224    /// Returns the local address this server is bound to.
225    pub fn local_addr(&self) -> io::Result<L::Addr> {
226        self.listener.local_addr()
227    }
228}
229
230#[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))]
231impl<L, M, S, F> Debug for WithGracefulShutdown<L, M, S, F>
232where
233    L: Debug + 'static,
234    M: Debug,
235    S: Debug,
236    F: Debug,
237{
238    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
239        let Self {
240            listener,
241            make_service,
242            signal,
243            _marker: _,
244        } = self;
245
246        f.debug_struct("WithGracefulShutdown")
247            .field("listener", listener)
248            .field("make_service", make_service)
249            .field("signal", signal)
250            .finish()
251    }
252}
253
254#[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))]
255impl<L, M, S, F> IntoFuture for WithGracefulShutdown<L, M, S, F>
256where
257    L: Listener,
258    L::Addr: Debug,
259    M: for<'a> Service<IncomingStream<'a, L>, Error = Infallible, Response = S> + Send + 'static,
260    for<'a> <M as Service<IncomingStream<'a, L>>>::Future: Send,
261    S: Service<Request, Response = Response, Error = Infallible> + Clone + Send + 'static,
262    S::Future: Send,
263    F: Future<Output = ()> + Send + 'static,
264{
265    type Output = io::Result<()>;
266    type IntoFuture = private::ServeFuture;
267
268    fn into_future(self) -> Self::IntoFuture {
269        private::ServeFuture(Box::pin(async move {
270            self.run().await;
271            Ok(())
272        }))
273    }
274}
275
276#[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))]
277impl<L, M, S, F> WithGracefulShutdown<L, M, S, F>
278where
279    L: Listener,
280    L::Addr: Debug,
281    M: for<'a> Service<IncomingStream<'a, L>, Error = Infallible, Response = S> + Send + 'static,
282    for<'a> <M as Service<IncomingStream<'a, L>>>::Future: Send,
283    S: Service<Request, Response = Response, Error = Infallible> + Clone + Send + 'static,
284    S::Future: Send,
285    F: Future<Output = ()> + Send + 'static,
286{
287    async fn run(self) {
288        let Self {
289            mut listener,
290            mut make_service,
291            signal,
292            _marker: _,
293        } = self;
294
295        let (signal_tx, signal_rx) = watch::channel(());
296        let signal_tx = Arc::new(signal_tx);
297        tokio::spawn(async move {
298            signal.await;
299            trace!("received graceful shutdown signal. Telling tasks to shutdown");
300            drop(signal_rx);
301        });
302
303        let (close_tx, close_rx) = watch::channel(());
304
305        loop {
306            let (io, remote_addr) = tokio::select! {
307                conn = listener.accept() => conn,
308                _ = signal_tx.closed() => {
309                    trace!("signal received, not accepting new connections");
310                    break;
311                }
312            };
313
314            let io = TokioIo::new(io);
315
316            trace!("connection {remote_addr:?} accepted");
317
318            poll_fn(|cx| make_service.poll_ready(cx))
319                .await
320                .unwrap_or_else(|err| match err {});
321
322            let tower_service = make_service
323                .call(IncomingStream {
324                    io: &io,
325                    remote_addr,
326                })
327                .await
328                .unwrap_or_else(|err| match err {})
329                .map_request(|req: Request<Incoming>| req.map(Body::new));
330
331            let hyper_service = TowerToHyperService::new(tower_service);
332
333            let signal_tx = Arc::clone(&signal_tx);
334
335            let close_rx = close_rx.clone();
336
337            tokio::spawn(async move {
338                #[allow(unused_mut)]
339                let mut builder = Builder::new(TokioExecutor::new());
340                // CONNECT protocol needed for HTTP/2 websockets
341                #[cfg(feature = "http2")]
342                builder.http2().enable_connect_protocol();
343                let conn = builder.serve_connection_with_upgrades(io, hyper_service);
344                pin_mut!(conn);
345
346                let signal_closed = signal_tx.closed().fuse();
347                pin_mut!(signal_closed);
348
349                loop {
350                    tokio::select! {
351                        result = conn.as_mut() => {
352                            if let Err(_err) = result {
353                                trace!("failed to serve connection: {_err:#}");
354                            }
355                            break;
356                        }
357                        _ = &mut signal_closed => {
358                            trace!("signal received in task, starting graceful shutdown");
359                            conn.as_mut().graceful_shutdown();
360                        }
361                    }
362                }
363
364                drop(close_rx);
365            });
366        }
367
368        drop(close_rx);
369        drop(listener);
370
371        trace!(
372            "waiting for {} task(s) to finish",
373            close_tx.receiver_count()
374        );
375        close_tx.closed().await;
376    }
377}
378
379/// An incoming stream.
380///
381/// Used with [`serve`] and [`IntoMakeServiceWithConnectInfo`].
382///
383/// [`IntoMakeServiceWithConnectInfo`]: crate::extract::connect_info::IntoMakeServiceWithConnectInfo
384#[derive(Debug)]
385pub struct IncomingStream<'a, L>
386where
387    L: Listener,
388{
389    io: &'a TokioIo<L::Io>,
390    remote_addr: L::Addr,
391}
392
393impl<L> IncomingStream<'_, L>
394where
395    L: Listener,
396{
397    /// Get a reference to the inner IO type.
398    pub fn io(&self) -> &L::Io {
399        self.io.inner()
400    }
401
402    /// Returns the remote address that this stream is bound to.
403    pub fn remote_addr(&self) -> &L::Addr {
404        &self.remote_addr
405    }
406}
407
408mod private {
409    use std::{
410        future::Future,
411        io,
412        pin::Pin,
413        task::{Context, Poll},
414    };
415
416    pub struct ServeFuture(pub(super) futures_util::future::BoxFuture<'static, io::Result<()>>);
417
418    impl Future for ServeFuture {
419        type Output = io::Result<()>;
420
421        #[inline]
422        fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
423            self.0.as_mut().poll(cx)
424        }
425    }
426
427    impl std::fmt::Debug for ServeFuture {
428        fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
429            f.debug_struct("ServeFuture").finish_non_exhaustive()
430        }
431    }
432}
433
434#[cfg(test)]
435mod tests {
436    use std::{
437        future::{pending, IntoFuture as _},
438        net::{IpAddr, Ipv4Addr},
439    };
440
441    use axum_core::{body::Body, extract::Request};
442    use http::StatusCode;
443    use hyper_util::rt::TokioIo;
444    #[cfg(unix)]
445    use tokio::net::UnixListener;
446    use tokio::{
447        io::{self, AsyncRead, AsyncWrite},
448        net::TcpListener,
449    };
450
451    #[cfg(unix)]
452    use super::IncomingStream;
453    use super::{serve, Listener};
454    #[cfg(unix)]
455    use crate::extract::connect_info::Connected;
456    use crate::{
457        body::to_bytes,
458        handler::{Handler, HandlerWithoutStateExt},
459        routing::get,
460        serve::ListenerExt,
461        Router,
462    };
463
464    #[allow(dead_code, unused_must_use)]
465    async fn if_it_compiles_it_works() {
466        #[derive(Clone, Debug)]
467        struct UdsConnectInfo;
468
469        #[cfg(unix)]
470        impl Connected<IncomingStream<'_, UnixListener>> for UdsConnectInfo {
471            fn connect_info(_stream: IncomingStream<'_, UnixListener>) -> Self {
472                Self
473            }
474        }
475
476        let router: Router = Router::new();
477
478        let addr = "0.0.0.0:0";
479
480        let tcp_nodelay_listener = || async {
481            TcpListener::bind(addr).await.unwrap().tap_io(|tcp_stream| {
482                if let Err(err) = tcp_stream.set_nodelay(true) {
483                    eprintln!("failed to set TCP_NODELAY on incoming connection: {err:#}");
484                }
485            })
486        };
487
488        // router
489        serve(TcpListener::bind(addr).await.unwrap(), router.clone());
490        serve(tcp_nodelay_listener().await, router.clone())
491            .await
492            .unwrap();
493        #[cfg(unix)]
494        serve(UnixListener::bind("").unwrap(), router.clone());
495
496        serve(
497            TcpListener::bind(addr).await.unwrap(),
498            router.clone().into_make_service(),
499        );
500        serve(
501            tcp_nodelay_listener().await,
502            router.clone().into_make_service(),
503        );
504        #[cfg(unix)]
505        serve(
506            UnixListener::bind("").unwrap(),
507            router.clone().into_make_service(),
508        );
509
510        serve(
511            TcpListener::bind(addr).await.unwrap(),
512            router
513                .clone()
514                .into_make_service_with_connect_info::<std::net::SocketAddr>(),
515        );
516        serve(
517            tcp_nodelay_listener().await,
518            router
519                .clone()
520                .into_make_service_with_connect_info::<std::net::SocketAddr>(),
521        );
522        #[cfg(unix)]
523        serve(
524            UnixListener::bind("").unwrap(),
525            router.into_make_service_with_connect_info::<UdsConnectInfo>(),
526        );
527
528        // method router
529        serve(TcpListener::bind(addr).await.unwrap(), get(handler));
530        serve(tcp_nodelay_listener().await, get(handler));
531        #[cfg(unix)]
532        serve(UnixListener::bind("").unwrap(), get(handler));
533
534        serve(
535            TcpListener::bind(addr).await.unwrap(),
536            get(handler).into_make_service(),
537        );
538        serve(
539            tcp_nodelay_listener().await,
540            get(handler).into_make_service(),
541        );
542        #[cfg(unix)]
543        serve(
544            UnixListener::bind("").unwrap(),
545            get(handler).into_make_service(),
546        );
547
548        serve(
549            TcpListener::bind(addr).await.unwrap(),
550            get(handler).into_make_service_with_connect_info::<std::net::SocketAddr>(),
551        );
552        serve(
553            tcp_nodelay_listener().await,
554            get(handler).into_make_service_with_connect_info::<std::net::SocketAddr>(),
555        );
556        #[cfg(unix)]
557        serve(
558            UnixListener::bind("").unwrap(),
559            get(handler).into_make_service_with_connect_info::<UdsConnectInfo>(),
560        );
561
562        // handler
563        serve(
564            TcpListener::bind(addr).await.unwrap(),
565            handler.into_service(),
566        );
567        serve(tcp_nodelay_listener().await, handler.into_service());
568        #[cfg(unix)]
569        serve(UnixListener::bind("").unwrap(), handler.into_service());
570
571        serve(
572            TcpListener::bind(addr).await.unwrap(),
573            handler.with_state(()),
574        );
575        serve(tcp_nodelay_listener().await, handler.with_state(()));
576        #[cfg(unix)]
577        serve(UnixListener::bind("").unwrap(), handler.with_state(()));
578
579        serve(
580            TcpListener::bind(addr).await.unwrap(),
581            handler.into_make_service(),
582        );
583        serve(tcp_nodelay_listener().await, handler.into_make_service());
584        #[cfg(unix)]
585        serve(UnixListener::bind("").unwrap(), handler.into_make_service());
586
587        serve(
588            TcpListener::bind(addr).await.unwrap(),
589            handler.into_make_service_with_connect_info::<std::net::SocketAddr>(),
590        );
591        serve(
592            tcp_nodelay_listener().await,
593            handler.into_make_service_with_connect_info::<std::net::SocketAddr>(),
594        );
595        #[cfg(unix)]
596        serve(
597            UnixListener::bind("").unwrap(),
598            handler.into_make_service_with_connect_info::<UdsConnectInfo>(),
599        );
600    }
601
602    async fn handler() {}
603
604    #[crate::test]
605    async fn test_serve_local_addr() {
606        let router: Router = Router::new();
607        let addr = "0.0.0.0:0";
608
609        let server = serve(TcpListener::bind(addr).await.unwrap(), router.clone());
610        let address = server.local_addr().unwrap();
611
612        assert_eq!(address.ip(), IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)));
613        assert_ne!(address.port(), 0);
614    }
615
616    #[crate::test]
617    async fn test_with_graceful_shutdown_local_addr() {
618        let router: Router = Router::new();
619        let addr = "0.0.0.0:0";
620
621        let server = serve(TcpListener::bind(addr).await.unwrap(), router.clone())
622            .with_graceful_shutdown(pending());
623        let address = server.local_addr().unwrap();
624
625        assert_eq!(address.ip(), IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)));
626        assert_ne!(address.port(), 0);
627    }
628
629    #[test]
630    fn into_future_outside_tokio() {
631        let router: Router = Router::new();
632        let addr = "0.0.0.0:0";
633
634        let rt = tokio::runtime::Builder::new_multi_thread()
635            .enable_all()
636            .build()
637            .unwrap();
638
639        let listener = rt.block_on(tokio::net::TcpListener::bind(addr)).unwrap();
640
641        // Call Serve::into_future outside of a tokio context. This used to panic.
642        _ = serve(listener, router).into_future();
643    }
644
645    #[crate::test]
646    async fn serving_on_custom_io_type() {
647        struct ReadyListener<T>(Option<T>);
648
649        impl<T> Listener for ReadyListener<T>
650        where
651            T: AsyncRead + AsyncWrite + Unpin + Send + 'static,
652        {
653            type Io = T;
654            type Addr = ();
655
656            async fn accept(&mut self) -> (Self::Io, Self::Addr) {
657                match self.0.take() {
658                    Some(server) => (server, ()),
659                    None => std::future::pending().await,
660                }
661            }
662
663            fn local_addr(&self) -> io::Result<Self::Addr> {
664                Ok(())
665            }
666        }
667
668        let (client, server) = io::duplex(1024);
669        let listener = ReadyListener(Some(server));
670
671        let app = Router::new().route("/", get(|| async { "Hello, World!" }));
672
673        tokio::spawn(serve(listener, app).into_future());
674
675        let stream = TokioIo::new(client);
676        let (mut sender, conn) = hyper::client::conn::http1::handshake(stream).await.unwrap();
677        tokio::spawn(conn);
678
679        let request = Request::builder().body(Body::empty()).unwrap();
680
681        let response = sender.send_request(request).await.unwrap();
682        assert_eq!(response.status(), StatusCode::OK);
683
684        let body = Body::new(response.into_body());
685        let body = to_bytes(body, usize::MAX).await.unwrap();
686        let body = String::from_utf8(body.to_vec()).unwrap();
687        assert_eq!(body, "Hello, World!");
688    }
689}