1use crate::extract::Request;
2use crate::extract::{rejection::*, FromRequest, RawForm};
3use axum_core::response::{IntoResponse, Response};
4use axum_core::RequestExt;
5use http::header::CONTENT_TYPE;
6use http::StatusCode;
7use serde::de::DeserializeOwned;
8use serde::Serialize;
9
10#[cfg_attr(docsrs, doc(cfg(feature = "form")))]
70#[derive(Debug, Clone, Copy, Default)]
71#[must_use]
72pub struct Form<T>(pub T);
73
74impl<T, S> FromRequest<S> for Form<T>
75where
76 T: DeserializeOwned,
77 S: Send + Sync,
78{
79 type Rejection = FormRejection;
80
81 async fn from_request(req: Request, _state: &S) -> Result<Self, Self::Rejection> {
82 let is_get_or_head =
83 req.method() == http::Method::GET || req.method() == http::Method::HEAD;
84
85 match req.extract().await {
86 Ok(RawForm(bytes)) => {
87 let deserializer =
88 serde_urlencoded::Deserializer::new(form_urlencoded::parse(&bytes));
89 let value = serde_path_to_error::deserialize(deserializer).map_err(
90 |err| -> FormRejection {
91 if is_get_or_head {
92 FailedToDeserializeForm::from_err(err).into()
93 } else {
94 FailedToDeserializeFormBody::from_err(err).into()
95 }
96 },
97 )?;
98 Ok(Form(value))
99 }
100 Err(RawFormRejection::BytesRejection(r)) => Err(FormRejection::BytesRejection(r)),
101 Err(RawFormRejection::InvalidFormContentType(r)) => {
102 Err(FormRejection::InvalidFormContentType(r))
103 }
104 }
105 }
106}
107
108impl<T> IntoResponse for Form<T>
109where
110 T: Serialize,
111{
112 fn into_response(self) -> Response {
113 fn make_response(ser_result: Result<String, serde_urlencoded::ser::Error>) -> Response {
115 match ser_result {
116 Ok(body) => (
117 [(CONTENT_TYPE, mime::APPLICATION_WWW_FORM_URLENCODED.as_ref())],
118 body,
119 )
120 .into_response(),
121 Err(err) => (StatusCode::INTERNAL_SERVER_ERROR, err.to_string()).into_response(),
122 }
123 }
124
125 make_response(serde_urlencoded::to_string(&self.0))
126 }
127}
128axum_core::__impl_deref!(Form);
129
130#[cfg(test)]
131mod tests {
132 use crate::{
133 routing::{on, MethodFilter},
134 test_helpers::TestClient,
135 Router,
136 };
137
138 use super::*;
139 use axum_core::body::Body;
140 use http::{Method, Request};
141 use mime::APPLICATION_WWW_FORM_URLENCODED;
142 use serde::{Deserialize, Serialize};
143 use std::fmt::Debug;
144
145 #[derive(Debug, PartialEq, Serialize, Deserialize)]
146 struct Pagination {
147 size: Option<u64>,
148 page: Option<u64>,
149 }
150
151 async fn check_query<T: DeserializeOwned + PartialEq + Debug>(uri: impl AsRef<str>, value: T) {
152 let req = Request::builder()
153 .uri(uri.as_ref())
154 .body(Body::empty())
155 .unwrap();
156 assert_eq!(Form::<T>::from_request(req, &()).await.unwrap().0, value);
157 }
158
159 async fn check_body<T: Serialize + DeserializeOwned + PartialEq + Debug>(value: T) {
160 let req = Request::builder()
161 .uri("http://example.com/test")
162 .method(Method::POST)
163 .header(CONTENT_TYPE, APPLICATION_WWW_FORM_URLENCODED.as_ref())
164 .body(Body::from(serde_urlencoded::to_string(&value).unwrap()))
165 .unwrap();
166 assert_eq!(Form::<T>::from_request(req, &()).await.unwrap().0, value);
167 }
168
169 #[crate::test]
170 async fn test_form_query() {
171 check_query(
172 "http://example.com/test",
173 Pagination {
174 size: None,
175 page: None,
176 },
177 )
178 .await;
179
180 check_query(
181 "http://example.com/test?size=10",
182 Pagination {
183 size: Some(10),
184 page: None,
185 },
186 )
187 .await;
188
189 check_query(
190 "http://example.com/test?size=10&page=20",
191 Pagination {
192 size: Some(10),
193 page: Some(20),
194 },
195 )
196 .await;
197 }
198
199 #[crate::test]
200 async fn test_form_body() {
201 check_body(Pagination {
202 size: None,
203 page: None,
204 })
205 .await;
206
207 check_body(Pagination {
208 size: Some(10),
209 page: None,
210 })
211 .await;
212
213 check_body(Pagination {
214 size: Some(10),
215 page: Some(20),
216 })
217 .await;
218 }
219
220 #[crate::test]
221 async fn test_incorrect_content_type() {
222 let req = Request::builder()
223 .uri("http://example.com/test")
224 .method(Method::POST)
225 .header(CONTENT_TYPE, mime::APPLICATION_JSON.as_ref())
226 .body(Body::from(
227 serde_urlencoded::to_string(&Pagination {
228 size: Some(10),
229 page: None,
230 })
231 .unwrap(),
232 ))
233 .unwrap();
234 assert!(matches!(
235 Form::<Pagination>::from_request(req, &())
236 .await
237 .unwrap_err(),
238 FormRejection::InvalidFormContentType(InvalidFormContentType)
239 ));
240 }
241
242 #[tokio::test]
243 async fn deserialize_error_status_codes() {
244 #[allow(dead_code)]
245 #[derive(Deserialize)]
246 struct Payload {
247 a: i32,
248 }
249
250 let app = Router::new().route(
251 "/",
252 on(
253 MethodFilter::GET.or(MethodFilter::POST),
254 |_: Form<Payload>| async {},
255 ),
256 );
257
258 let client = TestClient::new(app);
259
260 let res = client.get("/?a=false").await;
261 assert_eq!(res.status(), StatusCode::BAD_REQUEST);
262 assert_eq!(
263 res.text().await,
264 "Failed to deserialize form: a: invalid digit found in string"
265 );
266
267 let res = client
268 .post("/")
269 .header(CONTENT_TYPE, APPLICATION_WWW_FORM_URLENCODED.as_ref())
270 .body("a=false")
271 .await;
272 assert_eq!(res.status(), StatusCode::UNPROCESSABLE_ENTITY);
273 assert_eq!(
274 res.text().await,
275 "Failed to deserialize form body: a: invalid digit found in string"
276 );
277 }
278}