axum/serve/
mod.rs

1//! Serve services.
2
3use std::{
4    convert::Infallible,
5    fmt::Debug,
6    future::{Future, IntoFuture},
7    io,
8    marker::PhantomData,
9    pin::pin,
10};
11
12use axum_core::{body::Body, extract::Request, response::Response};
13use futures_util::FutureExt;
14use hyper::body::Incoming;
15use hyper_util::rt::{TokioExecutor, TokioIo};
16#[cfg(any(feature = "http1", feature = "http2"))]
17use hyper_util::{server::conn::auto::Builder, service::TowerToHyperService};
18use tokio::sync::watch;
19use tower::ServiceExt as _;
20use tower_service::Service;
21
22mod listener;
23
24pub use self::listener::{Listener, ListenerExt, TapIo};
25
26/// Serve the service with the supplied listener.
27///
28/// This method of running a service is intentionally simple and doesn't support any configuration.
29/// Use hyper or hyper-util if you need configuration.
30///
31/// It supports both HTTP/1 as well as HTTP/2.
32///
33/// # Examples
34///
35/// Serving a [`Router`]:
36///
37/// ```
38/// use axum::{Router, routing::get};
39///
40/// # async {
41/// let router = Router::new().route("/", get(|| async { "Hello, World!" }));
42///
43/// let listener = tokio::net::TcpListener::bind("0.0.0.0:3000").await.unwrap();
44/// axum::serve(listener, router).await.unwrap();
45/// # };
46/// ```
47///
48/// See also [`Router::into_make_service_with_connect_info`].
49///
50/// Serving a [`MethodRouter`]:
51///
52/// ```
53/// use axum::routing::get;
54///
55/// # async {
56/// let router = get(|| async { "Hello, World!" });
57///
58/// let listener = tokio::net::TcpListener::bind("0.0.0.0:3000").await.unwrap();
59/// axum::serve(listener, router).await.unwrap();
60/// # };
61/// ```
62///
63/// See also [`MethodRouter::into_make_service_with_connect_info`].
64///
65/// Serving a [`Handler`]:
66///
67/// ```
68/// use axum::handler::HandlerWithoutStateExt;
69///
70/// # async {
71/// async fn handler() -> &'static str {
72///     "Hello, World!"
73/// }
74///
75/// let listener = tokio::net::TcpListener::bind("0.0.0.0:3000").await.unwrap();
76/// axum::serve(listener, handler.into_make_service()).await.unwrap();
77/// # };
78/// ```
79///
80/// See also [`HandlerWithoutStateExt::into_make_service_with_connect_info`] and
81/// [`HandlerService::into_make_service_with_connect_info`].
82///
83/// # Return Value
84///
85/// Although this future resolves to `io::Result<()>`, it will never actually complete or return an
86/// error. Errors on the TCP socket will be handled by sleeping for a short while (currently, one
87/// second).
88///
89/// [`Router`]: crate::Router
90/// [`Router::into_make_service_with_connect_info`]: crate::Router::into_make_service_with_connect_info
91/// [`MethodRouter`]: crate::routing::MethodRouter
92/// [`MethodRouter::into_make_service_with_connect_info`]: crate::routing::MethodRouter::into_make_service_with_connect_info
93/// [`Handler`]: crate::handler::Handler
94/// [`HandlerWithoutStateExt::into_make_service_with_connect_info`]: crate::handler::HandlerWithoutStateExt::into_make_service_with_connect_info
95/// [`HandlerService::into_make_service_with_connect_info`]: crate::handler::HandlerService::into_make_service_with_connect_info
96#[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))]
97pub fn serve<L, M, S>(listener: L, make_service: M) -> Serve<L, M, S>
98where
99    L: Listener,
100    M: for<'a> Service<IncomingStream<'a, L>, Error = Infallible, Response = S>,
101    S: Service<Request, Response = Response, Error = Infallible> + Clone + Send + 'static,
102    S::Future: Send,
103{
104    Serve {
105        listener,
106        make_service,
107        _marker: PhantomData,
108    }
109}
110
111/// Future returned by [`serve`].
112#[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))]
113#[must_use = "futures must be awaited or polled"]
114pub struct Serve<L, M, S> {
115    listener: L,
116    make_service: M,
117    _marker: PhantomData<S>,
118}
119
120#[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))]
121impl<L, M, S> Serve<L, M, S>
122where
123    L: Listener,
124{
125    /// Prepares a server to handle graceful shutdown when the provided future completes.
126    ///
127    /// # Example
128    ///
129    /// ```
130    /// use axum::{Router, routing::get};
131    ///
132    /// # async {
133    /// let router = Router::new().route("/", get(|| async { "Hello, World!" }));
134    ///
135    /// let listener = tokio::net::TcpListener::bind("0.0.0.0:3000").await.unwrap();
136    /// axum::serve(listener, router)
137    ///     .with_graceful_shutdown(shutdown_signal())
138    ///     .await
139    ///     .unwrap();
140    /// # };
141    ///
142    /// async fn shutdown_signal() {
143    ///     // ...
144    /// }
145    /// ```
146    ///
147    /// # Return Value
148    ///
149    /// Similarly to [`serve`], although this future resolves to `io::Result<()>`, it will never
150    /// error. It returns `Ok(())` only after the `signal` future completes.
151    pub fn with_graceful_shutdown<F>(self, signal: F) -> WithGracefulShutdown<L, M, S, F>
152    where
153        F: Future<Output = ()> + Send + 'static,
154    {
155        WithGracefulShutdown {
156            listener: self.listener,
157            make_service: self.make_service,
158            signal,
159            _marker: PhantomData,
160        }
161    }
162
163    /// Returns the local address this server is bound to.
164    pub fn local_addr(&self) -> io::Result<L::Addr> {
165        self.listener.local_addr()
166    }
167}
168
169#[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))]
170impl<L, M, S> Serve<L, M, S>
171where
172    L: Listener,
173    L::Addr: Debug,
174    M: for<'a> Service<IncomingStream<'a, L>, Error = Infallible, Response = S> + Send + 'static,
175    for<'a> <M as Service<IncomingStream<'a, L>>>::Future: Send,
176    S: Service<Request, Response = Response, Error = Infallible> + Clone + Send + 'static,
177    S::Future: Send,
178{
179    async fn run(self) -> ! {
180        let Self {
181            mut listener,
182            mut make_service,
183            _marker,
184        } = self;
185
186        let (signal_tx, _signal_rx) = watch::channel(());
187        let (_close_tx, close_rx) = watch::channel(());
188
189        loop {
190            let (io, remote_addr) = listener.accept().await;
191            handle_connection(&mut make_service, &signal_tx, &close_rx, io, remote_addr).await;
192        }
193    }
194}
195
196#[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))]
197impl<L, M, S> Debug for Serve<L, M, S>
198where
199    L: Debug + 'static,
200    M: Debug,
201{
202    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
203        let Self {
204            listener,
205            make_service,
206            _marker: _,
207        } = self;
208
209        let mut s = f.debug_struct("Serve");
210        s.field("listener", listener)
211            .field("make_service", make_service);
212
213        s.finish()
214    }
215}
216
217#[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))]
218impl<L, M, S> IntoFuture for Serve<L, M, S>
219where
220    L: Listener,
221    L::Addr: Debug,
222    M: for<'a> Service<IncomingStream<'a, L>, Error = Infallible, Response = S> + Send + 'static,
223    for<'a> <M as Service<IncomingStream<'a, L>>>::Future: Send,
224    S: Service<Request, Response = Response, Error = Infallible> + Clone + Send + 'static,
225    S::Future: Send,
226{
227    type Output = io::Result<()>;
228    type IntoFuture = private::ServeFuture;
229
230    fn into_future(self) -> Self::IntoFuture {
231        private::ServeFuture(Box::pin(async move { self.run().await }))
232    }
233}
234
235/// Serve future with graceful shutdown enabled.
236#[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))]
237#[must_use = "futures must be awaited or polled"]
238pub struct WithGracefulShutdown<L, M, S, F> {
239    listener: L,
240    make_service: M,
241    signal: F,
242    _marker: PhantomData<S>,
243}
244
245#[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))]
246impl<L, M, S, F> WithGracefulShutdown<L, M, S, F>
247where
248    L: Listener,
249{
250    /// Returns the local address this server is bound to.
251    pub fn local_addr(&self) -> io::Result<L::Addr> {
252        self.listener.local_addr()
253    }
254}
255
256#[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))]
257impl<L, M, S, F> WithGracefulShutdown<L, M, S, F>
258where
259    L: Listener,
260    L::Addr: Debug,
261    M: for<'a> Service<IncomingStream<'a, L>, Error = Infallible, Response = S> + Send + 'static,
262    for<'a> <M as Service<IncomingStream<'a, L>>>::Future: Send,
263    S: Service<Request, Response = Response, Error = Infallible> + Clone + Send + 'static,
264    S::Future: Send,
265    F: Future<Output = ()> + Send + 'static,
266{
267    async fn run(self) {
268        let Self {
269            mut listener,
270            mut make_service,
271            signal,
272            _marker,
273        } = self;
274
275        let (signal_tx, signal_rx) = watch::channel(());
276        tokio::spawn(async move {
277            signal.await;
278            trace!("received graceful shutdown signal. Telling tasks to shutdown");
279            drop(signal_rx);
280        });
281
282        let (close_tx, close_rx) = watch::channel(());
283
284        loop {
285            let (io, remote_addr) = tokio::select! {
286                conn = listener.accept() => conn,
287                _ = signal_tx.closed() => {
288                    trace!("signal received, not accepting new connections");
289                    break;
290                }
291            };
292
293            handle_connection(&mut make_service, &signal_tx, &close_rx, io, remote_addr).await;
294        }
295
296        drop(close_rx);
297        drop(listener);
298
299        trace!(
300            "waiting for {} task(s) to finish",
301            close_tx.receiver_count()
302        );
303        close_tx.closed().await;
304    }
305}
306
307#[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))]
308impl<L, M, S, F> Debug for WithGracefulShutdown<L, M, S, F>
309where
310    L: Debug + 'static,
311    M: Debug,
312    S: Debug,
313    F: Debug,
314{
315    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
316        let Self {
317            listener,
318            make_service,
319            signal,
320            _marker: _,
321        } = self;
322
323        f.debug_struct("WithGracefulShutdown")
324            .field("listener", listener)
325            .field("make_service", make_service)
326            .field("signal", signal)
327            .finish()
328    }
329}
330
331#[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))]
332impl<L, M, S, F> IntoFuture for WithGracefulShutdown<L, M, S, F>
333where
334    L: Listener,
335    L::Addr: Debug,
336    M: for<'a> Service<IncomingStream<'a, L>, Error = Infallible, Response = S> + Send + 'static,
337    for<'a> <M as Service<IncomingStream<'a, L>>>::Future: Send,
338    S: Service<Request, Response = Response, Error = Infallible> + Clone + Send + 'static,
339    S::Future: Send,
340    F: Future<Output = ()> + Send + 'static,
341{
342    type Output = io::Result<()>;
343    type IntoFuture = private::ServeFuture;
344
345    fn into_future(self) -> Self::IntoFuture {
346        private::ServeFuture(Box::pin(async move {
347            self.run().await;
348            Ok(())
349        }))
350    }
351}
352
353async fn handle_connection<L, M, S>(
354    make_service: &mut M,
355    signal_tx: &watch::Sender<()>,
356    close_rx: &watch::Receiver<()>,
357    io: <L as Listener>::Io,
358    remote_addr: <L as Listener>::Addr,
359) where
360    L: Listener,
361    L::Addr: Debug,
362    M: for<'a> Service<IncomingStream<'a, L>, Error = Infallible, Response = S> + Send + 'static,
363    for<'a> <M as Service<IncomingStream<'a, L>>>::Future: Send,
364    S: Service<Request, Response = Response, Error = Infallible> + Clone + Send + 'static,
365    S::Future: Send,
366{
367    let io = TokioIo::new(io);
368
369    trace!("connection {remote_addr:?} accepted");
370
371    make_service
372        .ready()
373        .await
374        .unwrap_or_else(|err| match err {});
375
376    let tower_service = make_service
377        .call(IncomingStream {
378            io: &io,
379            remote_addr,
380        })
381        .await
382        .unwrap_or_else(|err| match err {})
383        .map_request(|req: Request<Incoming>| req.map(Body::new));
384
385    let hyper_service = TowerToHyperService::new(tower_service);
386    let signal_tx = signal_tx.clone();
387    let close_rx = close_rx.clone();
388
389    tokio::spawn(async move {
390        #[allow(unused_mut)]
391        let mut builder = Builder::new(TokioExecutor::new());
392        // CONNECT protocol needed for HTTP/2 websockets
393        #[cfg(feature = "http2")]
394        builder.http2().enable_connect_protocol();
395
396        let mut conn = pin!(builder.serve_connection_with_upgrades(io, hyper_service));
397        let mut signal_closed = pin!(signal_tx.closed().fuse());
398
399        loop {
400            tokio::select! {
401                result = conn.as_mut() => {
402                    if let Err(_err) = result {
403                        trace!("failed to serve connection: {_err:#}");
404                    }
405                    break;
406                }
407                _ = &mut signal_closed => {
408                    trace!("signal received in task, starting graceful shutdown");
409                    conn.as_mut().graceful_shutdown();
410                }
411            }
412        }
413
414        drop(close_rx);
415    });
416}
417
418/// An incoming stream.
419///
420/// Used with [`serve`] and [`IntoMakeServiceWithConnectInfo`].
421///
422/// [`IntoMakeServiceWithConnectInfo`]: crate::extract::connect_info::IntoMakeServiceWithConnectInfo
423#[derive(Debug)]
424pub struct IncomingStream<'a, L>
425where
426    L: Listener,
427{
428    io: &'a TokioIo<L::Io>,
429    remote_addr: L::Addr,
430}
431
432impl<L> IncomingStream<'_, L>
433where
434    L: Listener,
435{
436    /// Get a reference to the inner IO type.
437    pub fn io(&self) -> &L::Io {
438        self.io.inner()
439    }
440
441    /// Returns the remote address that this stream is bound to.
442    pub fn remote_addr(&self) -> &L::Addr {
443        &self.remote_addr
444    }
445}
446
447mod private {
448    use std::{
449        future::Future,
450        io,
451        pin::Pin,
452        task::{Context, Poll},
453    };
454
455    pub struct ServeFuture(pub(super) futures_util::future::BoxFuture<'static, io::Result<()>>);
456
457    impl Future for ServeFuture {
458        type Output = io::Result<()>;
459
460        #[inline]
461        fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
462            self.0.as_mut().poll(cx)
463        }
464    }
465
466    impl std::fmt::Debug for ServeFuture {
467        fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
468            f.debug_struct("ServeFuture").finish_non_exhaustive()
469        }
470    }
471}
472
473#[cfg(test)]
474mod tests {
475    use std::{
476        future::{pending, IntoFuture as _},
477        net::{IpAddr, Ipv4Addr},
478    };
479
480    use axum_core::{body::Body, extract::Request};
481    use http::StatusCode;
482    use hyper_util::rt::TokioIo;
483    #[cfg(unix)]
484    use tokio::net::UnixListener;
485    use tokio::{
486        io::{self, AsyncRead, AsyncWrite},
487        net::TcpListener,
488    };
489
490    #[cfg(unix)]
491    use super::IncomingStream;
492    use super::{serve, Listener};
493    #[cfg(unix)]
494    use crate::extract::connect_info::Connected;
495    use crate::{
496        body::to_bytes,
497        handler::{Handler, HandlerWithoutStateExt},
498        routing::get,
499        serve::ListenerExt,
500        Router,
501    };
502
503    #[allow(dead_code, unused_must_use)]
504    async fn if_it_compiles_it_works() {
505        #[derive(Clone, Debug)]
506        struct UdsConnectInfo;
507
508        #[cfg(unix)]
509        impl Connected<IncomingStream<'_, UnixListener>> for UdsConnectInfo {
510            fn connect_info(_stream: IncomingStream<'_, UnixListener>) -> Self {
511                Self
512            }
513        }
514
515        let router: Router = Router::new();
516
517        let addr = "0.0.0.0:0";
518
519        let tcp_nodelay_listener = || async {
520            TcpListener::bind(addr).await.unwrap().tap_io(|tcp_stream| {
521                if let Err(err) = tcp_stream.set_nodelay(true) {
522                    eprintln!("failed to set TCP_NODELAY on incoming connection: {err:#}");
523                }
524            })
525        };
526
527        // router
528        serve(TcpListener::bind(addr).await.unwrap(), router.clone());
529        serve(tcp_nodelay_listener().await, router.clone())
530            .await
531            .unwrap();
532        #[cfg(unix)]
533        serve(UnixListener::bind("").unwrap(), router.clone());
534
535        serve(
536            TcpListener::bind(addr).await.unwrap(),
537            router.clone().into_make_service(),
538        );
539        serve(
540            tcp_nodelay_listener().await,
541            router.clone().into_make_service(),
542        );
543        #[cfg(unix)]
544        serve(
545            UnixListener::bind("").unwrap(),
546            router.clone().into_make_service(),
547        );
548
549        serve(
550            TcpListener::bind(addr).await.unwrap(),
551            router
552                .clone()
553                .into_make_service_with_connect_info::<std::net::SocketAddr>(),
554        );
555        serve(
556            tcp_nodelay_listener().await,
557            router
558                .clone()
559                .into_make_service_with_connect_info::<std::net::SocketAddr>(),
560        );
561        #[cfg(unix)]
562        serve(
563            UnixListener::bind("").unwrap(),
564            router.into_make_service_with_connect_info::<UdsConnectInfo>(),
565        );
566
567        // method router
568        serve(TcpListener::bind(addr).await.unwrap(), get(handler));
569        serve(tcp_nodelay_listener().await, get(handler));
570        #[cfg(unix)]
571        serve(UnixListener::bind("").unwrap(), get(handler));
572
573        serve(
574            TcpListener::bind(addr).await.unwrap(),
575            get(handler).into_make_service(),
576        );
577        serve(
578            tcp_nodelay_listener().await,
579            get(handler).into_make_service(),
580        );
581        #[cfg(unix)]
582        serve(
583            UnixListener::bind("").unwrap(),
584            get(handler).into_make_service(),
585        );
586
587        serve(
588            TcpListener::bind(addr).await.unwrap(),
589            get(handler).into_make_service_with_connect_info::<std::net::SocketAddr>(),
590        );
591        serve(
592            tcp_nodelay_listener().await,
593            get(handler).into_make_service_with_connect_info::<std::net::SocketAddr>(),
594        );
595        #[cfg(unix)]
596        serve(
597            UnixListener::bind("").unwrap(),
598            get(handler).into_make_service_with_connect_info::<UdsConnectInfo>(),
599        );
600
601        // handler
602        serve(
603            TcpListener::bind(addr).await.unwrap(),
604            handler.into_service(),
605        );
606        serve(tcp_nodelay_listener().await, handler.into_service());
607        #[cfg(unix)]
608        serve(UnixListener::bind("").unwrap(), handler.into_service());
609
610        serve(
611            TcpListener::bind(addr).await.unwrap(),
612            handler.with_state(()),
613        );
614        serve(tcp_nodelay_listener().await, handler.with_state(()));
615        #[cfg(unix)]
616        serve(UnixListener::bind("").unwrap(), handler.with_state(()));
617
618        serve(
619            TcpListener::bind(addr).await.unwrap(),
620            handler.into_make_service(),
621        );
622        serve(tcp_nodelay_listener().await, handler.into_make_service());
623        #[cfg(unix)]
624        serve(UnixListener::bind("").unwrap(), handler.into_make_service());
625
626        serve(
627            TcpListener::bind(addr).await.unwrap(),
628            handler.into_make_service_with_connect_info::<std::net::SocketAddr>(),
629        );
630        serve(
631            tcp_nodelay_listener().await,
632            handler.into_make_service_with_connect_info::<std::net::SocketAddr>(),
633        );
634        #[cfg(unix)]
635        serve(
636            UnixListener::bind("").unwrap(),
637            handler.into_make_service_with_connect_info::<UdsConnectInfo>(),
638        );
639    }
640
641    async fn handler() {}
642
643    #[crate::test]
644    async fn test_serve_local_addr() {
645        let router: Router = Router::new();
646        let addr = "0.0.0.0:0";
647
648        let server = serve(TcpListener::bind(addr).await.unwrap(), router.clone());
649        let address = server.local_addr().unwrap();
650
651        assert_eq!(address.ip(), IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)));
652        assert_ne!(address.port(), 0);
653    }
654
655    #[crate::test]
656    async fn test_with_graceful_shutdown_local_addr() {
657        let router: Router = Router::new();
658        let addr = "0.0.0.0:0";
659
660        let server = serve(TcpListener::bind(addr).await.unwrap(), router.clone())
661            .with_graceful_shutdown(pending());
662        let address = server.local_addr().unwrap();
663
664        assert_eq!(address.ip(), IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)));
665        assert_ne!(address.port(), 0);
666    }
667
668    #[test]
669    fn into_future_outside_tokio() {
670        let router: Router = Router::new();
671        let addr = "0.0.0.0:0";
672
673        let rt = tokio::runtime::Builder::new_multi_thread()
674            .enable_all()
675            .build()
676            .unwrap();
677
678        let listener = rt.block_on(tokio::net::TcpListener::bind(addr)).unwrap();
679
680        // Call Serve::into_future outside of a tokio context. This used to panic.
681        _ = serve(listener, router).into_future();
682    }
683
684    #[crate::test]
685    async fn serving_on_custom_io_type() {
686        struct ReadyListener<T>(Option<T>);
687
688        impl<T> Listener for ReadyListener<T>
689        where
690            T: AsyncRead + AsyncWrite + Unpin + Send + 'static,
691        {
692            type Io = T;
693            type Addr = ();
694
695            async fn accept(&mut self) -> (Self::Io, Self::Addr) {
696                match self.0.take() {
697                    Some(server) => (server, ()),
698                    None => std::future::pending().await,
699                }
700            }
701
702            fn local_addr(&self) -> io::Result<Self::Addr> {
703                Ok(())
704            }
705        }
706
707        let (client, server) = io::duplex(1024);
708        let listener = ReadyListener(Some(server));
709
710        let app = Router::new().route("/", get(|| async { "Hello, World!" }));
711
712        tokio::spawn(serve(listener, app).into_future());
713
714        let stream = TokioIo::new(client);
715        let (mut sender, conn) = hyper::client::conn::http1::handshake(stream).await.unwrap();
716        tokio::spawn(conn);
717
718        let request = Request::builder().body(Body::empty()).unwrap();
719
720        let response = sender.send_request(request).await.unwrap();
721        assert_eq!(response.status(), StatusCode::OK);
722
723        let body = Body::new(response.into_body());
724        let body = to_bytes(body, usize::MAX).await.unwrap();
725        let body = String::from_utf8(body.to_vec()).unwrap();
726        assert_eq!(body, "Hello, World!");
727    }
728}