1use std::{
4 convert::Infallible,
5 fmt::Debug,
6 future::{Future, IntoFuture},
7 io,
8 marker::PhantomData,
9 pin::pin,
10};
11
12use axum_core::{body::Body, extract::Request, response::Response};
13use futures_util::FutureExt;
14use hyper::body::Incoming;
15use hyper_util::rt::{TokioExecutor, TokioIo};
16#[cfg(any(feature = "http1", feature = "http2"))]
17use hyper_util::{server::conn::auto::Builder, service::TowerToHyperService};
18use tokio::sync::watch;
19use tower::ServiceExt as _;
20use tower_service::Service;
21
22mod listener;
23
24pub use self::listener::{Listener, ListenerExt, TapIo};
25
26#[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))]
97pub fn serve<L, M, S>(listener: L, make_service: M) -> Serve<L, M, S>
98where
99 L: Listener,
100 M: for<'a> Service<IncomingStream<'a, L>, Error = Infallible, Response = S>,
101 S: Service<Request, Response = Response, Error = Infallible> + Clone + Send + 'static,
102 S::Future: Send,
103{
104 Serve {
105 listener,
106 make_service,
107 _marker: PhantomData,
108 }
109}
110
111#[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))]
113#[must_use = "futures must be awaited or polled"]
114pub struct Serve<L, M, S> {
115 listener: L,
116 make_service: M,
117 _marker: PhantomData<S>,
118}
119
120#[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))]
121impl<L, M, S> Serve<L, M, S>
122where
123 L: Listener,
124{
125 pub fn with_graceful_shutdown<F>(self, signal: F) -> WithGracefulShutdown<L, M, S, F>
152 where
153 F: Future<Output = ()> + Send + 'static,
154 {
155 WithGracefulShutdown {
156 listener: self.listener,
157 make_service: self.make_service,
158 signal,
159 _marker: PhantomData,
160 }
161 }
162
163 pub fn local_addr(&self) -> io::Result<L::Addr> {
165 self.listener.local_addr()
166 }
167}
168
169#[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))]
170impl<L, M, S> Serve<L, M, S>
171where
172 L: Listener,
173 L::Addr: Debug,
174 M: for<'a> Service<IncomingStream<'a, L>, Error = Infallible, Response = S> + Send + 'static,
175 for<'a> <M as Service<IncomingStream<'a, L>>>::Future: Send,
176 S: Service<Request, Response = Response, Error = Infallible> + Clone + Send + 'static,
177 S::Future: Send,
178{
179 async fn run(self) -> ! {
180 let Self {
181 mut listener,
182 mut make_service,
183 _marker,
184 } = self;
185
186 let (signal_tx, _signal_rx) = watch::channel(());
187 let (_close_tx, close_rx) = watch::channel(());
188
189 loop {
190 let (io, remote_addr) = listener.accept().await;
191 handle_connection(&mut make_service, &signal_tx, &close_rx, io, remote_addr).await;
192 }
193 }
194}
195
196#[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))]
197impl<L, M, S> Debug for Serve<L, M, S>
198where
199 L: Debug + 'static,
200 M: Debug,
201{
202 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
203 let Self {
204 listener,
205 make_service,
206 _marker: _,
207 } = self;
208
209 let mut s = f.debug_struct("Serve");
210 s.field("listener", listener)
211 .field("make_service", make_service);
212
213 s.finish()
214 }
215}
216
217#[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))]
218impl<L, M, S> IntoFuture for Serve<L, M, S>
219where
220 L: Listener,
221 L::Addr: Debug,
222 M: for<'a> Service<IncomingStream<'a, L>, Error = Infallible, Response = S> + Send + 'static,
223 for<'a> <M as Service<IncomingStream<'a, L>>>::Future: Send,
224 S: Service<Request, Response = Response, Error = Infallible> + Clone + Send + 'static,
225 S::Future: Send,
226{
227 type Output = io::Result<()>;
228 type IntoFuture = private::ServeFuture;
229
230 fn into_future(self) -> Self::IntoFuture {
231 private::ServeFuture(Box::pin(async move { self.run().await }))
232 }
233}
234
235#[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))]
237#[must_use = "futures must be awaited or polled"]
238pub struct WithGracefulShutdown<L, M, S, F> {
239 listener: L,
240 make_service: M,
241 signal: F,
242 _marker: PhantomData<S>,
243}
244
245#[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))]
246impl<L, M, S, F> WithGracefulShutdown<L, M, S, F>
247where
248 L: Listener,
249{
250 pub fn local_addr(&self) -> io::Result<L::Addr> {
252 self.listener.local_addr()
253 }
254}
255
256#[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))]
257impl<L, M, S, F> WithGracefulShutdown<L, M, S, F>
258where
259 L: Listener,
260 L::Addr: Debug,
261 M: for<'a> Service<IncomingStream<'a, L>, Error = Infallible, Response = S> + Send + 'static,
262 for<'a> <M as Service<IncomingStream<'a, L>>>::Future: Send,
263 S: Service<Request, Response = Response, Error = Infallible> + Clone + Send + 'static,
264 S::Future: Send,
265 F: Future<Output = ()> + Send + 'static,
266{
267 async fn run(self) {
268 let Self {
269 mut listener,
270 mut make_service,
271 signal,
272 _marker,
273 } = self;
274
275 let (signal_tx, signal_rx) = watch::channel(());
276 tokio::spawn(async move {
277 signal.await;
278 trace!("received graceful shutdown signal. Telling tasks to shutdown");
279 drop(signal_rx);
280 });
281
282 let (close_tx, close_rx) = watch::channel(());
283
284 loop {
285 let (io, remote_addr) = tokio::select! {
286 conn = listener.accept() => conn,
287 _ = signal_tx.closed() => {
288 trace!("signal received, not accepting new connections");
289 break;
290 }
291 };
292
293 handle_connection(&mut make_service, &signal_tx, &close_rx, io, remote_addr).await;
294 }
295
296 drop(close_rx);
297 drop(listener);
298
299 trace!(
300 "waiting for {} task(s) to finish",
301 close_tx.receiver_count()
302 );
303 close_tx.closed().await;
304 }
305}
306
307#[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))]
308impl<L, M, S, F> Debug for WithGracefulShutdown<L, M, S, F>
309where
310 L: Debug + 'static,
311 M: Debug,
312 S: Debug,
313 F: Debug,
314{
315 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
316 let Self {
317 listener,
318 make_service,
319 signal,
320 _marker: _,
321 } = self;
322
323 f.debug_struct("WithGracefulShutdown")
324 .field("listener", listener)
325 .field("make_service", make_service)
326 .field("signal", signal)
327 .finish()
328 }
329}
330
331#[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))]
332impl<L, M, S, F> IntoFuture for WithGracefulShutdown<L, M, S, F>
333where
334 L: Listener,
335 L::Addr: Debug,
336 M: for<'a> Service<IncomingStream<'a, L>, Error = Infallible, Response = S> + Send + 'static,
337 for<'a> <M as Service<IncomingStream<'a, L>>>::Future: Send,
338 S: Service<Request, Response = Response, Error = Infallible> + Clone + Send + 'static,
339 S::Future: Send,
340 F: Future<Output = ()> + Send + 'static,
341{
342 type Output = io::Result<()>;
343 type IntoFuture = private::ServeFuture;
344
345 fn into_future(self) -> Self::IntoFuture {
346 private::ServeFuture(Box::pin(async move {
347 self.run().await;
348 Ok(())
349 }))
350 }
351}
352
353async fn handle_connection<L, M, S>(
354 make_service: &mut M,
355 signal_tx: &watch::Sender<()>,
356 close_rx: &watch::Receiver<()>,
357 io: <L as Listener>::Io,
358 remote_addr: <L as Listener>::Addr,
359) where
360 L: Listener,
361 L::Addr: Debug,
362 M: for<'a> Service<IncomingStream<'a, L>, Error = Infallible, Response = S> + Send + 'static,
363 for<'a> <M as Service<IncomingStream<'a, L>>>::Future: Send,
364 S: Service<Request, Response = Response, Error = Infallible> + Clone + Send + 'static,
365 S::Future: Send,
366{
367 let io = TokioIo::new(io);
368
369 trace!("connection {remote_addr:?} accepted");
370
371 make_service
372 .ready()
373 .await
374 .unwrap_or_else(|err| match err {});
375
376 let tower_service = make_service
377 .call(IncomingStream {
378 io: &io,
379 remote_addr,
380 })
381 .await
382 .unwrap_or_else(|err| match err {})
383 .map_request(|req: Request<Incoming>| req.map(Body::new));
384
385 let hyper_service = TowerToHyperService::new(tower_service);
386 let signal_tx = signal_tx.clone();
387 let close_rx = close_rx.clone();
388
389 tokio::spawn(async move {
390 #[allow(unused_mut)]
391 let mut builder = Builder::new(TokioExecutor::new());
392 #[cfg(feature = "http2")]
394 builder.http2().enable_connect_protocol();
395
396 let mut conn = pin!(builder.serve_connection_with_upgrades(io, hyper_service));
397 let mut signal_closed = pin!(signal_tx.closed().fuse());
398
399 loop {
400 tokio::select! {
401 result = conn.as_mut() => {
402 if let Err(_err) = result {
403 trace!("failed to serve connection: {_err:#}");
404 }
405 break;
406 }
407 _ = &mut signal_closed => {
408 trace!("signal received in task, starting graceful shutdown");
409 conn.as_mut().graceful_shutdown();
410 }
411 }
412 }
413
414 drop(close_rx);
415 });
416}
417
418#[derive(Debug)]
424pub struct IncomingStream<'a, L>
425where
426 L: Listener,
427{
428 io: &'a TokioIo<L::Io>,
429 remote_addr: L::Addr,
430}
431
432impl<L> IncomingStream<'_, L>
433where
434 L: Listener,
435{
436 pub fn io(&self) -> &L::Io {
438 self.io.inner()
439 }
440
441 pub fn remote_addr(&self) -> &L::Addr {
443 &self.remote_addr
444 }
445}
446
447mod private {
448 use std::{
449 future::Future,
450 io,
451 pin::Pin,
452 task::{Context, Poll},
453 };
454
455 pub struct ServeFuture(pub(super) futures_util::future::BoxFuture<'static, io::Result<()>>);
456
457 impl Future for ServeFuture {
458 type Output = io::Result<()>;
459
460 #[inline]
461 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
462 self.0.as_mut().poll(cx)
463 }
464 }
465
466 impl std::fmt::Debug for ServeFuture {
467 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
468 f.debug_struct("ServeFuture").finish_non_exhaustive()
469 }
470 }
471}
472
473#[cfg(test)]
474mod tests {
475 use std::{
476 future::{pending, IntoFuture as _},
477 net::{IpAddr, Ipv4Addr},
478 };
479
480 use axum_core::{body::Body, extract::Request};
481 use http::StatusCode;
482 use hyper_util::rt::TokioIo;
483 #[cfg(unix)]
484 use tokio::net::UnixListener;
485 use tokio::{
486 io::{self, AsyncRead, AsyncWrite},
487 net::TcpListener,
488 };
489
490 #[cfg(unix)]
491 use super::IncomingStream;
492 use super::{serve, Listener};
493 #[cfg(unix)]
494 use crate::extract::connect_info::Connected;
495 use crate::{
496 body::to_bytes,
497 handler::{Handler, HandlerWithoutStateExt},
498 routing::get,
499 serve::ListenerExt,
500 Router,
501 };
502
503 #[allow(dead_code, unused_must_use)]
504 async fn if_it_compiles_it_works() {
505 #[derive(Clone, Debug)]
506 struct UdsConnectInfo;
507
508 #[cfg(unix)]
509 impl Connected<IncomingStream<'_, UnixListener>> for UdsConnectInfo {
510 fn connect_info(_stream: IncomingStream<'_, UnixListener>) -> Self {
511 Self
512 }
513 }
514
515 let router: Router = Router::new();
516
517 let addr = "0.0.0.0:0";
518
519 let tcp_nodelay_listener = || async {
520 TcpListener::bind(addr).await.unwrap().tap_io(|tcp_stream| {
521 if let Err(err) = tcp_stream.set_nodelay(true) {
522 eprintln!("failed to set TCP_NODELAY on incoming connection: {err:#}");
523 }
524 })
525 };
526
527 serve(TcpListener::bind(addr).await.unwrap(), router.clone());
529 serve(tcp_nodelay_listener().await, router.clone())
530 .await
531 .unwrap();
532 #[cfg(unix)]
533 serve(UnixListener::bind("").unwrap(), router.clone());
534
535 serve(
536 TcpListener::bind(addr).await.unwrap(),
537 router.clone().into_make_service(),
538 );
539 serve(
540 tcp_nodelay_listener().await,
541 router.clone().into_make_service(),
542 );
543 #[cfg(unix)]
544 serve(
545 UnixListener::bind("").unwrap(),
546 router.clone().into_make_service(),
547 );
548
549 serve(
550 TcpListener::bind(addr).await.unwrap(),
551 router
552 .clone()
553 .into_make_service_with_connect_info::<std::net::SocketAddr>(),
554 );
555 serve(
556 tcp_nodelay_listener().await,
557 router
558 .clone()
559 .into_make_service_with_connect_info::<std::net::SocketAddr>(),
560 );
561 #[cfg(unix)]
562 serve(
563 UnixListener::bind("").unwrap(),
564 router.into_make_service_with_connect_info::<UdsConnectInfo>(),
565 );
566
567 serve(TcpListener::bind(addr).await.unwrap(), get(handler));
569 serve(tcp_nodelay_listener().await, get(handler));
570 #[cfg(unix)]
571 serve(UnixListener::bind("").unwrap(), get(handler));
572
573 serve(
574 TcpListener::bind(addr).await.unwrap(),
575 get(handler).into_make_service(),
576 );
577 serve(
578 tcp_nodelay_listener().await,
579 get(handler).into_make_service(),
580 );
581 #[cfg(unix)]
582 serve(
583 UnixListener::bind("").unwrap(),
584 get(handler).into_make_service(),
585 );
586
587 serve(
588 TcpListener::bind(addr).await.unwrap(),
589 get(handler).into_make_service_with_connect_info::<std::net::SocketAddr>(),
590 );
591 serve(
592 tcp_nodelay_listener().await,
593 get(handler).into_make_service_with_connect_info::<std::net::SocketAddr>(),
594 );
595 #[cfg(unix)]
596 serve(
597 UnixListener::bind("").unwrap(),
598 get(handler).into_make_service_with_connect_info::<UdsConnectInfo>(),
599 );
600
601 serve(
603 TcpListener::bind(addr).await.unwrap(),
604 handler.into_service(),
605 );
606 serve(tcp_nodelay_listener().await, handler.into_service());
607 #[cfg(unix)]
608 serve(UnixListener::bind("").unwrap(), handler.into_service());
609
610 serve(
611 TcpListener::bind(addr).await.unwrap(),
612 handler.with_state(()),
613 );
614 serve(tcp_nodelay_listener().await, handler.with_state(()));
615 #[cfg(unix)]
616 serve(UnixListener::bind("").unwrap(), handler.with_state(()));
617
618 serve(
619 TcpListener::bind(addr).await.unwrap(),
620 handler.into_make_service(),
621 );
622 serve(tcp_nodelay_listener().await, handler.into_make_service());
623 #[cfg(unix)]
624 serve(UnixListener::bind("").unwrap(), handler.into_make_service());
625
626 serve(
627 TcpListener::bind(addr).await.unwrap(),
628 handler.into_make_service_with_connect_info::<std::net::SocketAddr>(),
629 );
630 serve(
631 tcp_nodelay_listener().await,
632 handler.into_make_service_with_connect_info::<std::net::SocketAddr>(),
633 );
634 #[cfg(unix)]
635 serve(
636 UnixListener::bind("").unwrap(),
637 handler.into_make_service_with_connect_info::<UdsConnectInfo>(),
638 );
639 }
640
641 async fn handler() {}
642
643 #[crate::test]
644 async fn test_serve_local_addr() {
645 let router: Router = Router::new();
646 let addr = "0.0.0.0:0";
647
648 let server = serve(TcpListener::bind(addr).await.unwrap(), router.clone());
649 let address = server.local_addr().unwrap();
650
651 assert_eq!(address.ip(), IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)));
652 assert_ne!(address.port(), 0);
653 }
654
655 #[crate::test]
656 async fn test_with_graceful_shutdown_local_addr() {
657 let router: Router = Router::new();
658 let addr = "0.0.0.0:0";
659
660 let server = serve(TcpListener::bind(addr).await.unwrap(), router.clone())
661 .with_graceful_shutdown(pending());
662 let address = server.local_addr().unwrap();
663
664 assert_eq!(address.ip(), IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)));
665 assert_ne!(address.port(), 0);
666 }
667
668 #[test]
669 fn into_future_outside_tokio() {
670 let router: Router = Router::new();
671 let addr = "0.0.0.0:0";
672
673 let rt = tokio::runtime::Builder::new_multi_thread()
674 .enable_all()
675 .build()
676 .unwrap();
677
678 let listener = rt.block_on(tokio::net::TcpListener::bind(addr)).unwrap();
679
680 _ = serve(listener, router).into_future();
682 }
683
684 #[crate::test]
685 async fn serving_on_custom_io_type() {
686 struct ReadyListener<T>(Option<T>);
687
688 impl<T> Listener for ReadyListener<T>
689 where
690 T: AsyncRead + AsyncWrite + Unpin + Send + 'static,
691 {
692 type Io = T;
693 type Addr = ();
694
695 async fn accept(&mut self) -> (Self::Io, Self::Addr) {
696 match self.0.take() {
697 Some(server) => (server, ()),
698 None => std::future::pending().await,
699 }
700 }
701
702 fn local_addr(&self) -> io::Result<Self::Addr> {
703 Ok(())
704 }
705 }
706
707 let (client, server) = io::duplex(1024);
708 let listener = ReadyListener(Some(server));
709
710 let app = Router::new().route("/", get(|| async { "Hello, World!" }));
711
712 tokio::spawn(serve(listener, app).into_future());
713
714 let stream = TokioIo::new(client);
715 let (mut sender, conn) = hyper::client::conn::http1::handshake(stream).await.unwrap();
716 tokio::spawn(conn);
717
718 let request = Request::builder().body(Body::empty()).unwrap();
719
720 let response = sender.send_request(request).await.unwrap();
721 assert_eq!(response.status(), StatusCode::OK);
722
723 let body = Body::new(response.into_body());
724 let body = to_bytes(body, usize::MAX).await.unwrap();
725 let body = String::from_utf8(body.to_vec()).unwrap();
726 assert_eq!(body, "Hello, World!");
727 }
728}