redis/
script.rs

1#![cfg(feature = "script")]
2use sha1_smol::Sha1;
3
4use crate::cmd::cmd;
5use crate::connection::ConnectionLike;
6use crate::types::{ErrorKind, FromRedisValue, RedisResult, ToRedisArgs};
7use crate::Cmd;
8
9/// Represents a lua script.
10#[derive(Debug, Clone)]
11pub struct Script {
12    code: String,
13    hash: String,
14}
15
16/// The script object represents a lua script that can be executed on the
17/// redis server.  The object itself takes care of automatic uploading and
18/// execution.  The script object itself can be shared and is immutable.
19///
20/// Example:
21///
22/// ```rust,no_run
23/// # let client = redis::Client::open("redis://127.0.0.1/").unwrap();
24/// # let mut con = client.get_connection().unwrap();
25/// let script = redis::Script::new(r"
26///     return tonumber(ARGV[1]) + tonumber(ARGV[2]);
27/// ");
28/// let result = script.arg(1).arg(2).invoke(&mut con);
29/// assert_eq!(result, Ok(3));
30/// ```
31impl Script {
32    /// Creates a new script object.
33    pub fn new(code: &str) -> Script {
34        let mut hash = Sha1::new();
35        hash.update(code.as_bytes());
36        Script {
37            code: code.to_string(),
38            hash: hash.digest().to_string(),
39        }
40    }
41
42    /// Returns the script's SHA1 hash in hexadecimal format.
43    pub fn get_hash(&self) -> &str {
44        &self.hash
45    }
46
47    /// Creates a script invocation object with a key filled in.
48    #[inline]
49    pub fn key<T: ToRedisArgs>(&self, key: T) -> ScriptInvocation<'_> {
50        ScriptInvocation {
51            script: self,
52            args: vec![],
53            keys: key.to_redis_args(),
54        }
55    }
56
57    /// Creates a script invocation object with an argument filled in.
58    #[inline]
59    pub fn arg<T: ToRedisArgs>(&self, arg: T) -> ScriptInvocation<'_> {
60        ScriptInvocation {
61            script: self,
62            args: arg.to_redis_args(),
63            keys: vec![],
64        }
65    }
66
67    /// Returns an empty script invocation object.  This is primarily useful
68    /// for programmatically adding arguments and keys because the type will
69    /// not change.  Normally you can use `arg` and `key` directly.
70    #[inline]
71    pub fn prepare_invoke(&self) -> ScriptInvocation<'_> {
72        ScriptInvocation {
73            script: self,
74            args: vec![],
75            keys: vec![],
76        }
77    }
78
79    /// Invokes the script directly without arguments.
80    #[inline]
81    pub fn invoke<T: FromRedisValue>(&self, con: &mut dyn ConnectionLike) -> RedisResult<T> {
82        ScriptInvocation {
83            script: self,
84            args: vec![],
85            keys: vec![],
86        }
87        .invoke(con)
88    }
89
90    /// Asynchronously invokes the script without arguments.
91    #[inline]
92    #[cfg(feature = "aio")]
93    pub async fn invoke_async<C, T>(&self, con: &mut C) -> RedisResult<T>
94    where
95        C: crate::aio::ConnectionLike,
96        T: FromRedisValue,
97    {
98        ScriptInvocation {
99            script: self,
100            args: vec![],
101            keys: vec![],
102        }
103        .invoke_async(con)
104        .await
105    }
106}
107
108/// Represents a prepared script call.
109pub struct ScriptInvocation<'a> {
110    script: &'a Script,
111    args: Vec<Vec<u8>>,
112    keys: Vec<Vec<u8>>,
113}
114
115/// This type collects keys and other arguments for the script so that it
116/// can be then invoked.  While the `Script` type itself holds the script,
117/// the `ScriptInvocation` holds the arguments that should be invoked until
118/// it's sent to the server.
119impl<'a> ScriptInvocation<'a> {
120    /// Adds a regular argument to the invocation.  This ends up as `ARGV[i]`
121    /// in the script.
122    #[inline]
123    pub fn arg<'b, T: ToRedisArgs>(&'b mut self, arg: T) -> &'b mut ScriptInvocation<'a>
124    where
125        'a: 'b,
126    {
127        arg.write_redis_args(&mut self.args);
128        self
129    }
130
131    /// Adds a key argument to the invocation.  This ends up as `KEYS[i]`
132    /// in the script.
133    #[inline]
134    pub fn key<'b, T: ToRedisArgs>(&'b mut self, key: T) -> &'b mut ScriptInvocation<'a>
135    where
136        'a: 'b,
137    {
138        key.write_redis_args(&mut self.keys);
139        self
140    }
141
142    /// Invokes the script and returns the result.
143    #[inline]
144    pub fn invoke<T: FromRedisValue>(&self, con: &mut dyn ConnectionLike) -> RedisResult<T> {
145        let eval_cmd = self.eval_cmd();
146        match eval_cmd.query(con) {
147            Ok(val) => Ok(val),
148            Err(err) => {
149                if err.kind() == ErrorKind::NoScriptError {
150                    self.load_cmd().exec(con)?;
151                    eval_cmd.query(con)
152                } else {
153                    Err(err)
154                }
155            }
156        }
157    }
158
159    /// Asynchronously invokes the script and returns the result.
160    #[inline]
161    #[cfg(feature = "aio")]
162    pub async fn invoke_async<T: FromRedisValue>(
163        &self,
164        con: &mut impl crate::aio::ConnectionLike,
165    ) -> RedisResult<T> {
166        let eval_cmd = self.eval_cmd();
167        match eval_cmd.query_async(con).await {
168            Ok(val) => {
169                // Return the value from the script evaluation
170                Ok(val)
171            }
172            Err(err) => {
173                // Load the script into Redis if the script hash wasn't there already
174                if err.kind() == ErrorKind::NoScriptError {
175                    self.load_cmd().exec_async(con).await?;
176                    eval_cmd.query_async(con).await
177                } else {
178                    Err(err)
179                }
180            }
181        }
182    }
183
184    /// Loads the script and returns the SHA1 of it.
185    #[inline]
186    pub fn load(&self, con: &mut dyn ConnectionLike) -> RedisResult<String> {
187        let hash: String = self.load_cmd().query(con)?;
188
189        debug_assert_eq!(hash, self.script.hash);
190
191        Ok(hash)
192    }
193
194    /// Asynchronously loads the script and returns the SHA1 of it.
195    #[inline]
196    #[cfg(feature = "aio")]
197    pub async fn load_async<C>(&self, con: &mut C) -> RedisResult<String>
198    where
199        C: crate::aio::ConnectionLike,
200    {
201        let hash: String = self.load_cmd().query_async(con).await?;
202
203        debug_assert_eq!(hash, self.script.hash);
204
205        Ok(hash)
206    }
207
208    /// Returns a command to load the script.
209    fn load_cmd(&self) -> Cmd {
210        let mut cmd = cmd("SCRIPT");
211        cmd.arg("LOAD").arg(self.script.code.as_bytes());
212        cmd
213    }
214
215    fn estimate_buflen(&self) -> usize {
216        self
217            .keys
218            .iter()
219            .chain(self.args.iter())
220            .fold(0, |acc, e| acc + e.len())
221            + 7 /* "EVALSHA".len() */
222            + self.script.hash.len()
223            + 4 /* Slots reserved for the length of keys. */
224    }
225
226    /// Returns a command to evaluate the command.
227    pub(crate) fn eval_cmd(&self) -> Cmd {
228        let args_len = 3 + self.keys.len() + self.args.len();
229        let mut cmd = Cmd::with_capacity(args_len, self.estimate_buflen());
230        cmd.arg("EVALSHA")
231            .arg(self.script.hash.as_bytes())
232            .arg(self.keys.len())
233            .arg(&*self.keys)
234            .arg(&*self.args);
235        cmd
236    }
237}
238
239#[cfg(test)]
240mod tests {
241    use super::Script;
242
243    #[test]
244    fn script_eval_should_work() {
245        let script = Script::new("return KEYS[1]");
246        let invocation = script.key("dummy");
247        let estimated_buflen = invocation.estimate_buflen();
248        let cmd = invocation.eval_cmd();
249        assert!(estimated_buflen >= cmd.capacity().1);
250        let expected = "*4\r\n$7\r\nEVALSHA\r\n$40\r\n4a2267357833227dd98abdedb8cf24b15a986445\r\n$1\r\n1\r\n$5\r\ndummy\r\n";
251        assert_eq!(
252            expected,
253            std::str::from_utf8(cmd.get_packed_command().as_slice()).unwrap()
254        );
255    }
256}