1use std::{error::Error, fmt, num::ParseIntError, string::FromUtf8Error, time::SystemTimeError};
2
3use redis::{AsyncCommands, RedisError};
4
5use crate::{backend::SessionBackend, utils::now};
6
7#[derive(Clone)]
9pub struct RedisBackend<C> {
10 namespace: String,
11 sessions_key: String,
12 connection: C,
13}
14
15impl<C> RedisBackend<C> {
16 pub fn new<N>(namespace: N, connection: C) -> Self
23 where
24 N: Into<String>,
25 {
26 let namespace = namespace.into();
27 let sessions_key = format!("{namespace}:__seance_sessions");
28 Self {
29 namespace,
30 sessions_key,
31 connection,
32 }
33 }
34
35 fn get_session_key(&self, session_id: &str) -> String {
36 format!("{}:{}", self.namespace, session_id)
37 }
38}
39
40impl<C> SessionBackend for RedisBackend<C>
41where
42 C: AsyncCommands,
43{
44 type Error = RedisBackendError;
45
46 async fn get_sessions(&mut self) -> Result<Vec<String>, Self::Error> {
47 self.connection
48 .hkeys(&self.sessions_key)
49 .await
50 .map_err(RedisBackendError::GetSessions)
51 }
52
53 async fn get_session_age(&mut self, session_id: &str) -> Result<Option<u64>, Self::Error> {
54 self.connection
55 .hget(&self.sessions_key, session_id)
56 .await
57 .map_err(RedisBackendError::GetSessionAge)
58 }
59
60 async fn remove_session(&mut self, session_id: &str) -> Result<(), Self::Error> {
61 let session_key = self.get_session_key(session_id);
62 self.connection
63 .del(session_key)
64 .await
65 .map_err(RedisBackendError::RemoveSession)
66 }
67
68 async fn read_value(&mut self, session_id: &str, key: &str) -> Result<Option<Vec<u8>>, Self::Error> {
69 let session_key = self.get_session_key(session_id);
70 let result: Option<Vec<u8>> = self
72 .connection
73 .hget(session_key, key)
74 .await
75 .map_err(RedisBackendError::ReadValue)?;
76 Ok(result)
77 }
78
79 async fn write_value(&mut self, session_id: &str, key: &str, value: &[u8]) -> Result<(), Self::Error> {
80 let session_key = self.get_session_key(session_id);
81 let len: i64 = self
82 .connection
83 .hlen(&session_key)
84 .await
85 .map_err(RedisBackendError::WriteValue)?;
86 if len == 0 {
87 let timestamp = format!("{}", now().map_err(RedisBackendError::SetSessionTimestamp)?);
88 let _: () = self
89 .connection
90 .hset(&self.sessions_key, session_id, timestamp)
91 .await
92 .map_err(RedisBackendError::WriteValue)?;
93 }
94 self.connection
95 .hset(session_key, key, value)
96 .await
97 .map_err(RedisBackendError::WriteValue)
98 }
99
100 async fn remove_value(&mut self, session_id: &str, key: &str) -> Result<(), Self::Error> {
101 let session_key = self.get_session_key(session_id);
102 self.connection
103 .hdel(session_key, key)
104 .await
105 .map_err(RedisBackendError::RemoveValue)
106 }
107}
108
109#[derive(Debug)]
111pub enum RedisBackendError {
112 GetSessions(RedisError),
114 GetSessionAge(RedisError),
116 ParseSessionAge(ParseIntError),
118 ParseSessionId(FromUtf8Error),
120 ReadValue(RedisError),
122 RemoveSession(RedisError),
124 RemoveValue(RedisError),
126 SessionAgeFromUtf8(FromUtf8Error),
128 SetSessionTimestamp(SystemTimeError),
132 WriteValue(RedisError),
134}
135
136impl fmt::Display for RedisBackendError {
137 fn fmt(&self, out: &mut fmt::Formatter) -> fmt::Result {
138 use self::RedisBackendError::*;
139 match self {
140 GetSessions(err) => write!(out, "failed to get sessions list: {err}"),
141 GetSessionAge(err) => write!(out, "failed to get session age: {err}"),
142 ParseSessionAge(err) => write!(out, "session age contains non-integer value: {err}"),
143 ParseSessionId(err) => write!(out, "session id contains non-utf8 string: {err}"),
144 ReadValue(err) => write!(out, "failed to read value: {err}"),
145 RemoveSession(err) => write!(out, "failed to remove session: {err}"),
146 RemoveValue(err) => write!(out, "failed to remove value: {err}"),
147 SessionAgeFromUtf8(err) => write!(out, "session age contains non-utf8 string: {err}"),
148 SetSessionTimestamp(err) => write!(out, "failed to set session timestamp: {err}"),
149 WriteValue(err) => write!(out, "failed to write value: {err}"),
150 }
151 }
152}
153
154impl Error for RedisBackendError {
155 fn source(&self) -> Option<&(dyn Error + 'static)> {
156 use self::RedisBackendError::*;
157 Some(match self {
158 GetSessions(err) => err,
159 GetSessionAge(err) => err,
160 ParseSessionAge(err) => err,
161 ParseSessionId(err) => err,
162 ReadValue(err) => err,
163 RemoveSession(err) => err,
164 RemoveValue(err) => err,
165 SessionAgeFromUtf8(err) => err,
166 SetSessionTimestamp(err) => err,
167 WriteValue(err) => err,
168 })
169 }
170}