1#![cfg(feature = "script")]
2use sha1_smol::Sha1;
3
4use crate::cmd::cmd;
5use crate::connection::ConnectionLike;
6use crate::types::{ErrorKind, FromRedisValue, RedisResult, ToRedisArgs};
7use crate::Cmd;
8
9#[derive(Debug, Clone)]
11pub struct Script {
12 code: String,
13 hash: String,
14}
15
16impl Script {
32 pub fn new(code: &str) -> Script {
34 let mut hash = Sha1::new();
35 hash.update(code.as_bytes());
36 Script {
37 code: code.to_string(),
38 hash: hash.digest().to_string(),
39 }
40 }
41
42 pub fn get_hash(&self) -> &str {
44 &self.hash
45 }
46
47 #[inline]
49 pub fn key<T: ToRedisArgs>(&self, key: T) -> ScriptInvocation<'_> {
50 ScriptInvocation {
51 script: self,
52 args: vec![],
53 keys: key.to_redis_args(),
54 }
55 }
56
57 #[inline]
59 pub fn arg<T: ToRedisArgs>(&self, arg: T) -> ScriptInvocation<'_> {
60 ScriptInvocation {
61 script: self,
62 args: arg.to_redis_args(),
63 keys: vec![],
64 }
65 }
66
67 #[inline]
71 pub fn prepare_invoke(&self) -> ScriptInvocation<'_> {
72 ScriptInvocation {
73 script: self,
74 args: vec![],
75 keys: vec![],
76 }
77 }
78
79 #[inline]
81 pub fn invoke<T: FromRedisValue>(&self, con: &mut dyn ConnectionLike) -> RedisResult<T> {
82 ScriptInvocation {
83 script: self,
84 args: vec![],
85 keys: vec![],
86 }
87 .invoke(con)
88 }
89
90 #[inline]
92 #[cfg(feature = "aio")]
93 pub async fn invoke_async<C, T>(&self, con: &mut C) -> RedisResult<T>
94 where
95 C: crate::aio::ConnectionLike,
96 T: FromRedisValue,
97 {
98 ScriptInvocation {
99 script: self,
100 args: vec![],
101 keys: vec![],
102 }
103 .invoke_async(con)
104 .await
105 }
106}
107
108pub struct ScriptInvocation<'a> {
110 script: &'a Script,
111 args: Vec<Vec<u8>>,
112 keys: Vec<Vec<u8>>,
113}
114
115impl<'a> ScriptInvocation<'a> {
120 #[inline]
123 pub fn arg<'b, T: ToRedisArgs>(&'b mut self, arg: T) -> &'b mut ScriptInvocation<'a>
124 where
125 'a: 'b,
126 {
127 arg.write_redis_args(&mut self.args);
128 self
129 }
130
131 #[inline]
134 pub fn key<'b, T: ToRedisArgs>(&'b mut self, key: T) -> &'b mut ScriptInvocation<'a>
135 where
136 'a: 'b,
137 {
138 key.write_redis_args(&mut self.keys);
139 self
140 }
141
142 #[inline]
144 pub fn invoke<T: FromRedisValue>(&self, con: &mut dyn ConnectionLike) -> RedisResult<T> {
145 let eval_cmd = self.eval_cmd();
146 match eval_cmd.query(con) {
147 Ok(val) => Ok(val),
148 Err(err) => {
149 if err.kind() == ErrorKind::NoScriptError {
150 self.load_cmd().exec(con)?;
151 eval_cmd.query(con)
152 } else {
153 Err(err)
154 }
155 }
156 }
157 }
158
159 #[inline]
161 #[cfg(feature = "aio")]
162 pub async fn invoke_async<T: FromRedisValue>(
163 &self,
164 con: &mut impl crate::aio::ConnectionLike,
165 ) -> RedisResult<T> {
166 let eval_cmd = self.eval_cmd();
167 match eval_cmd.query_async(con).await {
168 Ok(val) => {
169 Ok(val)
171 }
172 Err(err) => {
173 if err.kind() == ErrorKind::NoScriptError {
175 self.load_cmd().exec_async(con).await?;
176 eval_cmd.query_async(con).await
177 } else {
178 Err(err)
179 }
180 }
181 }
182 }
183
184 #[inline]
186 pub fn load(&self, con: &mut dyn ConnectionLike) -> RedisResult<String> {
187 let hash: String = self.load_cmd().query(con)?;
188
189 debug_assert_eq!(hash, self.script.hash);
190
191 Ok(hash)
192 }
193
194 #[inline]
196 #[cfg(feature = "aio")]
197 pub async fn load_async<C>(&self, con: &mut C) -> RedisResult<String>
198 where
199 C: crate::aio::ConnectionLike,
200 {
201 let hash: String = self.load_cmd().query_async(con).await?;
202
203 debug_assert_eq!(hash, self.script.hash);
204
205 Ok(hash)
206 }
207
208 fn load_cmd(&self) -> Cmd {
210 let mut cmd = cmd("SCRIPT");
211 cmd.arg("LOAD").arg(self.script.code.as_bytes());
212 cmd
213 }
214
215 fn estimate_buflen(&self) -> usize {
216 self
217 .keys
218 .iter()
219 .chain(self.args.iter())
220 .fold(0, |acc, e| acc + e.len())
221 + 7 + self.script.hash.len()
223 + 4 }
225
226 pub(crate) fn eval_cmd(&self) -> Cmd {
228 let args_len = 3 + self.keys.len() + self.args.len();
229 let mut cmd = Cmd::with_capacity(args_len, self.estimate_buflen());
230 cmd.arg("EVALSHA")
231 .arg(self.script.hash.as_bytes())
232 .arg(self.keys.len())
233 .arg(&*self.keys)
234 .arg(&*self.args);
235 cmd
236 }
237}
238
239#[cfg(test)]
240mod tests {
241 use super::Script;
242
243 #[test]
244 fn script_eval_should_work() {
245 let script = Script::new("return KEYS[1]");
246 let invocation = script.key("dummy");
247 let estimated_buflen = invocation.estimate_buflen();
248 let cmd = invocation.eval_cmd();
249 assert!(estimated_buflen >= cmd.capacity().1);
250 let expected = "*4\r\n$7\r\nEVALSHA\r\n$40\r\n4a2267357833227dd98abdedb8cf24b15a986445\r\n$1\r\n1\r\n$5\r\ndummy\r\n";
251 assert_eq!(
252 expected,
253 std::str::from_utf8(cmd.get_packed_command().as_slice()).unwrap()
254 );
255 }
256}