carapax/ratelimit/predicate/direct/
mod.rs

1use std::sync::Arc;
2
3pub use governor::{Jitter, Quota};
4use governor::{
5    RateLimiter,
6    clock::DefaultClock,
7    middleware::NoOpMiddleware,
8    state::{InMemoryState, NotKeyed},
9};
10pub use nonzero_ext::nonzero;
11
12use crate::{
13    core::{Handler, PredicateResult},
14    ratelimit::{
15        jitter::NoJitter,
16        method::{MethodDiscard, MethodWait},
17    },
18};
19
20#[cfg(test)]
21mod tests;
22
23/// A predicate with a direct rate limiter.
24///
25/// Use this predicate when you need to limit all updates.
26#[derive(Clone)]
27pub struct DirectRateLimitPredicate<J, M> {
28    limiter: Arc<RateLimiter<NotKeyed, InMemoryState, DefaultClock, NoOpMiddleware>>,
29    jitter: J,
30    _method: M,
31}
32
33impl DirectRateLimitPredicate<NoJitter, MethodDiscard> {
34    /// Creates a new `DirectRateLimitPredicate` with the discard method.
35    ///
36    /// The predicate will stop update propagation when the rate limit is reached.
37    ///
38    /// # Arguments
39    ///
40    /// * `quota` - A rate limiting quota.
41    pub fn discard(quota: Quota) -> Self {
42        Self {
43            limiter: Arc::new(RateLimiter::direct(quota)),
44            jitter: NoJitter,
45            _method: MethodDiscard,
46        }
47    }
48}
49
50impl DirectRateLimitPredicate<NoJitter, MethodWait> {
51    /// Creates a new `DirectRateLimitPredicate` with the wait method.
52    ///
53    /// The predicate will pause update propagation when the rate limit is reached.
54    ///
55    /// # Arguments
56    ///
57    /// * `quota` - A rate limiting quota.
58    pub fn wait(quota: Quota) -> Self {
59        Self {
60            limiter: Arc::new(RateLimiter::direct(quota)),
61            jitter: NoJitter,
62            _method: MethodWait,
63        }
64    }
65}
66
67impl DirectRateLimitPredicate<Jitter, MethodWait> {
68    /// Creates a new `DirectRateLimitPredicate` with the wait method and jitter.
69    ///
70    /// Predicate will pause update propagation when the rate limit is reached.
71    ///
72    /// # Arguments
73    ///
74    /// * `quota` - A rate limiting quota.
75    /// * `jitter` - An interval specification for deviating from the nominal wait time.
76    pub fn wait_with_jitter(quota: Quota, jitter: Jitter) -> Self {
77        Self {
78            limiter: Arc::new(RateLimiter::direct(quota)),
79            jitter,
80            _method: MethodWait,
81        }
82    }
83}
84
85impl Handler<()> for DirectRateLimitPredicate<NoJitter, MethodDiscard> {
86    type Output = PredicateResult;
87
88    async fn handle(&self, (): ()) -> Self::Output {
89        match self.limiter.check() {
90            Ok(_) => PredicateResult::True,
91            Err(_) => {
92                log::info!("DirectRateLimitPredicate: update discarded");
93                PredicateResult::False
94            }
95        }
96    }
97}
98
99impl Handler<()> for DirectRateLimitPredicate<NoJitter, MethodWait> {
100    type Output = PredicateResult;
101
102    async fn handle(&self, (): ()) -> Self::Output {
103        self.limiter.until_ready().await;
104        PredicateResult::True
105    }
106}
107
108impl Handler<()> for DirectRateLimitPredicate<Jitter, MethodWait> {
109    type Output = PredicateResult;
110
111    async fn handle(&self, (): ()) -> Self::Output {
112        self.limiter.until_ready_with_jitter(self.jitter).await;
113        PredicateResult::True
114    }
115}