1use std::{
4 convert::Infallible,
5 fmt::Debug,
6 future::{poll_fn, Future, IntoFuture},
7 io,
8 marker::PhantomData,
9 sync::Arc,
10};
11
12use axum_core::{body::Body, extract::Request, response::Response};
13use futures_util::{pin_mut, FutureExt};
14use hyper::body::Incoming;
15use hyper_util::rt::{TokioExecutor, TokioIo};
16#[cfg(any(feature = "http1", feature = "http2"))]
17use hyper_util::{server::conn::auto::Builder, service::TowerToHyperService};
18use tokio::sync::watch;
19use tower::ServiceExt as _;
20use tower_service::Service;
21
22mod listener;
23
24pub use self::listener::{Listener, ListenerExt, TapIo};
25
26#[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))]
97pub fn serve<L, M, S>(listener: L, make_service: M) -> Serve<L, M, S>
98where
99 L: Listener,
100 M: for<'a> Service<IncomingStream<'a, L>, Error = Infallible, Response = S>,
101 S: Service<Request, Response = Response, Error = Infallible> + Clone + Send + 'static,
102 S::Future: Send,
103{
104 Serve {
105 listener,
106 make_service,
107 _marker: PhantomData,
108 }
109}
110
111#[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))]
113#[must_use = "futures must be awaited or polled"]
114pub struct Serve<L, M, S> {
115 listener: L,
116 make_service: M,
117 _marker: PhantomData<S>,
118}
119
120#[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))]
121impl<L, M, S> Serve<L, M, S>
122where
123 L: Listener,
124{
125 pub fn with_graceful_shutdown<F>(self, signal: F) -> WithGracefulShutdown<L, M, S, F>
152 where
153 F: Future<Output = ()> + Send + 'static,
154 {
155 WithGracefulShutdown {
156 listener: self.listener,
157 make_service: self.make_service,
158 signal,
159 _marker: PhantomData,
160 }
161 }
162
163 pub fn local_addr(&self) -> io::Result<L::Addr> {
165 self.listener.local_addr()
166 }
167}
168
169#[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))]
170impl<L, M, S> Debug for Serve<L, M, S>
171where
172 L: Debug + 'static,
173 M: Debug,
174{
175 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
176 let Self {
177 listener,
178 make_service,
179 _marker: _,
180 } = self;
181
182 let mut s = f.debug_struct("Serve");
183 s.field("listener", listener)
184 .field("make_service", make_service);
185
186 s.finish()
187 }
188}
189
190#[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))]
191impl<L, M, S> IntoFuture for Serve<L, M, S>
192where
193 L: Listener,
194 L::Addr: Debug,
195 M: for<'a> Service<IncomingStream<'a, L>, Error = Infallible, Response = S> + Send + 'static,
196 for<'a> <M as Service<IncomingStream<'a, L>>>::Future: Send,
197 S: Service<Request, Response = Response, Error = Infallible> + Clone + Send + 'static,
198 S::Future: Send,
199{
200 type Output = io::Result<()>;
201 type IntoFuture = private::ServeFuture;
202
203 fn into_future(self) -> Self::IntoFuture {
204 self.with_graceful_shutdown(std::future::pending())
205 .into_future()
206 }
207}
208
209#[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))]
211#[must_use = "futures must be awaited or polled"]
212pub struct WithGracefulShutdown<L, M, S, F> {
213 listener: L,
214 make_service: M,
215 signal: F,
216 _marker: PhantomData<S>,
217}
218
219#[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))]
220impl<L, M, S, F> WithGracefulShutdown<L, M, S, F>
221where
222 L: Listener,
223{
224 pub fn local_addr(&self) -> io::Result<L::Addr> {
226 self.listener.local_addr()
227 }
228}
229
230#[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))]
231impl<L, M, S, F> Debug for WithGracefulShutdown<L, M, S, F>
232where
233 L: Debug + 'static,
234 M: Debug,
235 S: Debug,
236 F: Debug,
237{
238 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
239 let Self {
240 listener,
241 make_service,
242 signal,
243 _marker: _,
244 } = self;
245
246 f.debug_struct("WithGracefulShutdown")
247 .field("listener", listener)
248 .field("make_service", make_service)
249 .field("signal", signal)
250 .finish()
251 }
252}
253
254#[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))]
255impl<L, M, S, F> IntoFuture for WithGracefulShutdown<L, M, S, F>
256where
257 L: Listener,
258 L::Addr: Debug,
259 M: for<'a> Service<IncomingStream<'a, L>, Error = Infallible, Response = S> + Send + 'static,
260 for<'a> <M as Service<IncomingStream<'a, L>>>::Future: Send,
261 S: Service<Request, Response = Response, Error = Infallible> + Clone + Send + 'static,
262 S::Future: Send,
263 F: Future<Output = ()> + Send + 'static,
264{
265 type Output = io::Result<()>;
266 type IntoFuture = private::ServeFuture;
267
268 fn into_future(self) -> Self::IntoFuture {
269 private::ServeFuture(Box::pin(async move {
270 self.run().await;
271 Ok(())
272 }))
273 }
274}
275
276#[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))]
277impl<L, M, S, F> WithGracefulShutdown<L, M, S, F>
278where
279 L: Listener,
280 L::Addr: Debug,
281 M: for<'a> Service<IncomingStream<'a, L>, Error = Infallible, Response = S> + Send + 'static,
282 for<'a> <M as Service<IncomingStream<'a, L>>>::Future: Send,
283 S: Service<Request, Response = Response, Error = Infallible> + Clone + Send + 'static,
284 S::Future: Send,
285 F: Future<Output = ()> + Send + 'static,
286{
287 async fn run(self) {
288 let Self {
289 mut listener,
290 mut make_service,
291 signal,
292 _marker: _,
293 } = self;
294
295 let (signal_tx, signal_rx) = watch::channel(());
296 let signal_tx = Arc::new(signal_tx);
297 tokio::spawn(async move {
298 signal.await;
299 trace!("received graceful shutdown signal. Telling tasks to shutdown");
300 drop(signal_rx);
301 });
302
303 let (close_tx, close_rx) = watch::channel(());
304
305 loop {
306 let (io, remote_addr) = tokio::select! {
307 conn = listener.accept() => conn,
308 _ = signal_tx.closed() => {
309 trace!("signal received, not accepting new connections");
310 break;
311 }
312 };
313
314 let io = TokioIo::new(io);
315
316 trace!("connection {remote_addr:?} accepted");
317
318 poll_fn(|cx| make_service.poll_ready(cx))
319 .await
320 .unwrap_or_else(|err| match err {});
321
322 let tower_service = make_service
323 .call(IncomingStream {
324 io: &io,
325 remote_addr,
326 })
327 .await
328 .unwrap_or_else(|err| match err {})
329 .map_request(|req: Request<Incoming>| req.map(Body::new));
330
331 let hyper_service = TowerToHyperService::new(tower_service);
332
333 let signal_tx = Arc::clone(&signal_tx);
334
335 let close_rx = close_rx.clone();
336
337 tokio::spawn(async move {
338 #[allow(unused_mut)]
339 let mut builder = Builder::new(TokioExecutor::new());
340 #[cfg(feature = "http2")]
342 builder.http2().enable_connect_protocol();
343 let conn = builder.serve_connection_with_upgrades(io, hyper_service);
344 pin_mut!(conn);
345
346 let signal_closed = signal_tx.closed().fuse();
347 pin_mut!(signal_closed);
348
349 loop {
350 tokio::select! {
351 result = conn.as_mut() => {
352 if let Err(_err) = result {
353 trace!("failed to serve connection: {_err:#}");
354 }
355 break;
356 }
357 _ = &mut signal_closed => {
358 trace!("signal received in task, starting graceful shutdown");
359 conn.as_mut().graceful_shutdown();
360 }
361 }
362 }
363
364 drop(close_rx);
365 });
366 }
367
368 drop(close_rx);
369 drop(listener);
370
371 trace!(
372 "waiting for {} task(s) to finish",
373 close_tx.receiver_count()
374 );
375 close_tx.closed().await;
376 }
377}
378
379#[derive(Debug)]
385pub struct IncomingStream<'a, L>
386where
387 L: Listener,
388{
389 io: &'a TokioIo<L::Io>,
390 remote_addr: L::Addr,
391}
392
393impl<L> IncomingStream<'_, L>
394where
395 L: Listener,
396{
397 pub fn io(&self) -> &L::Io {
399 self.io.inner()
400 }
401
402 pub fn remote_addr(&self) -> &L::Addr {
404 &self.remote_addr
405 }
406}
407
408mod private {
409 use std::{
410 future::Future,
411 io,
412 pin::Pin,
413 task::{Context, Poll},
414 };
415
416 pub struct ServeFuture(pub(super) futures_util::future::BoxFuture<'static, io::Result<()>>);
417
418 impl Future for ServeFuture {
419 type Output = io::Result<()>;
420
421 #[inline]
422 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
423 self.0.as_mut().poll(cx)
424 }
425 }
426
427 impl std::fmt::Debug for ServeFuture {
428 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
429 f.debug_struct("ServeFuture").finish_non_exhaustive()
430 }
431 }
432}
433
434#[cfg(test)]
435mod tests {
436 use std::{
437 future::{pending, IntoFuture as _},
438 net::{IpAddr, Ipv4Addr},
439 };
440
441 use axum_core::{body::Body, extract::Request};
442 use http::StatusCode;
443 use hyper_util::rt::TokioIo;
444 #[cfg(unix)]
445 use tokio::net::UnixListener;
446 use tokio::{
447 io::{self, AsyncRead, AsyncWrite},
448 net::TcpListener,
449 };
450
451 #[cfg(unix)]
452 use super::IncomingStream;
453 use super::{serve, Listener};
454 #[cfg(unix)]
455 use crate::extract::connect_info::Connected;
456 use crate::{
457 body::to_bytes,
458 handler::{Handler, HandlerWithoutStateExt},
459 routing::get,
460 serve::ListenerExt,
461 Router,
462 };
463
464 #[allow(dead_code, unused_must_use)]
465 async fn if_it_compiles_it_works() {
466 #[derive(Clone, Debug)]
467 struct UdsConnectInfo;
468
469 #[cfg(unix)]
470 impl Connected<IncomingStream<'_, UnixListener>> for UdsConnectInfo {
471 fn connect_info(_stream: IncomingStream<'_, UnixListener>) -> Self {
472 Self
473 }
474 }
475
476 let router: Router = Router::new();
477
478 let addr = "0.0.0.0:0";
479
480 let tcp_nodelay_listener = || async {
481 TcpListener::bind(addr).await.unwrap().tap_io(|tcp_stream| {
482 if let Err(err) = tcp_stream.set_nodelay(true) {
483 eprintln!("failed to set TCP_NODELAY on incoming connection: {err:#}");
484 }
485 })
486 };
487
488 serve(TcpListener::bind(addr).await.unwrap(), router.clone());
490 serve(tcp_nodelay_listener().await, router.clone())
491 .await
492 .unwrap();
493 #[cfg(unix)]
494 serve(UnixListener::bind("").unwrap(), router.clone());
495
496 serve(
497 TcpListener::bind(addr).await.unwrap(),
498 router.clone().into_make_service(),
499 );
500 serve(
501 tcp_nodelay_listener().await,
502 router.clone().into_make_service(),
503 );
504 #[cfg(unix)]
505 serve(
506 UnixListener::bind("").unwrap(),
507 router.clone().into_make_service(),
508 );
509
510 serve(
511 TcpListener::bind(addr).await.unwrap(),
512 router
513 .clone()
514 .into_make_service_with_connect_info::<std::net::SocketAddr>(),
515 );
516 serve(
517 tcp_nodelay_listener().await,
518 router
519 .clone()
520 .into_make_service_with_connect_info::<std::net::SocketAddr>(),
521 );
522 #[cfg(unix)]
523 serve(
524 UnixListener::bind("").unwrap(),
525 router.into_make_service_with_connect_info::<UdsConnectInfo>(),
526 );
527
528 serve(TcpListener::bind(addr).await.unwrap(), get(handler));
530 serve(tcp_nodelay_listener().await, get(handler));
531 #[cfg(unix)]
532 serve(UnixListener::bind("").unwrap(), get(handler));
533
534 serve(
535 TcpListener::bind(addr).await.unwrap(),
536 get(handler).into_make_service(),
537 );
538 serve(
539 tcp_nodelay_listener().await,
540 get(handler).into_make_service(),
541 );
542 #[cfg(unix)]
543 serve(
544 UnixListener::bind("").unwrap(),
545 get(handler).into_make_service(),
546 );
547
548 serve(
549 TcpListener::bind(addr).await.unwrap(),
550 get(handler).into_make_service_with_connect_info::<std::net::SocketAddr>(),
551 );
552 serve(
553 tcp_nodelay_listener().await,
554 get(handler).into_make_service_with_connect_info::<std::net::SocketAddr>(),
555 );
556 #[cfg(unix)]
557 serve(
558 UnixListener::bind("").unwrap(),
559 get(handler).into_make_service_with_connect_info::<UdsConnectInfo>(),
560 );
561
562 serve(
564 TcpListener::bind(addr).await.unwrap(),
565 handler.into_service(),
566 );
567 serve(tcp_nodelay_listener().await, handler.into_service());
568 #[cfg(unix)]
569 serve(UnixListener::bind("").unwrap(), handler.into_service());
570
571 serve(
572 TcpListener::bind(addr).await.unwrap(),
573 handler.with_state(()),
574 );
575 serve(tcp_nodelay_listener().await, handler.with_state(()));
576 #[cfg(unix)]
577 serve(UnixListener::bind("").unwrap(), handler.with_state(()));
578
579 serve(
580 TcpListener::bind(addr).await.unwrap(),
581 handler.into_make_service(),
582 );
583 serve(tcp_nodelay_listener().await, handler.into_make_service());
584 #[cfg(unix)]
585 serve(UnixListener::bind("").unwrap(), handler.into_make_service());
586
587 serve(
588 TcpListener::bind(addr).await.unwrap(),
589 handler.into_make_service_with_connect_info::<std::net::SocketAddr>(),
590 );
591 serve(
592 tcp_nodelay_listener().await,
593 handler.into_make_service_with_connect_info::<std::net::SocketAddr>(),
594 );
595 #[cfg(unix)]
596 serve(
597 UnixListener::bind("").unwrap(),
598 handler.into_make_service_with_connect_info::<UdsConnectInfo>(),
599 );
600 }
601
602 async fn handler() {}
603
604 #[crate::test]
605 async fn test_serve_local_addr() {
606 let router: Router = Router::new();
607 let addr = "0.0.0.0:0";
608
609 let server = serve(TcpListener::bind(addr).await.unwrap(), router.clone());
610 let address = server.local_addr().unwrap();
611
612 assert_eq!(address.ip(), IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)));
613 assert_ne!(address.port(), 0);
614 }
615
616 #[crate::test]
617 async fn test_with_graceful_shutdown_local_addr() {
618 let router: Router = Router::new();
619 let addr = "0.0.0.0:0";
620
621 let server = serve(TcpListener::bind(addr).await.unwrap(), router.clone())
622 .with_graceful_shutdown(pending());
623 let address = server.local_addr().unwrap();
624
625 assert_eq!(address.ip(), IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)));
626 assert_ne!(address.port(), 0);
627 }
628
629 #[test]
630 fn into_future_outside_tokio() {
631 let router: Router = Router::new();
632 let addr = "0.0.0.0:0";
633
634 let rt = tokio::runtime::Builder::new_multi_thread()
635 .enable_all()
636 .build()
637 .unwrap();
638
639 let listener = rt.block_on(tokio::net::TcpListener::bind(addr)).unwrap();
640
641 _ = serve(listener, router).into_future();
643 }
644
645 #[crate::test]
646 async fn serving_on_custom_io_type() {
647 struct ReadyListener<T>(Option<T>);
648
649 impl<T> Listener for ReadyListener<T>
650 where
651 T: AsyncRead + AsyncWrite + Unpin + Send + 'static,
652 {
653 type Io = T;
654 type Addr = ();
655
656 async fn accept(&mut self) -> (Self::Io, Self::Addr) {
657 match self.0.take() {
658 Some(server) => (server, ()),
659 None => std::future::pending().await,
660 }
661 }
662
663 fn local_addr(&self) -> io::Result<Self::Addr> {
664 Ok(())
665 }
666 }
667
668 let (client, server) = io::duplex(1024);
669 let listener = ReadyListener(Some(server));
670
671 let app = Router::new().route("/", get(|| async { "Hello, World!" }));
672
673 tokio::spawn(serve(listener, app).into_future());
674
675 let stream = TokioIo::new(client);
676 let (mut sender, conn) = hyper::client::conn::http1::handshake(stream).await.unwrap();
677 tokio::spawn(conn);
678
679 let request = Request::builder().body(Body::empty()).unwrap();
680
681 let response = sender.send_request(request).await.unwrap();
682 assert_eq!(response.status(), StatusCode::OK);
683
684 let body = Body::new(response.into_body());
685 let body = to_bytes(body, usize::MAX).await.unwrap();
686 let body = String::from_utf8(body.to_vec()).unwrap();
687 assert_eq!(body, "Hello, World!");
688 }
689}