axum/
extension.rs

1use crate::{extract::rejection::*, response::IntoResponseParts};
2use axum_core::extract::OptionalFromRequestParts;
3use axum_core::{
4    extract::FromRequestParts,
5    response::{IntoResponse, Response, ResponseParts},
6};
7use http::{request::Parts, Extensions, Request};
8use std::{
9    convert::Infallible,
10    task::{Context, Poll},
11};
12use tower_service::Service;
13
14/// Extractor and response for extensions.
15///
16/// # As extractor
17///
18/// This is commonly used to share state across handlers.
19///
20/// ```rust,no_run
21/// use axum::{
22///     Router,
23///     Extension,
24///     routing::get,
25/// };
26/// use std::sync::Arc;
27///
28/// // Some shared state used throughout our application
29/// struct State {
30///     // ...
31/// }
32///
33/// async fn handler(state: Extension<Arc<State>>) {
34///     // ...
35/// }
36///
37/// let state = Arc::new(State { /* ... */ });
38///
39/// let app = Router::new().route("/", get(handler))
40///     // Add middleware that inserts the state into all incoming request's
41///     // extensions.
42///     .layer(Extension(state));
43/// # let _: Router = app;
44/// ```
45///
46/// If the extension is missing it will reject the request with a `500 Internal
47/// Server Error` response. Alternatively, you can use `Option<Extension<T>>` to
48/// make the extension extractor optional.
49///
50/// # As response
51///
52/// Response extensions can be used to share state with middleware.
53///
54/// ```rust
55/// use axum::{
56///     Extension,
57///     response::IntoResponse,
58/// };
59///
60/// async fn handler() -> (Extension<Foo>, &'static str) {
61///     (
62///         Extension(Foo("foo")),
63///         "Hello, World!"
64///     )
65/// }
66///
67/// #[derive(Clone)]
68/// struct Foo(&'static str);
69/// ```
70#[derive(Debug, Clone, Copy, Default)]
71#[must_use]
72pub struct Extension<T>(pub T);
73
74impl<T> Extension<T>
75where
76    T: Clone + Send + Sync + 'static,
77{
78    fn from_extensions(extensions: &Extensions) -> Option<Self> {
79        extensions.get::<T>().cloned().map(Extension)
80    }
81}
82
83impl<T, S> FromRequestParts<S> for Extension<T>
84where
85    T: Clone + Send + Sync + 'static,
86    S: Send + Sync,
87{
88    type Rejection = ExtensionRejection;
89
90    async fn from_request_parts(req: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
91        Ok(Self::from_extensions(&req.extensions).ok_or_else(|| {
92            MissingExtension::from_err(format!(
93                "Extension of type `{}` was not found. Perhaps you forgot to add it? See `axum::Extension`.",
94                std::any::type_name::<T>()
95            ))
96        })?)
97    }
98}
99
100impl<T, S> OptionalFromRequestParts<S> for Extension<T>
101where
102    T: Clone + Send + Sync + 'static,
103    S: Send + Sync,
104{
105    type Rejection = Infallible;
106
107    async fn from_request_parts(
108        req: &mut Parts,
109        _state: &S,
110    ) -> Result<Option<Self>, Self::Rejection> {
111        Ok(Self::from_extensions(&req.extensions))
112    }
113}
114
115axum_core::__impl_deref!(Extension);
116
117impl<T> IntoResponseParts for Extension<T>
118where
119    T: Clone + Send + Sync + 'static,
120{
121    type Error = Infallible;
122
123    fn into_response_parts(self, mut res: ResponseParts) -> Result<ResponseParts, Self::Error> {
124        res.extensions_mut().insert(self.0);
125        Ok(res)
126    }
127}
128
129impl<T> IntoResponse for Extension<T>
130where
131    T: Clone + Send + Sync + 'static,
132{
133    fn into_response(self) -> Response {
134        let mut res = ().into_response();
135        res.extensions_mut().insert(self.0);
136        res
137    }
138}
139
140impl<S, T> tower_layer::Layer<S> for Extension<T>
141where
142    T: Clone + Send + Sync + 'static,
143{
144    type Service = AddExtension<S, T>;
145
146    fn layer(&self, inner: S) -> Self::Service {
147        AddExtension {
148            inner,
149            value: self.0.clone(),
150        }
151    }
152}
153
154/// Middleware for adding some shareable value to [request extensions].
155///
156/// See [Passing state from middleware to handlers](index.html#passing-state-from-middleware-to-handlers)
157/// for more details.
158///
159/// [request extensions]: https://docs.rs/http/latest/http/struct.Extensions.html
160#[derive(Clone, Copy, Debug)]
161pub struct AddExtension<S, T> {
162    pub(crate) inner: S,
163    pub(crate) value: T,
164}
165
166impl<ResBody, S, T> Service<Request<ResBody>> for AddExtension<S, T>
167where
168    S: Service<Request<ResBody>>,
169    T: Clone + Send + Sync + 'static,
170{
171    type Response = S::Response;
172    type Error = S::Error;
173    type Future = S::Future;
174
175    #[inline]
176    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
177        self.inner.poll_ready(cx)
178    }
179
180    fn call(&mut self, mut req: Request<ResBody>) -> Self::Future {
181        req.extensions_mut().insert(self.value.clone());
182        self.inner.call(req)
183    }
184}
185
186#[cfg(test)]
187mod tests {
188    use super::*;
189    use crate::routing::get;
190    use crate::test_helpers::TestClient;
191    use crate::Router;
192    use http::StatusCode;
193
194    #[derive(Clone)]
195    struct Foo(String);
196
197    #[derive(Clone)]
198    struct Bar(String);
199
200    #[crate::test]
201    async fn extension_extractor() {
202        async fn requires_foo(Extension(foo): Extension<Foo>) -> String {
203            foo.0
204        }
205
206        async fn optional_foo(extension: Option<Extension<Foo>>) -> String {
207            extension.map(|foo| foo.0 .0).unwrap_or("none".to_owned())
208        }
209
210        async fn requires_bar(Extension(bar): Extension<Bar>) -> String {
211            bar.0
212        }
213
214        async fn optional_bar(extension: Option<Extension<Bar>>) -> String {
215            extension.map(|bar| bar.0 .0).unwrap_or("none".to_owned())
216        }
217
218        let app = Router::new()
219            .route("/requires_foo", get(requires_foo))
220            .route("/optional_foo", get(optional_foo))
221            .route("/requires_bar", get(requires_bar))
222            .route("/optional_bar", get(optional_bar))
223            .layer(Extension(Foo("foo".to_owned())));
224
225        let client = TestClient::new(app);
226
227        let response = client.get("/requires_foo").await;
228        assert_eq!(response.status(), StatusCode::OK);
229        assert_eq!(response.text().await, "foo");
230
231        let response = client.get("/optional_foo").await;
232        assert_eq!(response.status(), StatusCode::OK);
233        assert_eq!(response.text().await, "foo");
234
235        let response = client.get("/requires_bar").await;
236        assert_eq!(response.status(), StatusCode::INTERNAL_SERVER_ERROR);
237        assert_eq!(response.text().await, "Missing request extension: Extension of type `axum::extension::tests::Bar` was not found. Perhaps you forgot to add it? See `axum::Extension`.");
238
239        let response = client.get("/optional_bar").await;
240        assert_eq!(response.status(), StatusCode::OK);
241        assert_eq!(response.text().await, "none");
242    }
243}