openzeppelin_relayer/repositories/transaction_counter/
transaction_counter_redis.rs

1//! Redis implementation of the transaction counter.
2//!
3//! This module provides a Redis-based implementation of the `TransactionCounterTrait`,
4//! allowing transaction counters to be stored and retrieved from a Redis database.
5//! The implementation includes comprehensive error handling, logging, and atomic operations
6//! to ensure consistency when incrementing and decrementing counters.
7
8use super::TransactionCounterTrait;
9use crate::models::RepositoryError;
10use crate::repositories::redis_base::RedisRepository;
11use async_trait::async_trait;
12use redis::aio::ConnectionManager;
13use redis::AsyncCommands;
14use std::fmt;
15use std::sync::Arc;
16use tracing::debug;
17
18const COUNTER_PREFIX: &str = "transaction_counter";
19
20#[derive(Clone)]
21pub struct RedisTransactionCounter {
22    pub client: Arc<ConnectionManager>,
23    pub key_prefix: String,
24}
25
26impl RedisRepository for RedisTransactionCounter {}
27
28impl fmt::Debug for RedisTransactionCounter {
29    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
30        f.debug_struct("RedisTransactionCounter")
31            .field("key_prefix", &self.key_prefix)
32            .finish()
33    }
34}
35
36impl RedisTransactionCounter {
37    pub fn new(
38        connection_manager: Arc<ConnectionManager>,
39        key_prefix: String,
40    ) -> Result<Self, RepositoryError> {
41        if key_prefix.is_empty() {
42            return Err(RepositoryError::InvalidData(
43                "Redis key prefix cannot be empty".to_string(),
44            ));
45        }
46
47        Ok(Self {
48            client: connection_manager,
49            key_prefix,
50        })
51    }
52
53    /// Generate key for transaction counter: {prefix}:transaction_counter:{relayer_id}:{address}
54    fn counter_key(&self, relayer_id: &str, address: &str) -> String {
55        format!(
56            "{}:{}:{}:{}",
57            self.key_prefix, COUNTER_PREFIX, relayer_id, address
58        )
59    }
60}
61
62#[async_trait]
63impl TransactionCounterTrait for RedisTransactionCounter {
64    async fn get(&self, relayer_id: &str, address: &str) -> Result<Option<u64>, RepositoryError> {
65        if relayer_id.is_empty() {
66            return Err(RepositoryError::InvalidData(
67                "Relayer ID cannot be empty".to_string(),
68            ));
69        }
70
71        if address.is_empty() {
72            return Err(RepositoryError::InvalidData(
73                "Address cannot be empty".to_string(),
74            ));
75        }
76
77        let key = self.counter_key(relayer_id, address);
78        debug!(relayer_id = %relayer_id, address = %address, "getting counter for relayer and address");
79
80        let mut conn = self.client.as_ref().clone();
81
82        let value: Option<u64> = conn
83            .get(&key)
84            .await
85            .map_err(|e| self.map_redis_error(e, "get_counter"))?;
86
87        debug!(value = ?value, "retrieved counter value");
88        Ok(value)
89    }
90
91    async fn get_and_increment(
92        &self,
93        relayer_id: &str,
94        address: &str,
95    ) -> Result<u64, RepositoryError> {
96        if relayer_id.is_empty() {
97            return Err(RepositoryError::InvalidData(
98                "Relayer ID cannot be empty".to_string(),
99            ));
100        }
101
102        if address.is_empty() {
103            return Err(RepositoryError::InvalidData(
104                "Address cannot be empty".to_string(),
105            ));
106        }
107
108        let key = self.counter_key(relayer_id, address);
109        debug!(relayer_id = %relayer_id, address = %address, "getting and incrementing counter for relayer and address");
110
111        let mut conn = self.client.as_ref().clone();
112
113        // Use Redis INCR for atomic increment
114        let new_value: u64 = conn
115            .incr(&key, 1)
116            .await
117            .map_err(|e| self.map_redis_error(e, "get_and_increment"))?;
118
119        let current = new_value.saturating_sub(1);
120
121        debug!(from = %current, to = %(current + 1), "counter incremented");
122        Ok(current)
123    }
124
125    async fn decrement(&self, relayer_id: &str, address: &str) -> Result<u64, RepositoryError> {
126        if relayer_id.is_empty() {
127            return Err(RepositoryError::InvalidData(
128                "Relayer ID cannot be empty".to_string(),
129            ));
130        }
131
132        if address.is_empty() {
133            return Err(RepositoryError::InvalidData(
134                "Address cannot be empty".to_string(),
135            ));
136        }
137
138        let key = self.counter_key(relayer_id, address);
139        debug!(relayer_id = %relayer_id, address = %address, "decrementing counter for relayer and address");
140
141        let mut conn = self.client.as_ref().clone();
142
143        // Check if counter exists first
144        let exists: bool = conn
145            .exists(&key)
146            .await
147            .map_err(|e| self.map_redis_error(e, "check_counter_exists"))?;
148
149        if !exists {
150            return Err(RepositoryError::NotFound(format!(
151                "Counter not found for relayer {relayer_id} and address {address}"
152            )));
153        }
154
155        // Use Redis DECR and correct if it goes below 0
156        let new_value: i64 = conn
157            .decr(&key, 1)
158            .await
159            .map_err(|e| self.map_redis_error(e, "decrement_counter"))?;
160
161        let new_value = if new_value < 0 {
162            // Correct negative values back to 0
163            let _: () = conn
164                .set(&key, 0)
165                .await
166                .map_err(|e| self.map_redis_error(e, "correct_negative_counter"))?;
167            0u64
168        } else {
169            new_value as u64
170        };
171
172        debug!(new_value = %new_value, "counter decremented");
173        Ok(new_value)
174    }
175
176    async fn set(
177        &self,
178        relayer_id: &str,
179        address: &str,
180        value: u64,
181    ) -> Result<(), RepositoryError> {
182        if relayer_id.is_empty() {
183            return Err(RepositoryError::InvalidData(
184                "Relayer ID cannot be empty".to_string(),
185            ));
186        }
187
188        if address.is_empty() {
189            return Err(RepositoryError::InvalidData(
190                "Address cannot be empty".to_string(),
191            ));
192        }
193
194        let key = self.counter_key(relayer_id, address);
195        debug!(relayer_id = %relayer_id, address = %address, value = %value, "setting counter for relayer and address");
196
197        let mut conn = self.client.as_ref().clone();
198
199        let _: () = conn
200            .set(&key, value)
201            .await
202            .map_err(|e| self.map_redis_error(e, "set_counter"))?;
203
204        debug!(value = %value, "counter set");
205        Ok(())
206    }
207}
208
209#[cfg(test)]
210mod tests {
211    use super::*;
212    use redis::aio::ConnectionManager;
213    use std::sync::Arc;
214    use tokio;
215    use uuid::Uuid;
216
217    async fn setup_test_repo() -> RedisTransactionCounter {
218        let redis_url =
219            std::env::var("REDIS_URL").unwrap_or_else(|_| "redis://127.0.0.1:6379".to_string());
220        let client = redis::Client::open(redis_url).expect("Failed to create Redis client");
221        let connection_manager = ConnectionManager::new(client)
222            .await
223            .expect("Failed to create Redis connection manager");
224
225        RedisTransactionCounter::new(Arc::new(connection_manager), "test_counter".to_string())
226            .expect("Failed to create Redis transaction counter")
227    }
228
229    #[tokio::test]
230    #[ignore = "Requires active Redis instance"]
231    async fn test_get_nonexistent_counter() {
232        let repo = setup_test_repo().await;
233        let random_id = Uuid::new_v4().to_string();
234        let result = repo.get(&random_id, "0x1234").await.unwrap();
235        assert_eq!(result, None);
236    }
237
238    #[tokio::test]
239    #[ignore = "Requires active Redis instance"]
240    async fn test_set_and_get_counter() {
241        let repo = setup_test_repo().await;
242        let relayer_id = uuid::Uuid::new_v4().to_string();
243        let address = uuid::Uuid::new_v4().to_string();
244
245        repo.set(&relayer_id, &address, 100).await.unwrap();
246        let result = repo.get(&relayer_id, &address).await.unwrap();
247        assert_eq!(result, Some(100));
248    }
249
250    #[tokio::test]
251    #[ignore = "Requires active Redis instance"]
252    async fn test_get_and_increment() {
253        let repo = setup_test_repo().await;
254        let relayer_id = uuid::Uuid::new_v4().to_string();
255        let address = uuid::Uuid::new_v4().to_string();
256
257        // First increment should return 0 and set to 1
258        let result = repo.get_and_increment(&relayer_id, &address).await.unwrap();
259        assert_eq!(result, 0);
260
261        let current = repo.get(&relayer_id, &address).await.unwrap();
262        assert_eq!(current, Some(1));
263
264        // Second increment should return 1 and set to 2
265        let result = repo.get_and_increment(&relayer_id, &address).await.unwrap();
266        assert_eq!(result, 1);
267
268        let current = repo.get(&relayer_id, &address).await.unwrap();
269        assert_eq!(current, Some(2));
270    }
271
272    #[tokio::test]
273    #[ignore = "Requires active Redis instance"]
274    async fn test_decrement() {
275        let repo = setup_test_repo().await;
276        let relayer_id = uuid::Uuid::new_v4().to_string();
277        let address = uuid::Uuid::new_v4().to_string();
278
279        // Set initial value
280        repo.set(&relayer_id, &address, 5).await.unwrap();
281
282        // Decrement should return 4
283        let result = repo.decrement(&relayer_id, &address).await.unwrap();
284        assert_eq!(result, 4);
285
286        let current = repo.get(&relayer_id, &address).await.unwrap();
287        assert_eq!(current, Some(4));
288    }
289
290    #[tokio::test]
291    #[ignore = "Requires active Redis instance"]
292    async fn test_decrement_not_found() {
293        let repo = setup_test_repo().await;
294        let result = repo.decrement("nonexistent", "0x1234").await;
295        assert!(matches!(result, Err(RepositoryError::NotFound(_))));
296    }
297
298    #[tokio::test]
299    #[ignore = "Requires active Redis instance"]
300    async fn test_empty_validation() {
301        let repo = setup_test_repo().await;
302
303        // Test empty relayer_id
304        let result = repo.get("", "0x1234").await;
305        assert!(matches!(result, Err(RepositoryError::InvalidData(_))));
306
307        // Test empty address
308        let result = repo.get("relayer", "").await;
309        assert!(matches!(result, Err(RepositoryError::InvalidData(_))));
310    }
311
312    #[tokio::test]
313    #[ignore = "Requires active Redis instance"]
314    async fn test_multiple_relayers() {
315        let repo = setup_test_repo().await;
316        let relayer_1 = uuid::Uuid::new_v4().to_string();
317        let relayer_2 = uuid::Uuid::new_v4().to_string();
318        let address_1 = uuid::Uuid::new_v4().to_string();
319        let address_2 = uuid::Uuid::new_v4().to_string();
320
321        // Set different values for different relayer/address combinations
322        repo.set(&relayer_1, &address_1, 100).await.unwrap();
323        repo.set(&relayer_1, &address_2, 200).await.unwrap();
324        repo.set(&relayer_2, &address_1, 300).await.unwrap();
325
326        // Verify independent counters
327        assert_eq!(repo.get(&relayer_1, &address_1).await.unwrap(), Some(100));
328        assert_eq!(repo.get(&relayer_1, &address_2).await.unwrap(), Some(200));
329        assert_eq!(repo.get(&relayer_2, &address_1).await.unwrap(), Some(300));
330
331        // Verify independent increments
332        assert_eq!(
333            repo.get_and_increment(&relayer_1, &address_1)
334                .await
335                .unwrap(),
336            100
337        );
338        assert_eq!(
339            repo.get_and_increment(&relayer_1, &address_1)
340                .await
341                .unwrap(),
342            101
343        );
344        assert_eq!(
345            repo.get_and_increment(&relayer_1, &address_2)
346                .await
347                .unwrap(),
348            200
349        );
350        assert_eq!(
351            repo.get_and_increment(&relayer_1, &address_2)
352                .await
353                .unwrap(),
354            201
355        );
356        assert_eq!(repo.get(&relayer_2, &address_1).await.unwrap(), Some(300));
357    }
358
359    #[tokio::test]
360    #[ignore = "Requires active Redis instance"]
361    async fn test_concurrent_get_and_increment() {
362        let repo = setup_test_repo().await;
363        let relayer_id = uuid::Uuid::new_v4().to_string();
364        let address = uuid::Uuid::new_v4().to_string();
365
366        // Set initial value
367        repo.set(&relayer_id, &address, 100).await.unwrap();
368
369        // Create multiple concurrent tasks that increment the counter
370        let handles: Vec<_> = (0..10)
371            .map(|_| {
372                let repo = repo.clone();
373                let relayer_id = relayer_id.clone();
374                let address = address.clone();
375                tokio::spawn(
376                    async move { repo.get_and_increment(&relayer_id, &address).await.unwrap() },
377                )
378            })
379            .collect();
380
381        // Wait for all tasks to complete and collect results
382        let mut results = Vec::new();
383        for handle in handles {
384            results.push(handle.await.unwrap());
385        }
386
387        // Sort results to check they are sequential
388        results.sort();
389
390        // Verify we get exactly the values 100-109 (no duplicates, no gaps)
391        let expected: Vec<u64> = (100..110).collect();
392        assert_eq!(results, expected);
393
394        // Verify final value is 110
395        let final_value = repo.get(&relayer_id, &address).await.unwrap();
396        assert_eq!(final_value, Some(110));
397    }
398}