Skip to main content

tgbot/api/
client.rs

1use std::{error::Error, fmt, time::Duration};
2
3use bytes::Bytes;
4use futures_util::stream::Stream;
5use log::debug;
6use reqwest::{
7    Client as HttpClient,
8    ClientBuilder as HttpClientBuilder,
9    Error as HttpError,
10    RequestBuilder as HttpRequestBuilder,
11};
12use serde::de::DeserializeOwned;
13use tokio::time::sleep;
14
15use super::payload::{Payload, PayloadError};
16use crate::types::{Response, ResponseError};
17
18const DEFAULT_HOST: &str = "https://api.telegram.org";
19const DEFAULT_MAX_RETRIES: u8 = 2;
20
21/// A client for interacting with the Telegram Bot API.
22#[derive(Clone)]
23pub struct Client {
24    host: String,
25    http_client: HttpClient,
26    token: String,
27    max_retries: u8,
28}
29
30impl Client {
31    /// Creates a new Telegram Bot API client with the provided bot token.
32    ///
33    /// # Arguments
34    ///
35    /// * `token` - A token associated with your bot.
36    pub fn new<T>(token: T) -> Result<Self, ClientError>
37    where
38        T: Into<String>,
39    {
40        let client = {
41            #[cfg(feature = "webpki-roots")]
42            {
43                let root_cert_store = rustls::RootCertStore::from_iter(webpki_roots::TLS_SERVER_ROOTS.iter().cloned());
44
45                let tls_config = rustls::ClientConfig::builder()
46                    .with_root_certificates(root_cert_store)
47                    .with_no_client_auth();
48
49                HttpClientBuilder::new()
50                    .tls_backend_preconfigured(tls_config)
51                    .build()
52                    .map_err(ClientError::BuildClient)?
53            }
54
55            #[cfg(not(feature = "webpki-roots"))]
56            {
57                HttpClientBuilder::new()
58                    .tls_backend_rustls()
59                    .build()
60                    .map_err(ClientError::BuildClient)?
61            }
62        };
63        Ok(Self::with_http_client(client, token))
64    }
65
66    /// Creates a new Telegram Bot API client with a custom HTTP client and bot token.
67    ///
68    /// # Arguments
69    ///
70    /// * `client` - An HTTP client.
71    /// * `token` - A token associated with your bot.
72    ///
73    pub fn with_http_client<T>(http_client: HttpClient, token: T) -> Self
74    where
75        T: Into<String>,
76    {
77        Self {
78            http_client,
79            host: String::from(DEFAULT_HOST),
80            token: token.into(),
81            max_retries: DEFAULT_MAX_RETRIES,
82        }
83    }
84
85    /// Overrides the default API host with a custom one.
86    ///
87    /// # Arguments
88    ///
89    /// * `host` - The new API host to use.
90    pub fn with_host<T>(mut self, host: T) -> Self
91    where
92        T: Into<String>,
93    {
94        self.host = host.into();
95        self
96    }
97
98    /// Overrides the default number of max retries.
99    ///
100    /// # Arguments
101    ///
102    /// * `value` - The new number of max retries
103    pub fn with_max_retries(mut self, value: u8) -> Self {
104        self.max_retries = value;
105        self
106    }
107
108    /// Downloads a file.
109    ///
110    /// Use [`crate::types::GetFile`] method to get a value for the `file_path` argument.
111    ///
112    /// # Arguments
113    ///
114    /// * `file_path` - The path to the file to be downloaded.
115    ///
116    /// # Example
117    ///
118    /// ```
119    /// # async fn download_file() {
120    /// use tgbot::api::Client;
121    /// use futures_util::stream::StreamExt;
122    /// let api = Client::new("token").unwrap();
123    /// let mut stream = api.download_file("path").await.unwrap();
124    /// while let Some(chunk) = stream.next().await {
125    ///     let chunk = chunk.unwrap();
126    ///     // write chunk to something...
127    /// }
128    /// # }
129    /// ```
130    pub async fn download_file<P>(
131        &self,
132        file_path: P,
133    ) -> Result<impl Stream<Item = Result<Bytes, HttpError>> + use<P>, DownloadFileError>
134    where
135        P: AsRef<str>,
136    {
137        let payload = Payload::empty(file_path.as_ref());
138        let url = payload.build_url(&format!("{}/file", &self.host), &self.token);
139        debug!("Downloading file from {url}");
140        let rep = self.http_client.get(&url).send().await?;
141        let status = rep.status();
142        if !status.is_success() {
143            Err(DownloadFileError::Response {
144                status: status.as_u16(),
145                text: rep.text().await?,
146            })
147        } else {
148            Ok(rep.bytes_stream())
149        }
150    }
151
152    /// Executes a method.
153    ///
154    /// # Arguments
155    ///
156    /// * `method` - The method to execute.
157    ///
158    /// # Notes
159    ///
160    /// The client will not retry a request on a timeout error if the request is not cloneable
161    /// (e.g. contains a stream).
162    pub async fn execute<M>(&self, method: M) -> Result<M::Response, ExecuteError>
163    where
164        M: Method,
165        M::Response: DeserializeOwned + Send + 'static,
166    {
167        let request = method
168            .into_payload()
169            .into_http_request_builder(&self.http_client, &self.host, &self.token)?;
170        let response = match send_request_retry(Box::new(request)).await? {
171            RetryResponse::Ok(response) => response,
172            RetryResponse::Retry {
173                mut request,
174                mut response,
175                mut retry_after,
176            } => {
177                for i in 0..self.max_retries {
178                    debug!("Retry attempt {i}, sleeping for {retry_after} second(s)");
179                    sleep(Duration::from_secs(retry_after)).await;
180                    match send_request_retry(request).await? {
181                        RetryResponse::Ok(new_response) => {
182                            response = new_response;
183                            break;
184                        }
185                        RetryResponse::Retry {
186                            request: new_request,
187                            response: new_response,
188                            retry_after: new_retry_after,
189                        } => {
190                            request = new_request;
191                            response = new_response;
192                            retry_after = new_retry_after;
193                        }
194                    }
195                }
196                response
197            }
198        };
199        Ok(response.into_result()?)
200    }
201}
202
203enum RetryResponse<T> {
204    Ok(Response<T>),
205    Retry {
206        request: Box<HttpRequestBuilder>,
207        response: Response<T>,
208        retry_after: u64,
209    },
210}
211
212async fn send_request_retry<T>(request: Box<HttpRequestBuilder>) -> Result<RetryResponse<T>, ExecuteError>
213where
214    T: DeserializeOwned,
215{
216    Ok(match request.try_clone() {
217        Some(try_request) => {
218            let response = send_request(try_request).await?;
219            match response.retry_after() {
220                Some(retry_after) => RetryResponse::Retry {
221                    request,
222                    response,
223                    retry_after,
224                },
225                None => RetryResponse::Ok(response),
226            }
227        }
228        None => {
229            debug!("Could not clone builder, sending request without retry");
230            RetryResponse::Ok(send_request(*request).await?)
231        }
232    })
233}
234
235async fn send_request<T>(request: HttpRequestBuilder) -> Result<Response<T>, ExecuteError>
236where
237    T: DeserializeOwned,
238{
239    let response = request.send().await?;
240    Ok(response.json::<Response<T>>().await?)
241}
242
243impl fmt::Debug for Client {
244    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
245        f.debug_struct("Client")
246            .field("http_client", &self.http_client)
247            .field("host", &self.host)
248            .field("token", &format_args!("..."))
249            .finish()
250    }
251}
252
253/// Represents an API method that can be executed by the Telegram Bot API client.
254pub trait Method {
255    /// The type representing a successful result in an API response.
256    type Response;
257
258    /// Converts the method into a payload for an HTTP request.
259    fn into_payload(self) -> Payload;
260}
261
262/// Represents general errors that can occur while working with the Telegram Bot API client.
263#[derive(Debug)]
264pub enum ClientError {
265    /// An error indicating a failure to build an HTTP client.
266    BuildClient(HttpError),
267}
268
269impl Error for ClientError {
270    fn source(&self) -> Option<&(dyn Error + 'static)> {
271        Some(match self {
272            ClientError::BuildClient(err) => err,
273        })
274    }
275}
276
277impl fmt::Display for ClientError {
278    fn fmt(&self, out: &mut fmt::Formatter) -> fmt::Result {
279        match self {
280            ClientError::BuildClient(err) => write!(out, "can not build HTTP client: {err}"),
281        }
282    }
283}
284
285/// Represents errors that can occur while attempting
286/// to download a file using the Telegram Bot API client.
287#[derive(Debug)]
288pub enum DownloadFileError {
289    /// An error indicating a failure to send an HTTP request.
290    Http(HttpError),
291    /// An error received from the server in response to the download request.
292    Response {
293        /// The HTTP status code received in the response.
294        status: u16,
295        /// The body of the response as a string.
296        text: String,
297    },
298}
299
300impl From<HttpError> for DownloadFileError {
301    fn from(err: HttpError) -> Self {
302        Self::Http(err)
303    }
304}
305
306impl Error for DownloadFileError {
307    fn source(&self) -> Option<&(dyn Error + 'static)> {
308        match self {
309            DownloadFileError::Http(err) => Some(err),
310            _ => None,
311        }
312    }
313}
314
315impl fmt::Display for DownloadFileError {
316    fn fmt(&self, out: &mut fmt::Formatter) -> fmt::Result {
317        match self {
318            DownloadFileError::Http(err) => write!(out, "failed to download file: {err}"),
319            DownloadFileError::Response { status, text } => {
320                write!(out, "failed to download file: status={status} text={text}")
321            }
322        }
323    }
324}
325
326/// Represents errors that can occur during the execution
327/// of a method using the Telegram Bot API client.
328#[derive(Debug, derive_more::From)]
329pub enum ExecuteError {
330    /// An error indicating a failure to send an HTTP request.
331    Http(HttpError),
332    /// An error indicating a failure to build an HTTP request payload.
333    Payload(PayloadError),
334    /// An error received from the Telegram server in response to the execution request.
335    Response(ResponseError),
336}
337
338impl Error for ExecuteError {
339    fn source(&self) -> Option<&(dyn Error + 'static)> {
340        use self::ExecuteError::*;
341        Some(match self {
342            Http(err) => err,
343            Payload(err) => err,
344            Response(err) => err,
345        })
346    }
347}
348
349impl fmt::Display for ExecuteError {
350    fn fmt(&self, out: &mut fmt::Formatter) -> fmt::Result {
351        use self::ExecuteError::*;
352        write!(
353            out,
354            "failed to execute method: {}",
355            match self {
356                Http(err) => err.to_string(),
357                Payload(err) => err.to_string(),
358                Response(err) => err.to_string(),
359            }
360        )
361    }
362}
363
364#[cfg(test)]
365mod tests {
366    use super::*;
367
368    #[test]
369    fn api() {
370        let client = Client::new("token").unwrap();
371        assert_eq!(client.token, "token");
372        assert_eq!(client.host, DEFAULT_HOST);
373
374        let client = Client::new("token")
375            .unwrap()
376            .with_host("https://example.com")
377            .with_max_retries(1);
378        assert_eq!(client.token, "token");
379        assert_eq!(client.host, "https://example.com");
380        assert_eq!(client.max_retries, 1);
381    }
382}