1use crate::extract::Request;
2use crate::extract::{rejection::*, FromRequest, RawForm};
3use axum_core::response::{IntoResponse, Response};
4use axum_core::RequestExt;
5use http::header::CONTENT_TYPE;
6use http::StatusCode;
7use serde_core::{de::DeserializeOwned, Serialize};
8
9#[cfg_attr(docsrs, doc(cfg(feature = "form")))]
69#[derive(Debug, Clone, Copy, Default)]
70#[must_use]
71pub struct Form<T>(pub T);
72
73impl<T, S> FromRequest<S> for Form<T>
74where
75 T: DeserializeOwned,
76 S: Send + Sync,
77{
78 type Rejection = FormRejection;
79
80 async fn from_request(req: Request, _state: &S) -> Result<Self, Self::Rejection> {
81 let is_get_or_head =
82 req.method() == http::Method::GET || req.method() == http::Method::HEAD;
83
84 match req.extract().await {
85 Ok(RawForm(bytes)) => {
86 let deserializer =
87 serde_urlencoded::Deserializer::new(form_urlencoded::parse(&bytes));
88 let value = serde_path_to_error::deserialize(deserializer).map_err(
89 |err| -> FormRejection {
90 if is_get_or_head {
91 FailedToDeserializeForm::from_err(err).into()
92 } else {
93 FailedToDeserializeFormBody::from_err(err).into()
94 }
95 },
96 )?;
97 Ok(Form(value))
98 }
99 Err(RawFormRejection::BytesRejection(r)) => Err(FormRejection::BytesRejection(r)),
100 Err(RawFormRejection::InvalidFormContentType(r)) => {
101 Err(FormRejection::InvalidFormContentType(r))
102 }
103 }
104 }
105}
106
107impl<T> IntoResponse for Form<T>
108where
109 T: Serialize,
110{
111 fn into_response(self) -> Response {
112 fn make_response(ser_result: Result<String, serde_urlencoded::ser::Error>) -> Response {
114 match ser_result {
115 Ok(body) => (
116 [(CONTENT_TYPE, mime::APPLICATION_WWW_FORM_URLENCODED.as_ref())],
117 body,
118 )
119 .into_response(),
120 Err(err) => (StatusCode::INTERNAL_SERVER_ERROR, err.to_string()).into_response(),
121 }
122 }
123
124 make_response(serde_urlencoded::to_string(&self.0))
125 }
126}
127axum_core::__impl_deref!(Form);
128
129#[cfg(test)]
130mod tests {
131 use crate::{
132 routing::{on, MethodFilter},
133 test_helpers::TestClient,
134 Router,
135 };
136
137 use super::*;
138 use axum_core::body::Body;
139 use http::{Method, Request};
140 use mime::APPLICATION_WWW_FORM_URLENCODED;
141 use serde::{Deserialize, Serialize};
142 use std::fmt::Debug;
143
144 #[derive(Debug, PartialEq, Serialize, Deserialize)]
145 struct Pagination {
146 size: Option<u64>,
147 page: Option<u64>,
148 }
149
150 async fn check_query<T: DeserializeOwned + PartialEq + Debug>(uri: impl AsRef<str>, value: T) {
151 let req = Request::builder()
152 .uri(uri.as_ref())
153 .body(Body::empty())
154 .unwrap();
155 assert_eq!(Form::<T>::from_request(req, &()).await.unwrap().0, value);
156 }
157
158 async fn check_body<T: Serialize + DeserializeOwned + PartialEq + Debug>(value: T) {
159 let req = Request::builder()
160 .uri("http://example.com/test")
161 .method(Method::POST)
162 .header(CONTENT_TYPE, APPLICATION_WWW_FORM_URLENCODED.as_ref())
163 .body(Body::from(serde_urlencoded::to_string(&value).unwrap()))
164 .unwrap();
165 assert_eq!(Form::<T>::from_request(req, &()).await.unwrap().0, value);
166 }
167
168 #[crate::test]
169 async fn test_form_query() {
170 check_query(
171 "http://example.com/test",
172 Pagination {
173 size: None,
174 page: None,
175 },
176 )
177 .await;
178
179 check_query(
180 "http://example.com/test?size=10",
181 Pagination {
182 size: Some(10),
183 page: None,
184 },
185 )
186 .await;
187
188 check_query(
189 "http://example.com/test?size=10&page=20",
190 Pagination {
191 size: Some(10),
192 page: Some(20),
193 },
194 )
195 .await;
196 }
197
198 #[crate::test]
199 async fn test_form_body() {
200 check_body(Pagination {
201 size: None,
202 page: None,
203 })
204 .await;
205
206 check_body(Pagination {
207 size: Some(10),
208 page: None,
209 })
210 .await;
211
212 check_body(Pagination {
213 size: Some(10),
214 page: Some(20),
215 })
216 .await;
217 }
218
219 #[crate::test]
220 async fn test_incorrect_content_type() {
221 let req = Request::builder()
222 .uri("http://example.com/test")
223 .method(Method::POST)
224 .header(CONTENT_TYPE, mime::APPLICATION_JSON.as_ref())
225 .body(Body::from(
226 serde_urlencoded::to_string(&Pagination {
227 size: Some(10),
228 page: None,
229 })
230 .unwrap(),
231 ))
232 .unwrap();
233 assert!(matches!(
234 Form::<Pagination>::from_request(req, &())
235 .await
236 .unwrap_err(),
237 FormRejection::InvalidFormContentType(InvalidFormContentType)
238 ));
239 }
240
241 #[tokio::test]
242 async fn deserialize_error_status_codes() {
243 #[allow(dead_code)]
244 #[derive(Deserialize)]
245 struct Payload {
246 a: i32,
247 }
248
249 let app = Router::new().route(
250 "/",
251 on(
252 MethodFilter::GET.or(MethodFilter::POST),
253 |_: Form<Payload>| async {},
254 ),
255 );
256
257 let client = TestClient::new(app);
258
259 let res = client.get("/?a=false").await;
260 assert_eq!(res.status(), StatusCode::BAD_REQUEST);
261 assert_eq!(
262 res.text().await,
263 "Failed to deserialize form: a: invalid digit found in string"
264 );
265
266 let res = client
267 .post("/")
268 .header(CONTENT_TYPE, APPLICATION_WWW_FORM_URLENCODED.as_ref())
269 .body("a=false")
270 .await;
271 assert_eq!(res.status(), StatusCode::UNPROCESSABLE_ENTITY);
272 assert_eq!(
273 res.text().await,
274 "Failed to deserialize form body: a: invalid digit found in string"
275 );
276 }
277}