1use std::{error::Error, fmt, time::Duration};
2
3use bytes::Bytes;
4use futures_util::stream::Stream;
5use log::debug;
6use reqwest::{
7 Client as HttpClient,
8 ClientBuilder as HttpClientBuilder,
9 Error as HttpError,
10 RequestBuilder as HttpRequestBuilder,
11};
12use serde::de::DeserializeOwned;
13use tokio::time::sleep;
14
15use super::payload::{Payload, PayloadError};
16use crate::types::{Response, ResponseError};
17
18const DEFAULT_HOST: &str = "https://api.telegram.org";
19const DEFAULT_MAX_RETRIES: u8 = 2;
20
21#[derive(Clone)]
23pub struct Client {
24 host: String,
25 http_client: HttpClient,
26 token: String,
27 max_retries: u8,
28}
29
30impl Client {
31 pub fn new<T>(token: T) -> Result<Self, ClientError>
37 where
38 T: Into<String>,
39 {
40 let client = HttpClientBuilder::new()
41 .use_rustls_tls()
42 .build()
43 .map_err(ClientError::BuildClient)?;
44 Ok(Self::with_http_client(client, token))
45 }
46
47 pub fn with_http_client<T>(http_client: HttpClient, token: T) -> Self
55 where
56 T: Into<String>,
57 {
58 Self {
59 http_client,
60 host: String::from(DEFAULT_HOST),
61 token: token.into(),
62 max_retries: DEFAULT_MAX_RETRIES,
63 }
64 }
65
66 pub fn with_host<T>(mut self, host: T) -> Self
72 where
73 T: Into<String>,
74 {
75 self.host = host.into();
76 self
77 }
78
79 pub fn with_max_retries(mut self, value: u8) -> Self {
85 self.max_retries = value;
86 self
87 }
88
89 pub async fn download_file<P>(
112 &self,
113 file_path: P,
114 ) -> Result<impl Stream<Item = Result<Bytes, HttpError>> + use<P>, DownloadFileError>
115 where
116 P: AsRef<str>,
117 {
118 let payload = Payload::empty(file_path.as_ref());
119 let url = payload.build_url(&format!("{}/file", &self.host), &self.token);
120 debug!("Downloading file from {url}");
121 let rep = self.http_client.get(&url).send().await?;
122 let status = rep.status();
123 if !status.is_success() {
124 Err(DownloadFileError::Response {
125 status: status.as_u16(),
126 text: rep.text().await?,
127 })
128 } else {
129 Ok(rep.bytes_stream())
130 }
131 }
132
133 pub async fn execute<M>(&self, method: M) -> Result<M::Response, ExecuteError>
144 where
145 M: Method,
146 M::Response: DeserializeOwned + Send + 'static,
147 {
148 let request = method
149 .into_payload()
150 .into_http_request_builder(&self.http_client, &self.host, &self.token)?;
151 let response = match send_request_retry(Box::new(request)).await? {
152 RetryResponse::Ok(response) => response,
153 RetryResponse::Retry {
154 mut request,
155 mut response,
156 mut retry_after,
157 } => {
158 for i in 0..self.max_retries {
159 debug!("Retry attempt {i}, sleeping for {retry_after} second(s)");
160 sleep(Duration::from_secs(retry_after)).await;
161 match send_request_retry(request).await? {
162 RetryResponse::Ok(new_response) => {
163 response = new_response;
164 break;
165 }
166 RetryResponse::Retry {
167 request: new_request,
168 response: new_response,
169 retry_after: new_retry_after,
170 } => {
171 request = new_request;
172 response = new_response;
173 retry_after = new_retry_after;
174 }
175 }
176 }
177 response
178 }
179 };
180 Ok(response.into_result()?)
181 }
182}
183
184enum RetryResponse<T> {
185 Ok(Response<T>),
186 Retry {
187 request: Box<HttpRequestBuilder>,
188 response: Response<T>,
189 retry_after: u64,
190 },
191}
192
193async fn send_request_retry<T>(request: Box<HttpRequestBuilder>) -> Result<RetryResponse<T>, ExecuteError>
194where
195 T: DeserializeOwned,
196{
197 Ok(match request.try_clone() {
198 Some(try_request) => {
199 let response = send_request(try_request).await?;
200 match response.retry_after() {
201 Some(retry_after) => RetryResponse::Retry {
202 request,
203 response,
204 retry_after,
205 },
206 None => RetryResponse::Ok(response),
207 }
208 }
209 None => {
210 debug!("Could not clone builder, sending request without retry");
211 RetryResponse::Ok(send_request(*request).await?)
212 }
213 })
214}
215
216async fn send_request<T>(request: HttpRequestBuilder) -> Result<Response<T>, ExecuteError>
217where
218 T: DeserializeOwned,
219{
220 let response = request.send().await?;
221 Ok(response.json::<Response<T>>().await?)
222}
223
224impl fmt::Debug for Client {
225 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
226 f.debug_struct("Client")
227 .field("http_client", &self.http_client)
228 .field("host", &self.host)
229 .field("token", &format_args!("..."))
230 .finish()
231 }
232}
233
234pub trait Method {
236 type Response;
238
239 fn into_payload(self) -> Payload;
241}
242
243#[derive(Debug)]
245pub enum ClientError {
246 BuildClient(HttpError),
248}
249
250impl Error for ClientError {
251 fn source(&self) -> Option<&(dyn Error + 'static)> {
252 Some(match self {
253 ClientError::BuildClient(err) => err,
254 })
255 }
256}
257
258impl fmt::Display for ClientError {
259 fn fmt(&self, out: &mut fmt::Formatter) -> fmt::Result {
260 match self {
261 ClientError::BuildClient(err) => write!(out, "can not build HTTP client: {err}"),
262 }
263 }
264}
265
266#[derive(Debug)]
269pub enum DownloadFileError {
270 Http(HttpError),
272 Response {
274 status: u16,
276 text: String,
278 },
279}
280
281impl From<HttpError> for DownloadFileError {
282 fn from(err: HttpError) -> Self {
283 Self::Http(err)
284 }
285}
286
287impl Error for DownloadFileError {
288 fn source(&self) -> Option<&(dyn Error + 'static)> {
289 match self {
290 DownloadFileError::Http(err) => Some(err),
291 _ => None,
292 }
293 }
294}
295
296impl fmt::Display for DownloadFileError {
297 fn fmt(&self, out: &mut fmt::Formatter) -> fmt::Result {
298 match self {
299 DownloadFileError::Http(err) => write!(out, "failed to download file: {err}"),
300 DownloadFileError::Response { status, text } => {
301 write!(out, "failed to download file: status={status} text={text}")
302 }
303 }
304 }
305}
306
307#[derive(Debug, derive_more::From)]
310pub enum ExecuteError {
311 Http(HttpError),
313 Payload(PayloadError),
315 Response(ResponseError),
317}
318
319impl Error for ExecuteError {
320 fn source(&self) -> Option<&(dyn Error + 'static)> {
321 use self::ExecuteError::*;
322 Some(match self {
323 Http(err) => err,
324 Payload(err) => err,
325 Response(err) => err,
326 })
327 }
328}
329
330impl fmt::Display for ExecuteError {
331 fn fmt(&self, out: &mut fmt::Formatter) -> fmt::Result {
332 use self::ExecuteError::*;
333 write!(
334 out,
335 "failed to execute method: {}",
336 match self {
337 Http(err) => err.to_string(),
338 Payload(err) => err.to_string(),
339 Response(err) => err.to_string(),
340 }
341 )
342 }
343}
344
345#[cfg(test)]
346mod tests {
347 use super::*;
348
349 #[test]
350 fn api() {
351 let client = Client::new("token").unwrap();
352 assert_eq!(client.token, "token");
353 assert_eq!(client.host, DEFAULT_HOST);
354
355 let client = Client::new("token")
356 .unwrap()
357 .with_host("https://example.com")
358 .with_max_retries(1);
359 assert_eq!(client.token, "token");
360 assert_eq!(client.host, "https://example.com");
361 assert_eq!(client.max_retries, 1);
362 }
363}