openzeppelin_relayer/repositories/transaction_counter/
transaction_counter_in_memory.rs1use async_trait::async_trait;
8use dashmap::DashMap;
9
10use crate::repositories::{RepositoryError, TransactionCounterTrait};
11
12#[derive(Debug, Default, Clone)]
13pub struct InMemoryTransactionCounter {
14 store: DashMap<(String, String), u64>, }
16
17impl InMemoryTransactionCounter {
18 pub fn new() -> Self {
19 Self {
20 store: DashMap::new(),
21 }
22 }
23}
24
25#[async_trait]
26impl TransactionCounterTrait for InMemoryTransactionCounter {
27 async fn get(&self, relayer_id: &str, address: &str) -> Result<Option<u64>, RepositoryError> {
28 Ok(self
29 .store
30 .get(&(relayer_id.to_string(), address.to_string()))
31 .map(|n| *n))
32 }
33
34 async fn get_and_increment(
35 &self,
36 relayer_id: &str,
37 address: &str,
38 ) -> Result<u64, RepositoryError> {
39 let mut entry = self
40 .store
41 .entry((relayer_id.to_string(), address.to_string()))
42 .or_insert(0);
43 let current = *entry;
44 *entry += 1;
45 Ok(current)
46 }
47
48 async fn decrement(&self, relayer_id: &str, address: &str) -> Result<u64, RepositoryError> {
49 let mut entry = self
50 .store
51 .get_mut(&(relayer_id.to_string(), address.to_string()))
52 .ok_or_else(|| RepositoryError::NotFound(format!("Counter not found for {address}")))?;
53 if *entry > 0 {
54 *entry -= 1;
55 }
56 Ok(*entry)
57 }
58
59 async fn set(
60 &self,
61 relayer_id: &str,
62 address: &str,
63 value: u64,
64 ) -> Result<(), RepositoryError> {
65 self.store
66 .insert((relayer_id.to_string(), address.to_string()), value);
67 Ok(())
68 }
69}
70
71#[cfg(test)]
72mod tests {
73 use super::*;
74
75 #[tokio::test]
76 async fn test_decrement_not_found() {
77 let store = InMemoryTransactionCounter::new();
78 let result = store.decrement("nonexistent", "0x1234").await;
79 assert!(matches!(result, Err(RepositoryError::NotFound(_))));
80 }
81
82 #[tokio::test]
83 async fn test_nonce_store() {
84 let store = InMemoryTransactionCounter::new();
85 let relayer_id = "relayer_1";
86 let address = "0x1234";
87
88 assert_eq!(store.get(relayer_id, address).await.unwrap(), None);
90
91 store.set(relayer_id, address, 100).await.unwrap();
93 assert_eq!(store.get(relayer_id, address).await.unwrap(), Some(100));
94
95 assert_eq!(
97 store.get_and_increment(relayer_id, address).await.unwrap(),
98 100
99 );
100 assert_eq!(store.get(relayer_id, address).await.unwrap(), Some(101));
101
102 assert_eq!(store.decrement(relayer_id, address).await.unwrap(), 100);
104 assert_eq!(store.get(relayer_id, address).await.unwrap(), Some(100));
105 }
106
107 #[tokio::test]
108 async fn test_multiple_relayers() {
109 let store = InMemoryTransactionCounter::new();
110
111 store.set("relayer_1", "0x1234", 100).await.unwrap();
113 store.set("relayer_1", "0x5678", 200).await.unwrap();
114 store.set("relayer_2", "0x1234", 300).await.unwrap();
115
116 assert_eq!(store.get("relayer_1", "0x1234").await.unwrap(), Some(100));
118 assert_eq!(store.get("relayer_1", "0x5678").await.unwrap(), Some(200));
119 assert_eq!(store.get("relayer_2", "0x1234").await.unwrap(), Some(300));
120
121 assert_eq!(
123 store
124 .get_and_increment("relayer_1", "0x1234")
125 .await
126 .unwrap(),
127 100
128 );
129 assert_eq!(
130 store
131 .get_and_increment("relayer_1", "0x1234")
132 .await
133 .unwrap(),
134 101
135 );
136 assert_eq!(
137 store
138 .get_and_increment("relayer_1", "0x5678")
139 .await
140 .unwrap(),
141 200
142 );
143 assert_eq!(
144 store
145 .get_and_increment("relayer_1", "0x5678")
146 .await
147 .unwrap(),
148 201
149 );
150 assert_eq!(store.get("relayer_2", "0x1234").await.unwrap(), Some(300));
151 }
152}