1use crate::errors::ParsingError;
4use crate::types::{FromRedisValue, RedisWrite, ToRedisArgs, Value};
5
6macro_rules! not_convertible_error {
7 ($v:expr, $det:expr) => {
8 ParsingError::from(format!("{:?} (response was {:?})", $det, $v))
9 };
10}
11
12#[derive(Debug, Eq, PartialEq)]
15#[non_exhaustive]
16pub enum Rule {
17 On,
19 Off,
22
23 AddCommand(String),
25 RemoveCommand(String),
27 AddCategory(String),
29 RemoveCategory(String),
31 AllCommands,
34 NoCommands,
36
37 AddPass(String),
39 RemovePass(String),
41 AddHashedPass(String),
43 RemoveHashedPass(String),
45 NoPass,
49 ResetPass,
51
52 Pattern(String),
54 AllKeys,
56 ResetKeys,
58
59 Channel(String),
61 ResetChannels,
63 Selector(Vec<Rule>),
66
67 Reset,
70
71 Other(String),
75}
76
77impl ToRedisArgs for Rule {
78 fn write_redis_args<W>(&self, out: &mut W)
79 where
80 W: ?Sized + RedisWrite,
81 {
82 use self::Rule::*;
83
84 match self {
85 On => out.write_arg(b"on"),
86 Off => out.write_arg(b"off"),
87
88 AddCommand(cmd) => out.write_arg_fmt(format_args!("+{cmd}")),
89 RemoveCommand(cmd) => out.write_arg_fmt(format_args!("-{cmd}")),
90 AddCategory(cat) => out.write_arg_fmt(format_args!("+@{cat}")),
91 RemoveCategory(cat) => out.write_arg_fmt(format_args!("-@{cat}")),
92 AllCommands => out.write_arg(b"allcommands"),
93 NoCommands => out.write_arg(b"nocommands"),
94
95 AddPass(pass) => out.write_arg_fmt(format_args!(">{pass}")),
96 RemovePass(pass) => out.write_arg_fmt(format_args!("<{pass}")),
97 AddHashedPass(pass) => out.write_arg_fmt(format_args!("#{pass}")),
98 RemoveHashedPass(pass) => out.write_arg_fmt(format_args!("!{pass}")),
99 NoPass => out.write_arg(b"nopass"),
100 ResetPass => out.write_arg(b"resetpass"),
101
102 Pattern(pat) => out.write_arg_fmt(format_args!("~{pat}")),
103 AllKeys => out.write_arg(b"allkeys"),
104 ResetKeys => out.write_arg(b"resetkeys"),
105 Channel(pat) => out.write_arg_fmt(format_args!("&{pat}")),
106 ResetChannels => out.write_arg(b"resetchannels"),
107 Selector(sel) => out.write_arg_fmt(format_args!(
108 "({})",
109 sel.iter()
110 .flat_map(|r| r
111 .to_redis_args()
112 .into_iter()
113 .map(|x| String::from_utf8_lossy(&x).to_string()))
114 .collect::<Vec<String>>()
115 .join(" ")
116 )),
117 Reset => out.write_arg(b"reset"),
118
119 Other(rule) => out.write_arg(rule.as_bytes()),
120 };
121 }
122}
123
124#[derive(Debug, Default, Eq, PartialEq)]
129pub struct AclInfo {
130 pub flags: Vec<Rule>,
140 pub passwords: Vec<Rule>,
144 pub commands: Vec<Rule>,
153 pub keys: Vec<Rule>,
158 pub channels: Vec<Rule>,
162 pub selectors: Vec<Rule>,
166}
167impl AclInfo {
168 fn handle_pair(&mut self, name: &Value, value: &Value) -> Result<(), ParsingError> {
169 let key = match name {
171 Value::BulkString(bs) => {
172 let mut s = std::str::from_utf8(bs)?.trim().to_owned();
174 if s.len() >= 2 && s.starts_with('"') && s.ends_with('"') {
175 s = s[1..s.len() - 1].to_owned();
176 }
177 s
178 }
179 _ => {
180 return Err(not_convertible_error!(
181 name,
182 "Expect a bulk string key name"
183 ));
184 }
185 };
186 match key.as_str() {
187 "flags" => {
188 let f = value
189 .as_sequence()
190 .ok_or_else(|| {
191 not_convertible_error!(value, "Expect an array response of ACL flags")
192 })?
193 .iter()
194 .map(|flag| match flag {
195 Value::BulkString(flag) => match flag.as_slice() {
196 b"on" => Ok(Rule::On),
197 b"off" => Ok(Rule::Off),
198 b"allkeys" => Ok(Rule::AllKeys),
199 b"allcommands" => Ok(Rule::AllCommands),
200 b"nopass" => Ok(Rule::NoPass),
201 other => Ok(Rule::Other(String::from_utf8_lossy(other).into_owned())),
202 },
203 _ => Err(not_convertible_error!(
204 flag,
205 "Expect an arbitrary binary data"
206 )),
207 })
208 .collect::<Result<_, _>>()?;
209 self.flags = f;
210 }
211 "passwords" => {
212 let p = value
213 .as_sequence()
214 .ok_or_else(|| {
215 not_convertible_error!(value, "Expect an array response of ACL passwords")
216 })?
217 .iter()
218 .map(|pass| {
219 let s = String::from_redis_value_ref(pass)?;
220 Ok(Rule::AddHashedPass(s))
221 })
222 .collect::<Result<_, ParsingError>>()?;
223 self.passwords = p;
224 }
225 "commands" => {
226 let cmds = match value {
227 Value::BulkString(cmd) => std::str::from_utf8(cmd)?,
228 _ => {
229 return Err(not_convertible_error!(
230 value,
231 "Expect a valid UTF8 string for commands"
232 ));
233 }
234 }
235 .split_terminator(' ')
236 .map(|cmd| match cmd {
237 x if x.starts_with("+@") => Ok(Rule::AddCategory(x[2..].to_owned())),
238 x if x.starts_with("-@") => Ok(Rule::RemoveCategory(x[2..].to_owned())),
239 x if x.starts_with('+') => Ok(Rule::AddCommand(x[1..].to_owned())),
240 x if x.starts_with('-') => Ok(Rule::RemoveCommand(x[1..].to_owned())),
241 _ => Err(not_convertible_error!(
242 cmd,
243 "Expect a command addition/removal"
244 )),
245 })
246 .collect::<Result<_, _>>()?;
247 self.commands = cmds;
248 }
249 "keys" => {
250 let parsed = match value {
251 Value::Array(arr) => arr
252 .iter()
253 .map(|pat| {
254 let s = String::from_redis_value_ref(pat)?;
255 match s.as_str() {
256 "*" => Ok(Rule::AllKeys),
257 _ => Ok(Rule::Pattern(s)),
258 }
259 })
260 .collect::<Result<_, ParsingError>>()?,
261 Value::BulkString(bs) => {
262 let mut s = std::str::from_utf8(bs)?;
263 s = s.trim();
264 if s.len() >= 2 && s.starts_with('"') && s.ends_with('"') {
265 s = &s[1..s.len() - 1];
266 }
267 s.split_whitespace()
268 .map(|tok| {
269 let tok = if let Some(tok) = tok.strip_prefix('~') {
270 tok
271 } else {
272 tok
273 };
274 match tok {
275 "*" => Ok(Rule::AllKeys),
276 _ => Ok(Rule::Pattern(tok.to_owned())),
277 }
278 })
279 .collect::<Result<_, ParsingError>>()?
280 }
281 other => {
282 return Err(not_convertible_error!(
283 other,
284 "Expect an array or bulk-string of keys"
285 ));
286 }
287 };
288 self.keys = parsed;
289 }
290 "channels" => {
291 let parsed = match value {
292 Value::Array(arr) | Value::Set(arr) => arr
293 .iter()
294 .map(|pat| {
295 let s = String::from_redis_value_ref(pat)?;
296 let s = if let Some(s) = s.strip_prefix('&') {
297 s.to_owned()
298 } else {
299 s
300 };
301 Ok(Rule::Channel(s))
302 })
303 .collect::<Result<_, ParsingError>>()?,
304 Value::BulkString(bs) => {
305 let mut s = std::str::from_utf8(bs)?;
306 s = s.trim();
307 if s.len() >= 2 && s.starts_with('"') && s.ends_with('"') {
308 s = &s[1..s.len() - 1];
309 }
310 s.split_whitespace()
311 .map(|tok| {
312 let tok = if let Some(tok) = tok.strip_prefix('&') {
313 tok
314 } else {
315 tok
316 };
317 Ok(Rule::Channel(tok.to_owned()))
318 })
319 .collect::<Result<_, ParsingError>>()?
320 }
321 other => {
322 return Err(not_convertible_error!(
323 other,
324 "Expect an array or bulk-string of channels"
325 ));
326 }
327 };
328 self.channels = parsed;
329 }
330 "selectors" => {
331 let parsed = match value {
332 Value::Array(arr) | Value::Set(arr) => arr
336 .iter()
337 .map(|pat| {
338 let acl: AclInfo = FromRedisValue::from_redis_value_ref(pat)?;
339 let selector = acl
340 .flags
341 .into_iter()
342 .chain(acl.commands)
343 .chain(acl.channels)
344 .chain(acl.keys)
345 .collect();
346 Ok(selector)
347 })
348 .collect::<Result<Vec<Vec<Rule>>, ParsingError>>()?,
349 other => {
350 return Err(not_convertible_error!(
351 other,
352 "Expect an array or bulk-string of selectors"
353 ));
354 }
355 };
356 self.selectors = parsed.into_iter().flatten().collect();
357 }
358 _ => {}
359 }
360 Ok(())
361 }
362}
363impl FromRedisValue for AclInfo {
364 fn from_redis_value(v: Value) -> Result<Self, ParsingError> {
365 let mut acl = AclInfo::default();
366 if let Some(map_iter) = v.as_map_iter() {
369 for (name, value) in map_iter {
370 acl.handle_pair(name, value)?;
371 }
372 } else if let Some(seq) = v.as_sequence() {
373 if seq.len() % 2 != 0 {
375 return Err(not_convertible_error!(v, ""));
376 }
377 for chunk in seq.chunks(2) {
378 let name = &chunk[0];
379 let value = &chunk[1];
380 acl.handle_pair(name, value)?;
381 }
382 } else {
383 return Err(not_convertible_error!(v, ""));
384 }
385 Ok(acl)
386 }
387}
388#[cfg(test)]
389mod tests {
390 use super::*;
391
392 macro_rules! assert_args {
393 ($rule:expr, $arg:expr) => {
394 assert_eq!($rule.to_redis_args(), vec![$arg.to_vec()]);
395 };
396 }
397
398 #[test]
399 fn test_rule_to_arg() {
400 use self::Rule::*;
401
402 assert_args!(On, b"on");
403 assert_args!(Off, b"off");
404 assert_args!(AddCommand("set".to_owned()), b"+set");
405 assert_args!(RemoveCommand("set".to_owned()), b"-set");
406 assert_args!(AddCategory("hyperloglog".to_owned()), b"+@hyperloglog");
407 assert_args!(RemoveCategory("hyperloglog".to_owned()), b"-@hyperloglog");
408 assert_args!(AllCommands, b"allcommands");
409 assert_args!(NoCommands, b"nocommands");
410 assert_args!(AddPass("mypass".to_owned()), b">mypass");
411 assert_args!(RemovePass("mypass".to_owned()), b"<mypass");
412 assert_args!(
413 AddHashedPass(
414 "c3ab8ff13720e8ad9047dd39466b3c8974e592c2fa383d4a3960714caef0c4f2".to_owned()
415 ),
416 b"#c3ab8ff13720e8ad9047dd39466b3c8974e592c2fa383d4a3960714caef0c4f2"
417 );
418 assert_args!(
419 RemoveHashedPass(
420 "c3ab8ff13720e8ad9047dd39466b3c8974e592c2fa383d4a3960714caef0c4f2".to_owned()
421 ),
422 b"!c3ab8ff13720e8ad9047dd39466b3c8974e592c2fa383d4a3960714caef0c4f2"
423 );
424 assert_args!(NoPass, b"nopass");
425 assert_args!(Pattern("pat:*".to_owned()), b"~pat:*");
426 assert_args!(AllKeys, b"allkeys");
427 assert_args!(ResetKeys, b"resetkeys");
428 assert_args!(Reset, b"reset");
429 assert_args!(Other("resetchannels".to_owned()), b"resetchannels");
430 assert_args!(Channel("asynq:cancel".to_owned()), b"&asynq:cancel");
431 assert_args!(
432 Selector(vec![
433 AddCommand("SET".to_string()),
434 Pattern("key2".to_string())
435 ]),
436 b"(+SET ~key2)"
437 );
438 }
439
440 #[test]
441 fn test_from_redis_value() {
442 let redis_value = Value::Array(vec![
443 Value::BulkString("flags".into()),
444 Value::Array(vec![
445 Value::BulkString("on".into()),
446 Value::BulkString("allchannels".into()),
447 ]),
448 Value::BulkString("passwords".into()),
449 Value::Array(vec![]),
450 Value::BulkString("commands".into()),
451 Value::BulkString("-@all +get".into()),
452 Value::BulkString("keys".into()),
453 Value::Array(vec![Value::BulkString("pat:*".into())]),
454 Value::BulkString("channels".into()),
455 Value::Array(vec![Value::BulkString("&asynq:cancel".into())]),
456 Value::BulkString("selectors".into()),
457 Value::Array(vec![Value::Array(vec![
458 Value::BulkString("commands".into()),
459 Value::BulkString("-@all +get".into()),
460 Value::BulkString("keys".into()),
461 Value::BulkString("~key2".into()),
462 Value::BulkString("channels".into()),
463 Value::BulkString("".into()),
464 ])]),
465 ]);
466 let acl_info = AclInfo::from_redis_value_ref(&redis_value).expect("Parse successfully");
467
468 assert_eq!(
469 acl_info,
470 AclInfo {
471 flags: vec![Rule::On, Rule::Other("allchannels".into())],
472 passwords: vec![],
473 commands: vec![
474 Rule::RemoveCategory("all".to_owned()),
475 Rule::AddCommand("get".to_owned()),
476 ],
477 keys: vec![Rule::Pattern("pat:*".to_owned())],
478 channels: vec![Rule::Channel("asynq:cancel".to_owned())],
479 selectors: vec![
480 Rule::RemoveCategory("all".to_owned()),
481 Rule::AddCommand("get".to_owned()),
482 Rule::Pattern("key2".to_owned()),
483 ],
484 }
485 );
486 }
487}