1use crate::{extract::rejection::*, response::IntoResponseParts};
2use axum_core::extract::OptionalFromRequestParts;
3use axum_core::{
4 extract::FromRequestParts,
5 response::{IntoResponse, Response, ResponseParts},
6};
7use http::{request::Parts, Extensions, Request};
8use std::{
9 convert::Infallible,
10 task::{Context, Poll},
11};
12use tower_service::Service;
13
14#[derive(Debug, Clone, Copy, Default)]
71#[must_use]
72pub struct Extension<T>(pub T);
73
74impl<T> Extension<T>
75where
76 T: Clone + Send + Sync + 'static,
77{
78 fn from_extensions(extensions: &Extensions) -> Option<Self> {
79 extensions.get::<T>().cloned().map(Extension)
80 }
81}
82
83impl<T, S> FromRequestParts<S> for Extension<T>
84where
85 T: Clone + Send + Sync + 'static,
86 S: Send + Sync,
87{
88 type Rejection = ExtensionRejection;
89
90 async fn from_request_parts(req: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
91 Ok(Self::from_extensions(&req.extensions).ok_or_else(|| {
92 MissingExtension::from_err(format!(
93 "Extension of type `{}` was not found. Perhaps you forgot to add it? See `axum::Extension`.",
94 std::any::type_name::<T>()
95 ))
96 })?)
97 }
98}
99
100impl<T, S> OptionalFromRequestParts<S> for Extension<T>
101where
102 T: Clone + Send + Sync + 'static,
103 S: Send + Sync,
104{
105 type Rejection = Infallible;
106
107 async fn from_request_parts(
108 req: &mut Parts,
109 _state: &S,
110 ) -> Result<Option<Self>, Self::Rejection> {
111 Ok(Self::from_extensions(&req.extensions))
112 }
113}
114
115axum_core::__impl_deref!(Extension);
116
117impl<T> IntoResponseParts for Extension<T>
118where
119 T: Clone + Send + Sync + 'static,
120{
121 type Error = Infallible;
122
123 fn into_response_parts(self, mut res: ResponseParts) -> Result<ResponseParts, Self::Error> {
124 res.extensions_mut().insert(self.0);
125 Ok(res)
126 }
127}
128
129impl<T> IntoResponse for Extension<T>
130where
131 T: Clone + Send + Sync + 'static,
132{
133 fn into_response(self) -> Response {
134 let mut res = ().into_response();
135 res.extensions_mut().insert(self.0);
136 res
137 }
138}
139
140impl<S, T> tower_layer::Layer<S> for Extension<T>
141where
142 T: Clone + Send + Sync + 'static,
143{
144 type Service = AddExtension<S, T>;
145
146 fn layer(&self, inner: S) -> Self::Service {
147 AddExtension {
148 inner,
149 value: self.0.clone(),
150 }
151 }
152}
153
154#[derive(Clone, Copy, Debug)]
161pub struct AddExtension<S, T> {
162 pub(crate) inner: S,
163 pub(crate) value: T,
164}
165
166impl<ResBody, S, T> Service<Request<ResBody>> for AddExtension<S, T>
167where
168 S: Service<Request<ResBody>>,
169 T: Clone + Send + Sync + 'static,
170{
171 type Response = S::Response;
172 type Error = S::Error;
173 type Future = S::Future;
174
175 #[inline]
176 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
177 self.inner.poll_ready(cx)
178 }
179
180 fn call(&mut self, mut req: Request<ResBody>) -> Self::Future {
181 req.extensions_mut().insert(self.value.clone());
182 self.inner.call(req)
183 }
184}
185
186#[cfg(test)]
187mod tests {
188 use super::*;
189 use crate::routing::get;
190 use crate::test_helpers::TestClient;
191 use crate::Router;
192 use http::StatusCode;
193
194 #[derive(Clone)]
195 struct Foo(String);
196
197 #[derive(Clone)]
198 struct Bar(String);
199
200 #[crate::test]
201 async fn extension_extractor() {
202 async fn requires_foo(Extension(foo): Extension<Foo>) -> String {
203 foo.0
204 }
205
206 async fn optional_foo(extension: Option<Extension<Foo>>) -> String {
207 extension.map(|foo| foo.0 .0).unwrap_or("none".to_owned())
208 }
209
210 async fn requires_bar(Extension(bar): Extension<Bar>) -> String {
211 bar.0
212 }
213
214 async fn optional_bar(extension: Option<Extension<Bar>>) -> String {
215 extension.map(|bar| bar.0 .0).unwrap_or("none".to_owned())
216 }
217
218 let app = Router::new()
219 .route("/requires_foo", get(requires_foo))
220 .route("/optional_foo", get(optional_foo))
221 .route("/requires_bar", get(requires_bar))
222 .route("/optional_bar", get(optional_bar))
223 .layer(Extension(Foo("foo".to_owned())));
224
225 let client = TestClient::new(app);
226
227 let response = client.get("/requires_foo").await;
228 assert_eq!(response.status(), StatusCode::OK);
229 assert_eq!(response.text().await, "foo");
230
231 let response = client.get("/optional_foo").await;
232 assert_eq!(response.status(), StatusCode::OK);
233 assert_eq!(response.text().await, "foo");
234
235 let response = client.get("/requires_bar").await;
236 assert_eq!(response.status(), StatusCode::INTERNAL_SERVER_ERROR);
237 assert_eq!(response.text().await, "Missing request extension: Extension of type `axum::extension::tests::Bar` was not found. Perhaps you forgot to add it? See `axum::Extension`.");
238
239 let response = client.get("/optional_bar").await;
240 assert_eq!(response.status(), StatusCode::OK);
241 assert_eq!(response.text().await, "none");
242 }
243}