openzeppelin_relayer/repositories/api_key/
api_key_in_memory.rs

1//! This module provides an in-memory implementation of api keys.
2//!
3//! The `InMemoryApiKeyRepository` struct is used to store and retrieve api keys
4//! permissions.
5use crate::{
6    models::{ApiKeyRepoModel, PaginationQuery},
7    repositories::{ApiKeyRepositoryTrait, PaginatedResult, RepositoryError},
8};
9
10use async_trait::async_trait;
11
12use std::collections::HashMap;
13use tokio::sync::{Mutex, MutexGuard};
14
15#[derive(Debug)]
16pub struct InMemoryApiKeyRepository {
17    store: Mutex<HashMap<String, ApiKeyRepoModel>>,
18}
19
20impl Clone for InMemoryApiKeyRepository {
21    fn clone(&self) -> Self {
22        // Try to get the current data, or use empty HashMap if lock fails
23        let data = self
24            .store
25            .try_lock()
26            .map(|guard| guard.clone())
27            .unwrap_or_else(|_| HashMap::new());
28
29        Self {
30            store: Mutex::new(data),
31        }
32    }
33}
34
35impl InMemoryApiKeyRepository {
36    pub fn new() -> Self {
37        Self {
38            store: Mutex::new(HashMap::new()),
39        }
40    }
41
42    async fn acquire_lock<T>(lock: &Mutex<T>) -> Result<MutexGuard<T>, RepositoryError> {
43        Ok(lock.lock().await)
44    }
45}
46
47impl Default for InMemoryApiKeyRepository {
48    fn default() -> Self {
49        Self::new()
50    }
51}
52
53#[async_trait]
54impl ApiKeyRepositoryTrait for InMemoryApiKeyRepository {
55    async fn create(&self, api_key: ApiKeyRepoModel) -> Result<ApiKeyRepoModel, RepositoryError> {
56        let mut store = Self::acquire_lock(&self.store).await?;
57        store.insert(api_key.id.clone(), api_key.clone());
58        Ok(api_key)
59    }
60
61    async fn get_by_id(&self, id: &str) -> Result<Option<ApiKeyRepoModel>, RepositoryError> {
62        let store = Self::acquire_lock(&self.store).await?;
63        Ok(store.get(id).cloned())
64    }
65
66    async fn list_paginated(
67        &self,
68        query: PaginationQuery,
69    ) -> Result<PaginatedResult<ApiKeyRepoModel>, RepositoryError> {
70        let total = self.count().await?;
71        let start = ((query.page - 1) * query.per_page) as usize;
72
73        let items = self
74            .store
75            .lock()
76            .await
77            .values()
78            .skip(start)
79            .take(query.per_page as usize)
80            .cloned()
81            .collect();
82
83        Ok(PaginatedResult {
84            items,
85            total: total as u64,
86            page: query.page,
87            per_page: query.per_page,
88        })
89    }
90
91    async fn count(&self) -> Result<usize, RepositoryError> {
92        let store = self.store.lock().await;
93        Ok(store.len())
94    }
95
96    async fn list_permissions(&self, api_key_id: &str) -> Result<Vec<String>, RepositoryError> {
97        let store = self.store.lock().await;
98        let api_key = store
99            .get(api_key_id)
100            .ok_or(RepositoryError::NotFound(format!(
101                "Api key with id {api_key_id} not found"
102            )))?;
103        Ok(api_key.permissions.clone())
104    }
105
106    async fn delete_by_id(&self, api_key_id: &str) -> Result<(), RepositoryError> {
107        let mut store = self.store.lock().await;
108        store.remove(api_key_id);
109        Ok(())
110    }
111
112    async fn has_entries(&self) -> Result<bool, RepositoryError> {
113        let store = Self::acquire_lock(&self.store).await?;
114        Ok(!store.is_empty())
115    }
116
117    async fn drop_all_entries(&self) -> Result<(), RepositoryError> {
118        let mut store = Self::acquire_lock(&self.store).await?;
119        store.clear();
120        Ok(())
121    }
122}
123
124#[cfg(test)]
125mod tests {
126    use chrono::Utc;
127    use std::sync::Arc;
128
129    use crate::models::SecretString;
130
131    use super::*;
132
133    #[tokio::test]
134    async fn test_in_memory_api_key_repository() {
135        let api_key_repository = Arc::new(InMemoryApiKeyRepository::new());
136
137        // Test add and get_by_id
138        let api_key = ApiKeyRepoModel {
139            id: "test-api-key".to_string(),
140            value: SecretString::new("test-value"),
141            name: "test-name".to_string(),
142            allowed_origins: vec!["*".to_string()],
143            permissions: vec!["relayer:all:execute".to_string()],
144            created_at: Utc::now().to_string(),
145        };
146        api_key_repository.create(api_key.clone()).await.unwrap();
147        assert_eq!(
148            api_key_repository.get_by_id("test-api-key").await.unwrap(),
149            Some(api_key)
150        );
151    }
152
153    #[tokio::test]
154    async fn test_get_nonexistent_api_key() {
155        let api_key_repository = Arc::new(InMemoryApiKeyRepository::new());
156
157        let result = api_key_repository.get_by_id("test-api-key").await;
158        assert!(matches!(result, Ok(None)));
159    }
160
161    #[tokio::test]
162    async fn test_get_by_id() {
163        let api_key_repository = Arc::new(InMemoryApiKeyRepository::new());
164
165        let api_key = ApiKeyRepoModel {
166            id: "test-api-key".to_string(),
167            value: SecretString::new("test-value"),
168            name: "test-name".to_string(),
169            allowed_origins: vec!["*".to_string()],
170            permissions: vec!["relayer:all:execute".to_string()],
171            created_at: Utc::now().to_string(),
172        };
173        api_key_repository.create(api_key.clone()).await.unwrap();
174        assert_eq!(
175            api_key_repository.get_by_id("test-api-key").await.unwrap(),
176            Some(api_key)
177        );
178    }
179
180    #[tokio::test]
181    async fn test_list_paginated_api_keys() {
182        let api_key_repository = Arc::new(InMemoryApiKeyRepository::new());
183
184        let api_key1 = ApiKeyRepoModel {
185            id: "test-api-key1".to_string(),
186            value: SecretString::new("test-value1"),
187            name: "test-name1".to_string(),
188            allowed_origins: vec!["*".to_string()],
189            permissions: vec!["relayer:all:execute".to_string()],
190            created_at: Utc::now().to_string(),
191        };
192
193        let api_key2 = ApiKeyRepoModel {
194            id: "test-api-key2".to_string(),
195            value: SecretString::new("test-value2"),
196            name: "test-name2".to_string(),
197            allowed_origins: vec!["*".to_string()],
198            permissions: vec!["relayer:all:execute".to_string()],
199            created_at: Utc::now().to_string(),
200        };
201
202        api_key_repository.create(api_key1.clone()).await.unwrap();
203        api_key_repository.create(api_key2.clone()).await.unwrap();
204
205        let query = PaginationQuery {
206            page: 1,
207            per_page: 2,
208        };
209
210        let result = api_key_repository.list_paginated(query).await;
211        assert!(result.is_ok());
212        let result = result.unwrap();
213        assert_eq!(result.items.len(), 2);
214    }
215
216    #[tokio::test]
217    async fn test_has_entries() {
218        let api_key_repository = Arc::new(InMemoryApiKeyRepository::new());
219        assert!(!api_key_repository.has_entries().await.unwrap());
220        api_key_repository
221            .create(ApiKeyRepoModel {
222                id: "test-api-key".to_string(),
223                value: SecretString::new("test-value"),
224                name: "test-name".to_string(),
225                allowed_origins: vec!["*".to_string()],
226                permissions: vec!["relayer:all:execute".to_string()],
227                created_at: Utc::now().to_string(),
228            })
229            .await
230            .unwrap();
231
232        assert!(api_key_repository.has_entries().await.unwrap());
233        api_key_repository.drop_all_entries().await.unwrap();
234        assert!(!api_key_repository.has_entries().await.unwrap());
235    }
236
237    #[tokio::test]
238    async fn test_delete_by_id_api_key() {
239        let api_key_repository = Arc::new(InMemoryApiKeyRepository::new());
240        api_key_repository
241            .create(ApiKeyRepoModel {
242                id: "test-api-key".to_string(),
243                value: SecretString::new("test-value"),
244                name: "test-name".to_string(),
245                allowed_origins: vec!["*".to_string()],
246                permissions: vec!["relayer:all:execute".to_string()],
247                created_at: Utc::now().to_string(),
248            })
249            .await
250            .unwrap();
251
252        assert!(api_key_repository.has_entries().await.unwrap());
253        api_key_repository
254            .delete_by_id("test-api-key")
255            .await
256            .unwrap();
257        assert!(!api_key_repository.has_entries().await.unwrap());
258    }
259
260    #[tokio::test]
261    async fn test_list_permissions_api_key() {
262        let api_key_repository = Arc::new(InMemoryApiKeyRepository::new());
263        api_key_repository
264            .create(ApiKeyRepoModel {
265                id: "test-api-key".to_string(),
266                value: SecretString::new("test-value"),
267                name: "test-name".to_string(),
268                allowed_origins: vec!["*".to_string()],
269                permissions: vec![
270                    "relayer:all:execute".to_string(),
271                    "relayer:all:read".to_string(),
272                ],
273                created_at: Utc::now().to_string(),
274            })
275            .await
276            .unwrap();
277
278        let permissions = api_key_repository
279            .list_permissions("test-api-key")
280            .await
281            .unwrap();
282        assert_eq!(permissions, vec!["relayer:all:execute", "relayer:all:read"]);
283    }
284
285    #[tokio::test]
286    async fn test_drop_all_entries() {
287        let api_key_repository = Arc::new(InMemoryApiKeyRepository::new());
288        api_key_repository
289            .create(ApiKeyRepoModel {
290                id: "test-api-key".to_string(),
291                value: SecretString::new("test-value"),
292                name: "test-name".to_string(),
293                allowed_origins: vec!["*".to_string()],
294                permissions: vec!["relayer:all:execute".to_string()],
295                created_at: Utc::now().to_string(),
296            })
297            .await
298            .unwrap();
299
300        assert!(api_key_repository.has_entries().await.unwrap());
301        api_key_repository.drop_all_entries().await.unwrap();
302        assert!(!api_key_repository.has_entries().await.unwrap());
303    }
304}