openzeppelin_relayer/repositories/transaction_counter/
transaction_counter_redis.rs1use super::TransactionCounterTrait;
9use crate::models::RepositoryError;
10use crate::repositories::redis_base::RedisRepository;
11use async_trait::async_trait;
12use redis::aio::ConnectionManager;
13use redis::AsyncCommands;
14use std::fmt;
15use std::sync::Arc;
16use tracing::debug;
17
18const COUNTER_PREFIX: &str = "transaction_counter";
19
20#[derive(Clone)]
21pub struct RedisTransactionCounter {
22 pub client: Arc<ConnectionManager>,
23 pub key_prefix: String,
24}
25
26impl RedisRepository for RedisTransactionCounter {}
27
28impl fmt::Debug for RedisTransactionCounter {
29 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
30 f.debug_struct("RedisTransactionCounter")
31 .field("key_prefix", &self.key_prefix)
32 .finish()
33 }
34}
35
36impl RedisTransactionCounter {
37 pub fn new(
38 connection_manager: Arc<ConnectionManager>,
39 key_prefix: String,
40 ) -> Result<Self, RepositoryError> {
41 if key_prefix.is_empty() {
42 return Err(RepositoryError::InvalidData(
43 "Redis key prefix cannot be empty".to_string(),
44 ));
45 }
46
47 Ok(Self {
48 client: connection_manager,
49 key_prefix,
50 })
51 }
52
53 fn counter_key(&self, relayer_id: &str, address: &str) -> String {
55 format!(
56 "{}:{}:{}:{}",
57 self.key_prefix, COUNTER_PREFIX, relayer_id, address
58 )
59 }
60}
61
62#[async_trait]
63impl TransactionCounterTrait for RedisTransactionCounter {
64 async fn get(&self, relayer_id: &str, address: &str) -> Result<Option<u64>, RepositoryError> {
65 if relayer_id.is_empty() {
66 return Err(RepositoryError::InvalidData(
67 "Relayer ID cannot be empty".to_string(),
68 ));
69 }
70
71 if address.is_empty() {
72 return Err(RepositoryError::InvalidData(
73 "Address cannot be empty".to_string(),
74 ));
75 }
76
77 let key = self.counter_key(relayer_id, address);
78 debug!(relayer_id = %relayer_id, address = %address, "getting counter for relayer and address");
79
80 let mut conn = self.client.as_ref().clone();
81
82 let value: Option<u64> = conn
83 .get(&key)
84 .await
85 .map_err(|e| self.map_redis_error(e, "get_counter"))?;
86
87 debug!(value = ?value, "retrieved counter value");
88 Ok(value)
89 }
90
91 async fn get_and_increment(
92 &self,
93 relayer_id: &str,
94 address: &str,
95 ) -> Result<u64, RepositoryError> {
96 if relayer_id.is_empty() {
97 return Err(RepositoryError::InvalidData(
98 "Relayer ID cannot be empty".to_string(),
99 ));
100 }
101
102 if address.is_empty() {
103 return Err(RepositoryError::InvalidData(
104 "Address cannot be empty".to_string(),
105 ));
106 }
107
108 let key = self.counter_key(relayer_id, address);
109 debug!(relayer_id = %relayer_id, address = %address, "getting and incrementing counter for relayer and address");
110
111 let mut conn = self.client.as_ref().clone();
112
113 let new_value: u64 = conn
115 .incr(&key, 1)
116 .await
117 .map_err(|e| self.map_redis_error(e, "get_and_increment"))?;
118
119 let current = new_value.saturating_sub(1);
120
121 debug!(from = %current, to = %(current + 1), "counter incremented");
122 Ok(current)
123 }
124
125 async fn decrement(&self, relayer_id: &str, address: &str) -> Result<u64, RepositoryError> {
126 if relayer_id.is_empty() {
127 return Err(RepositoryError::InvalidData(
128 "Relayer ID cannot be empty".to_string(),
129 ));
130 }
131
132 if address.is_empty() {
133 return Err(RepositoryError::InvalidData(
134 "Address cannot be empty".to_string(),
135 ));
136 }
137
138 let key = self.counter_key(relayer_id, address);
139 debug!(relayer_id = %relayer_id, address = %address, "decrementing counter for relayer and address");
140
141 let mut conn = self.client.as_ref().clone();
142
143 let exists: bool = conn
145 .exists(&key)
146 .await
147 .map_err(|e| self.map_redis_error(e, "check_counter_exists"))?;
148
149 if !exists {
150 return Err(RepositoryError::NotFound(format!(
151 "Counter not found for relayer {relayer_id} and address {address}"
152 )));
153 }
154
155 let new_value: i64 = conn
157 .decr(&key, 1)
158 .await
159 .map_err(|e| self.map_redis_error(e, "decrement_counter"))?;
160
161 let new_value = if new_value < 0 {
162 let _: () = conn
164 .set(&key, 0)
165 .await
166 .map_err(|e| self.map_redis_error(e, "correct_negative_counter"))?;
167 0u64
168 } else {
169 new_value as u64
170 };
171
172 debug!(new_value = %new_value, "counter decremented");
173 Ok(new_value)
174 }
175
176 async fn set(
177 &self,
178 relayer_id: &str,
179 address: &str,
180 value: u64,
181 ) -> Result<(), RepositoryError> {
182 if relayer_id.is_empty() {
183 return Err(RepositoryError::InvalidData(
184 "Relayer ID cannot be empty".to_string(),
185 ));
186 }
187
188 if address.is_empty() {
189 return Err(RepositoryError::InvalidData(
190 "Address cannot be empty".to_string(),
191 ));
192 }
193
194 let key = self.counter_key(relayer_id, address);
195 debug!(relayer_id = %relayer_id, address = %address, value = %value, "setting counter for relayer and address");
196
197 let mut conn = self.client.as_ref().clone();
198
199 let _: () = conn
200 .set(&key, value)
201 .await
202 .map_err(|e| self.map_redis_error(e, "set_counter"))?;
203
204 debug!(value = %value, "counter set");
205 Ok(())
206 }
207}
208
209#[cfg(test)]
210mod tests {
211 use super::*;
212 use redis::aio::ConnectionManager;
213 use std::sync::Arc;
214 use tokio;
215 use uuid::Uuid;
216
217 async fn setup_test_repo() -> RedisTransactionCounter {
218 let redis_url =
219 std::env::var("REDIS_URL").unwrap_or_else(|_| "redis://127.0.0.1:6379".to_string());
220 let client = redis::Client::open(redis_url).expect("Failed to create Redis client");
221 let connection_manager = ConnectionManager::new(client)
222 .await
223 .expect("Failed to create Redis connection manager");
224
225 RedisTransactionCounter::new(Arc::new(connection_manager), "test_counter".to_string())
226 .expect("Failed to create Redis transaction counter")
227 }
228
229 #[tokio::test]
230 #[ignore = "Requires active Redis instance"]
231 async fn test_get_nonexistent_counter() {
232 let repo = setup_test_repo().await;
233 let random_id = Uuid::new_v4().to_string();
234 let result = repo.get(&random_id, "0x1234").await.unwrap();
235 assert_eq!(result, None);
236 }
237
238 #[tokio::test]
239 #[ignore = "Requires active Redis instance"]
240 async fn test_set_and_get_counter() {
241 let repo = setup_test_repo().await;
242 let relayer_id = uuid::Uuid::new_v4().to_string();
243 let address = uuid::Uuid::new_v4().to_string();
244
245 repo.set(&relayer_id, &address, 100).await.unwrap();
246 let result = repo.get(&relayer_id, &address).await.unwrap();
247 assert_eq!(result, Some(100));
248 }
249
250 #[tokio::test]
251 #[ignore = "Requires active Redis instance"]
252 async fn test_get_and_increment() {
253 let repo = setup_test_repo().await;
254 let relayer_id = uuid::Uuid::new_v4().to_string();
255 let address = uuid::Uuid::new_v4().to_string();
256
257 let result = repo.get_and_increment(&relayer_id, &address).await.unwrap();
259 assert_eq!(result, 0);
260
261 let current = repo.get(&relayer_id, &address).await.unwrap();
262 assert_eq!(current, Some(1));
263
264 let result = repo.get_and_increment(&relayer_id, &address).await.unwrap();
266 assert_eq!(result, 1);
267
268 let current = repo.get(&relayer_id, &address).await.unwrap();
269 assert_eq!(current, Some(2));
270 }
271
272 #[tokio::test]
273 #[ignore = "Requires active Redis instance"]
274 async fn test_decrement() {
275 let repo = setup_test_repo().await;
276 let relayer_id = uuid::Uuid::new_v4().to_string();
277 let address = uuid::Uuid::new_v4().to_string();
278
279 repo.set(&relayer_id, &address, 5).await.unwrap();
281
282 let result = repo.decrement(&relayer_id, &address).await.unwrap();
284 assert_eq!(result, 4);
285
286 let current = repo.get(&relayer_id, &address).await.unwrap();
287 assert_eq!(current, Some(4));
288 }
289
290 #[tokio::test]
291 #[ignore = "Requires active Redis instance"]
292 async fn test_decrement_not_found() {
293 let repo = setup_test_repo().await;
294 let result = repo.decrement("nonexistent", "0x1234").await;
295 assert!(matches!(result, Err(RepositoryError::NotFound(_))));
296 }
297
298 #[tokio::test]
299 #[ignore = "Requires active Redis instance"]
300 async fn test_empty_validation() {
301 let repo = setup_test_repo().await;
302
303 let result = repo.get("", "0x1234").await;
305 assert!(matches!(result, Err(RepositoryError::InvalidData(_))));
306
307 let result = repo.get("relayer", "").await;
309 assert!(matches!(result, Err(RepositoryError::InvalidData(_))));
310 }
311
312 #[tokio::test]
313 #[ignore = "Requires active Redis instance"]
314 async fn test_multiple_relayers() {
315 let repo = setup_test_repo().await;
316 let relayer_1 = uuid::Uuid::new_v4().to_string();
317 let relayer_2 = uuid::Uuid::new_v4().to_string();
318 let address_1 = uuid::Uuid::new_v4().to_string();
319 let address_2 = uuid::Uuid::new_v4().to_string();
320
321 repo.set(&relayer_1, &address_1, 100).await.unwrap();
323 repo.set(&relayer_1, &address_2, 200).await.unwrap();
324 repo.set(&relayer_2, &address_1, 300).await.unwrap();
325
326 assert_eq!(repo.get(&relayer_1, &address_1).await.unwrap(), Some(100));
328 assert_eq!(repo.get(&relayer_1, &address_2).await.unwrap(), Some(200));
329 assert_eq!(repo.get(&relayer_2, &address_1).await.unwrap(), Some(300));
330
331 assert_eq!(
333 repo.get_and_increment(&relayer_1, &address_1)
334 .await
335 .unwrap(),
336 100
337 );
338 assert_eq!(
339 repo.get_and_increment(&relayer_1, &address_1)
340 .await
341 .unwrap(),
342 101
343 );
344 assert_eq!(
345 repo.get_and_increment(&relayer_1, &address_2)
346 .await
347 .unwrap(),
348 200
349 );
350 assert_eq!(
351 repo.get_and_increment(&relayer_1, &address_2)
352 .await
353 .unwrap(),
354 201
355 );
356 assert_eq!(repo.get(&relayer_2, &address_1).await.unwrap(), Some(300));
357 }
358
359 #[tokio::test]
360 #[ignore = "Requires active Redis instance"]
361 async fn test_concurrent_get_and_increment() {
362 let repo = setup_test_repo().await;
363 let relayer_id = uuid::Uuid::new_v4().to_string();
364 let address = uuid::Uuid::new_v4().to_string();
365
366 repo.set(&relayer_id, &address, 100).await.unwrap();
368
369 let handles: Vec<_> = (0..10)
371 .map(|_| {
372 let repo = repo.clone();
373 let relayer_id = relayer_id.clone();
374 let address = address.clone();
375 tokio::spawn(
376 async move { repo.get_and_increment(&relayer_id, &address).await.unwrap() },
377 )
378 })
379 .collect();
380
381 let mut results = Vec::new();
383 for handle in handles {
384 results.push(handle.await.unwrap());
385 }
386
387 results.sort();
389
390 let expected: Vec<u64> = (100..110).collect();
392 assert_eq!(results, expected);
393
394 let final_value = repo.get(&relayer_id, &address).await.unwrap();
396 assert_eq!(final_value, Some(110));
397 }
398}