openzeppelin_relayer/repositories/plugin/
plugin_in_memory.rs

1//! This module provides an in-memory implementation of plugins.
2//!
3//! The `InMemoryPluginRepository` struct is used to store and retrieve plugins
4//! script paths for further execution.
5use crate::{
6    models::{PaginationQuery, PluginModel},
7    repositories::{PaginatedResult, PluginRepositoryTrait, RepositoryError},
8};
9
10use async_trait::async_trait;
11
12use std::collections::HashMap;
13use tokio::sync::{Mutex, MutexGuard};
14
15#[derive(Debug)]
16pub struct InMemoryPluginRepository {
17    store: Mutex<HashMap<String, PluginModel>>,
18}
19
20impl Clone for InMemoryPluginRepository {
21    fn clone(&self) -> Self {
22        // Try to get the current data, or use empty HashMap if lock fails
23        let data = self
24            .store
25            .try_lock()
26            .map(|guard| guard.clone())
27            .unwrap_or_else(|_| HashMap::new());
28
29        Self {
30            store: Mutex::new(data),
31        }
32    }
33}
34
35impl InMemoryPluginRepository {
36    pub fn new() -> Self {
37        Self {
38            store: Mutex::new(HashMap::new()),
39        }
40    }
41
42    pub async fn get_by_id(&self, id: &str) -> Result<Option<PluginModel>, RepositoryError> {
43        let store = Self::acquire_lock(&self.store).await?;
44        Ok(store.get(id).cloned())
45    }
46
47    async fn acquire_lock<T>(lock: &Mutex<T>) -> Result<MutexGuard<T>, RepositoryError> {
48        Ok(lock.lock().await)
49    }
50}
51
52impl Default for InMemoryPluginRepository {
53    fn default() -> Self {
54        Self::new()
55    }
56}
57
58#[async_trait]
59impl PluginRepositoryTrait for InMemoryPluginRepository {
60    async fn get_by_id(&self, id: &str) -> Result<Option<PluginModel>, RepositoryError> {
61        let store = Self::acquire_lock(&self.store).await?;
62        Ok(store.get(id).cloned())
63    }
64
65    async fn add(&self, plugin: PluginModel) -> Result<(), RepositoryError> {
66        let mut store = Self::acquire_lock(&self.store).await?;
67        store.insert(plugin.id.clone(), plugin);
68        Ok(())
69    }
70
71    async fn list_paginated(
72        &self,
73        query: PaginationQuery,
74    ) -> Result<PaginatedResult<PluginModel>, RepositoryError> {
75        let total = self.count().await?;
76        let start = ((query.page - 1) * query.per_page) as usize;
77
78        let items = self
79            .store
80            .lock()
81            .await
82            .values()
83            .skip(start)
84            .take(query.per_page as usize)
85            .cloned()
86            .collect();
87
88        Ok(PaginatedResult {
89            items,
90            total: total as u64,
91            page: query.page,
92            per_page: query.per_page,
93        })
94    }
95
96    async fn count(&self) -> Result<usize, RepositoryError> {
97        let store = self.store.lock().await;
98        Ok(store.len())
99    }
100
101    async fn has_entries(&self) -> Result<bool, RepositoryError> {
102        let store = Self::acquire_lock(&self.store).await?;
103        Ok(!store.is_empty())
104    }
105
106    async fn drop_all_entries(&self) -> Result<(), RepositoryError> {
107        let mut store = Self::acquire_lock(&self.store).await?;
108        store.clear();
109        Ok(())
110    }
111}
112
113#[cfg(test)]
114mod tests {
115    use crate::{config::PluginFileConfig, constants::DEFAULT_PLUGIN_TIMEOUT_SECONDS};
116
117    use super::*;
118    use std::{sync::Arc, time::Duration};
119
120    #[tokio::test]
121    async fn test_in_memory_plugin_repository() {
122        let plugin_repository = Arc::new(InMemoryPluginRepository::new());
123
124        // Test add and get_by_id
125        let plugin = PluginModel {
126            id: "test-plugin".to_string(),
127            path: "test-path".to_string(),
128            timeout: Duration::from_secs(DEFAULT_PLUGIN_TIMEOUT_SECONDS),
129            emit_logs: false,
130            emit_traces: false,
131        };
132        plugin_repository.add(plugin.clone()).await.unwrap();
133        assert_eq!(
134            plugin_repository.get_by_id("test-plugin").await.unwrap(),
135            Some(plugin)
136        );
137    }
138
139    #[tokio::test]
140    async fn test_get_nonexistent_plugin() {
141        let plugin_repository = Arc::new(InMemoryPluginRepository::new());
142
143        let result = plugin_repository.get_by_id("test-plugin").await;
144        assert!(matches!(result, Ok(None)));
145    }
146
147    #[tokio::test]
148    async fn test_try_from() {
149        let plugin = PluginFileConfig {
150            id: "test-plugin".to_string(),
151            path: "test-path".to_string(),
152            timeout: None,
153            emit_logs: false,
154            emit_traces: false,
155        };
156        let result = PluginModel::try_from(plugin);
157        assert!(result.is_ok());
158        assert_eq!(
159            result.unwrap(),
160            PluginModel {
161                id: "test-plugin".to_string(),
162                path: "test-path".to_string(),
163                timeout: Duration::from_secs(DEFAULT_PLUGIN_TIMEOUT_SECONDS),
164                emit_logs: false,
165                emit_traces: false,
166            }
167        );
168    }
169
170    #[tokio::test]
171    async fn test_get_by_id() {
172        let plugin_repository = Arc::new(InMemoryPluginRepository::new());
173
174        let plugin = PluginModel {
175            id: "test-plugin".to_string(),
176            path: "test-path".to_string(),
177            timeout: Duration::from_secs(DEFAULT_PLUGIN_TIMEOUT_SECONDS),
178            emit_logs: false,
179            emit_traces: false,
180        };
181        plugin_repository.add(plugin.clone()).await.unwrap();
182        assert_eq!(
183            plugin_repository.get_by_id("test-plugin").await.unwrap(),
184            Some(plugin)
185        );
186    }
187
188    #[tokio::test]
189    async fn test_list_paginated() {
190        let plugin_repository = Arc::new(InMemoryPluginRepository::new());
191
192        let plugin1 = PluginModel {
193            id: "test-plugin1".to_string(),
194            path: "test-path1".to_string(),
195            timeout: Duration::from_secs(DEFAULT_PLUGIN_TIMEOUT_SECONDS),
196            emit_logs: false,
197            emit_traces: false,
198        };
199
200        let plugin2 = PluginModel {
201            id: "test-plugin2".to_string(),
202            path: "test-path2".to_string(),
203            timeout: Duration::from_secs(DEFAULT_PLUGIN_TIMEOUT_SECONDS),
204            emit_logs: false,
205            emit_traces: false,
206        };
207
208        plugin_repository.add(plugin1.clone()).await.unwrap();
209        plugin_repository.add(plugin2.clone()).await.unwrap();
210
211        let query = PaginationQuery {
212            page: 1,
213            per_page: 2,
214        };
215
216        let result = plugin_repository.list_paginated(query).await;
217        assert!(result.is_ok());
218        let result = result.unwrap();
219        assert_eq!(result.items.len(), 2);
220    }
221
222    #[tokio::test]
223    async fn test_has_entries() {
224        let plugin_repository = Arc::new(InMemoryPluginRepository::new());
225        assert!(!plugin_repository.has_entries().await.unwrap());
226        plugin_repository
227            .add(PluginModel {
228                id: "test-plugin".to_string(),
229                path: "test-path".to_string(),
230                timeout: Duration::from_secs(DEFAULT_PLUGIN_TIMEOUT_SECONDS),
231                emit_logs: false,
232                emit_traces: false,
233            })
234            .await
235            .unwrap();
236
237        assert!(plugin_repository.has_entries().await.unwrap());
238        plugin_repository.drop_all_entries().await.unwrap();
239        assert!(!plugin_repository.has_entries().await.unwrap());
240    }
241
242    #[tokio::test]
243    async fn test_drop_all_entries() {
244        let plugin_repository = Arc::new(InMemoryPluginRepository::new());
245        plugin_repository
246            .add(PluginModel {
247                id: "test-plugin".to_string(),
248                path: "test-path".to_string(),
249                timeout: Duration::from_secs(DEFAULT_PLUGIN_TIMEOUT_SECONDS),
250                emit_logs: false,
251                emit_traces: false,
252            })
253            .await
254            .unwrap();
255
256        assert!(plugin_repository.has_entries().await.unwrap());
257        plugin_repository.drop_all_entries().await.unwrap();
258        assert!(!plugin_repository.has_entries().await.unwrap());
259    }
260}