1#![cfg(feature = "script")]
2use sha1_smol::Sha1;
3
4use crate::cmd::cmd;
5use crate::connection::ConnectionLike;
6use crate::types::{ErrorKind, FromRedisValue, RedisResult, ToRedisArgs};
7use crate::Cmd;
8
9#[derive(Debug, Clone)]
11pub struct Script {
12 code: String,
13 hash: String,
14}
15
16impl Script {
32 pub fn new(code: &str) -> Script {
34 let mut hash = Sha1::new();
35 hash.update(code.as_bytes());
36 Script {
37 code: code.to_string(),
38 hash: hash.digest().to_string(),
39 }
40 }
41
42 pub fn get_hash(&self) -> &str {
44 &self.hash
45 }
46
47 pub(crate) fn load_cmd(&self) -> Cmd {
49 let mut cmd = cmd("SCRIPT");
50 cmd.arg("LOAD").arg(self.code.as_bytes());
51 cmd
52 }
53
54 #[inline]
56 pub fn load(&self, con: &mut dyn ConnectionLike) -> RedisResult<String> {
57 let hash: String = self.load_cmd().query(con)?;
58
59 debug_assert_eq!(hash, self.hash);
60
61 Ok(hash)
62 }
63
64 #[inline]
66 #[cfg(feature = "aio")]
67 pub async fn load_async<C>(&self, con: &mut C) -> RedisResult<String>
68 where
69 C: crate::aio::ConnectionLike,
70 {
71 let hash: String = self.load_cmd().query_async(con).await?;
72
73 debug_assert_eq!(hash, self.hash);
74
75 Ok(hash)
76 }
77
78 #[inline]
80 pub fn key<T: ToRedisArgs>(&self, key: T) -> ScriptInvocation<'_> {
81 ScriptInvocation {
82 script: self,
83 args: vec![],
84 keys: key.to_redis_args(),
85 }
86 }
87
88 #[inline]
90 pub fn arg<T: ToRedisArgs>(&self, arg: T) -> ScriptInvocation<'_> {
91 ScriptInvocation {
92 script: self,
93 args: arg.to_redis_args(),
94 keys: vec![],
95 }
96 }
97
98 #[inline]
102 pub fn prepare_invoke(&self) -> ScriptInvocation<'_> {
103 ScriptInvocation {
104 script: self,
105 args: vec![],
106 keys: vec![],
107 }
108 }
109
110 #[inline]
112 pub fn invoke<T: FromRedisValue>(&self, con: &mut dyn ConnectionLike) -> RedisResult<T> {
113 ScriptInvocation {
114 script: self,
115 args: vec![],
116 keys: vec![],
117 }
118 .invoke(con)
119 }
120
121 #[inline]
123 #[cfg(feature = "aio")]
124 pub async fn invoke_async<C, T>(&self, con: &mut C) -> RedisResult<T>
125 where
126 C: crate::aio::ConnectionLike,
127 T: FromRedisValue,
128 {
129 ScriptInvocation {
130 script: self,
131 args: vec![],
132 keys: vec![],
133 }
134 .invoke_async(con)
135 .await
136 }
137}
138
139pub struct ScriptInvocation<'a> {
141 script: &'a Script,
142 args: Vec<Vec<u8>>,
143 keys: Vec<Vec<u8>>,
144}
145
146impl<'a> ScriptInvocation<'a> {
151 #[inline]
154 pub fn arg<'b, T: ToRedisArgs>(&'b mut self, arg: T) -> &'b mut ScriptInvocation<'a>
155 where
156 'a: 'b,
157 {
158 arg.write_redis_args(&mut self.args);
159 self
160 }
161
162 #[inline]
165 pub fn key<'b, T: ToRedisArgs>(&'b mut self, key: T) -> &'b mut ScriptInvocation<'a>
166 where
167 'a: 'b,
168 {
169 key.write_redis_args(&mut self.keys);
170 self
171 }
172
173 #[inline]
175 pub fn invoke<T: FromRedisValue>(&self, con: &mut dyn ConnectionLike) -> RedisResult<T> {
176 let eval_cmd = self.eval_cmd();
177 match eval_cmd.query(con) {
178 Ok(val) => Ok(val),
179 Err(err) => {
180 if err.kind() == ErrorKind::NoScriptError {
181 self.load(con)?;
182 eval_cmd.query(con)
183 } else {
184 Err(err)
185 }
186 }
187 }
188 }
189
190 #[inline]
192 #[cfg(feature = "aio")]
193 pub async fn invoke_async<T: FromRedisValue>(
194 &self,
195 con: &mut impl crate::aio::ConnectionLike,
196 ) -> RedisResult<T> {
197 let eval_cmd = self.eval_cmd();
198 match eval_cmd.query_async(con).await {
199 Ok(val) => {
200 Ok(val)
202 }
203 Err(err) => {
204 if err.kind() == ErrorKind::NoScriptError {
206 self.load_async(con).await?;
207 eval_cmd.query_async(con).await
208 } else {
209 Err(err)
210 }
211 }
212 }
213 }
214
215 #[inline]
217 pub fn load(&self, con: &mut dyn ConnectionLike) -> RedisResult<String> {
218 self.script.load(con)
219 }
220
221 #[inline]
223 #[cfg(feature = "aio")]
224 pub async fn load_async<C>(&self, con: &mut C) -> RedisResult<String>
225 where
226 C: crate::aio::ConnectionLike,
227 {
228 self.script.load_async(con).await
229 }
230
231 fn estimate_buflen(&self) -> usize {
232 self
233 .keys
234 .iter()
235 .chain(self.args.iter())
236 .fold(0, |acc, e| acc + e.len())
237 + 7 + self.script.hash.len()
239 + 4 }
241
242 pub(crate) fn eval_cmd(&self) -> Cmd {
244 let args_len = 3 + self.keys.len() + self.args.len();
245 let mut cmd = Cmd::with_capacity(args_len, self.estimate_buflen());
246 cmd.arg("EVALSHA")
247 .arg(self.script.hash.as_bytes())
248 .arg(self.keys.len())
249 .arg(&*self.keys)
250 .arg(&*self.args);
251 cmd
252 }
253}
254
255#[cfg(test)]
256mod tests {
257 use super::Script;
258
259 #[test]
260 fn script_eval_should_work() {
261 let script = Script::new("return KEYS[1]");
262 let invocation = script.key("dummy");
263 let estimated_buflen = invocation.estimate_buflen();
264 let cmd = invocation.eval_cmd();
265 assert!(estimated_buflen >= cmd.capacity().1);
266 let expected = "*4\r\n$7\r\nEVALSHA\r\n$40\r\n4a2267357833227dd98abdedb8cf24b15a986445\r\n$1\r\n1\r\n$5\r\ndummy\r\n";
267 assert_eq!(
268 expected,
269 std::str::from_utf8(cmd.get_packed_command().as_slice()).unwrap()
270 );
271 }
272}