openzeppelin_relayer/services/plugins/
runner.rs

1//! This module is the orchestrator of the plugin execution.
2//!
3//! 1. Initiates a socket connection to the relayer server - socket.rs
4//! 2. Executes the plugin script - script_executor.rs
5//! 3. Sends the shutdown signal to the relayer server - socket.rs
6//! 4. Waits for the relayer server to finish the execution - socket.rs
7//! 5. Returns the output of the script - script_executor.rs
8//!
9use std::{sync::Arc, time::Duration};
10
11use crate::services::plugins::{RelayerApi, ScriptExecutor, ScriptResult, SocketService};
12use crate::{
13    jobs::JobProducerTrait,
14    models::{
15        NetworkRepoModel, NotificationRepoModel, RelayerRepoModel, SignerRepoModel,
16        ThinDataAppState, TransactionRepoModel,
17    },
18    repositories::{
19        ApiKeyRepositoryTrait, NetworkRepository, PluginRepositoryTrait, RelayerRepository,
20        Repository, TransactionCounterTrait, TransactionRepository,
21    },
22};
23
24use super::PluginError;
25use async_trait::async_trait;
26use tokio::{sync::oneshot, time::timeout};
27
28#[cfg(test)]
29use mockall::automock;
30
31#[cfg_attr(test, automock)]
32#[async_trait]
33pub trait PluginRunnerTrait {
34    #[allow(clippy::type_complexity, clippy::too_many_arguments)]
35    async fn run<J, RR, TR, NR, NFR, SR, TCR, PR, AKR>(
36        &self,
37        plugin_id: String,
38        socket_path: &str,
39        script_path: String,
40        timeout_duration: Duration,
41        script_params: String,
42        http_request_id: Option<String>,
43        state: Arc<ThinDataAppState<J, RR, TR, NR, NFR, SR, TCR, PR, AKR>>,
44    ) -> Result<ScriptResult, PluginError>
45    where
46        J: JobProducerTrait + Send + Sync + 'static,
47        RR: RelayerRepository + Repository<RelayerRepoModel, String> + Send + Sync + 'static,
48        TR: TransactionRepository
49            + Repository<TransactionRepoModel, String>
50            + Send
51            + Sync
52            + 'static,
53        NR: NetworkRepository + Repository<NetworkRepoModel, String> + Send + Sync + 'static,
54        NFR: Repository<NotificationRepoModel, String> + Send + Sync + 'static,
55        SR: Repository<SignerRepoModel, String> + Send + Sync + 'static,
56        TCR: TransactionCounterTrait + Send + Sync + 'static,
57        PR: PluginRepositoryTrait + Send + Sync + 'static,
58        AKR: ApiKeyRepositoryTrait + Send + Sync + 'static;
59}
60
61#[derive(Default)]
62pub struct PluginRunner;
63
64#[async_trait]
65impl PluginRunnerTrait for PluginRunner {
66    async fn run<J, RR, TR, NR, NFR, SR, TCR, PR, AKR>(
67        &self,
68        plugin_id: String,
69        socket_path: &str,
70        script_path: String,
71        timeout_duration: Duration,
72        script_params: String,
73        http_request_id: Option<String>,
74        state: Arc<ThinDataAppState<J, RR, TR, NR, NFR, SR, TCR, PR, AKR>>,
75    ) -> Result<ScriptResult, PluginError>
76    where
77        J: JobProducerTrait + Send + Sync + 'static,
78        RR: RelayerRepository + Repository<RelayerRepoModel, String> + Send + Sync + 'static,
79        TR: TransactionRepository
80            + Repository<TransactionRepoModel, String>
81            + Send
82            + Sync
83            + 'static,
84        NR: NetworkRepository + Repository<NetworkRepoModel, String> + Send + Sync + 'static,
85        NFR: Repository<NotificationRepoModel, String> + Send + Sync + 'static,
86        SR: Repository<SignerRepoModel, String> + Send + Sync + 'static,
87        TCR: TransactionCounterTrait + Send + Sync + 'static,
88        PR: PluginRepositoryTrait + Send + Sync + 'static,
89        AKR: ApiKeyRepositoryTrait + Send + Sync + 'static,
90    {
91        let socket_service = SocketService::new(socket_path)?;
92        let socket_path_clone = socket_service.socket_path().to_string();
93
94        let (shutdown_tx, shutdown_rx) = oneshot::channel();
95
96        let server_handle = tokio::spawn(async move {
97            let relayer_api = Arc::new(RelayerApi);
98            socket_service.listen(shutdown_rx, state, relayer_api).await
99        });
100
101        let exec_outcome = match timeout(
102            timeout_duration,
103            ScriptExecutor::execute_typescript(
104                plugin_id,
105                script_path,
106                socket_path_clone,
107                script_params,
108                http_request_id,
109            ),
110        )
111        .await
112        {
113            Ok(result) => result,
114            Err(_) => {
115                // ensures the socket gets closed.
116                let _ = shutdown_tx.send(());
117                return Err(PluginError::ScriptTimeout(timeout_duration.as_secs()));
118            }
119        };
120
121        let _ = shutdown_tx.send(());
122
123        let server_handle = server_handle
124            .await
125            .map_err(|e| PluginError::SocketError(e.to_string()))?;
126
127        let traces = match server_handle {
128            Ok(traces) => traces,
129            Err(e) => return Err(PluginError::SocketError(e.to_string())),
130        };
131
132        match exec_outcome {
133            Ok(mut script_result) => {
134                // attach traces on success
135                script_result.trace = traces;
136                Ok(script_result)
137            }
138            Err(err) => Err(err.with_traces(traces)),
139        }
140    }
141}
142
143#[cfg(test)]
144mod tests {
145    use actix_web::web;
146    use std::fs;
147
148    use crate::{
149        jobs::MockJobProducerTrait,
150        repositories::{
151            ApiKeyRepositoryStorage, NetworkRepositoryStorage, NotificationRepositoryStorage,
152            PluginRepositoryStorage, RelayerRepositoryStorage, SignerRepositoryStorage,
153            TransactionCounterRepositoryStorage, TransactionRepositoryStorage,
154        },
155        services::plugins::LogLevel,
156        utils::mocks::mockutils::create_mock_app_state,
157    };
158    use tempfile::tempdir;
159
160    use super::*;
161
162    static TS_CONFIG: &str = r#"
163        {
164            "compilerOptions": {
165              "target": "es2016",
166              "module": "commonjs",
167              "esModuleInterop": true,
168              "forceConsistentCasingInFileNames": true,
169              "strict": true,
170              "skipLibCheck": true
171            }
172          }
173    "#;
174
175    #[tokio::test]
176    async fn test_run() {
177        let temp_dir = tempdir().unwrap();
178        let ts_config = temp_dir.path().join("tsconfig.json");
179        let script_path = temp_dir.path().join("test_run.ts");
180        let socket_path = temp_dir.path().join("test_run.sock");
181
182        let content = r#"
183            export async function handler(api: any, params: any) {
184                console.log('test');
185                console.error('test-error');
186                return 'test-result';
187            }
188        "#;
189        fs::write(script_path.clone(), content).unwrap();
190        fs::write(ts_config.clone(), TS_CONFIG.as_bytes()).unwrap();
191
192        let state = create_mock_app_state(None, None, None, None, None, None).await;
193
194        let plugin_runner = PluginRunner;
195        let plugin_id = "test-plugin".to_string();
196        let socket_path_str = socket_path.display().to_string();
197        let script_path_str = script_path.display().to_string();
198        let result = plugin_runner
199            .run::<MockJobProducerTrait, RelayerRepositoryStorage, TransactionRepositoryStorage, NetworkRepositoryStorage, NotificationRepositoryStorage, SignerRepositoryStorage, TransactionCounterRepositoryStorage, PluginRepositoryStorage, ApiKeyRepositoryStorage>(
200                plugin_id,
201                &socket_path_str,
202                script_path_str,
203                Duration::from_secs(10),
204                "{ \"test\": \"test\" }".to_string(),
205                None,
206                Arc::new(web::ThinData(state)),
207            )
208            .await;
209        if matches!(
210            result,
211            Err(PluginError::SocketError(ref msg)) if msg.contains("Operation not permitted")
212        ) {
213            eprintln!("skipping test_run due to sandbox socket restrictions");
214            return;
215        }
216
217        let result = result.expect("runner should complete without error");
218        assert_eq!(result.logs[0].level, LogLevel::Log);
219        assert_eq!(result.logs[0].message, "test");
220        assert_eq!(result.logs[1].level, LogLevel::Error);
221        assert_eq!(result.logs[1].message, "test-error");
222        assert_eq!(result.return_value, "test-result");
223    }
224
225    #[tokio::test]
226    async fn test_run_timeout() {
227        let temp_dir = tempdir().unwrap();
228        let ts_config = temp_dir.path().join("tsconfig.json");
229        let script_path = temp_dir.path().join("test_simple_timeout.ts");
230        let socket_path = temp_dir.path().join("test_simple_timeout.sock");
231
232        // Script that takes 200ms
233        let content = r#"
234            function sleep(ms) {
235                return new Promise(resolve => setTimeout(resolve, ms));
236            }
237
238            async function main() {
239                await sleep(200); // 200ms
240                console.log(JSON.stringify({ level: 'result', message: 'Should not reach here' }));
241            }
242
243            main();
244        "#;
245
246        fs::write(script_path.clone(), content).unwrap();
247        fs::write(ts_config.clone(), TS_CONFIG.as_bytes()).unwrap();
248
249        let state = create_mock_app_state(None, None, None, None, None, None).await;
250        let plugin_runner = PluginRunner;
251
252        // Use 100ms timeout for a 200ms script
253        let plugin_id = "test-plugin".to_string();
254        let socket_path_str = socket_path.display().to_string();
255        let script_path_str = script_path.display().to_string();
256        let result = plugin_runner
257            .run::<MockJobProducerTrait, RelayerRepositoryStorage, TransactionRepositoryStorage, NetworkRepositoryStorage, NotificationRepositoryStorage, SignerRepositoryStorage, TransactionCounterRepositoryStorage, PluginRepositoryStorage, ApiKeyRepositoryStorage>(
258                plugin_id,
259                &socket_path_str,
260                script_path_str,
261                Duration::from_millis(100), // 100ms timeout
262                "{}".to_string(),
263                None,
264                Arc::new(web::ThinData(state)),
265            )
266            .await;
267
268        // Should timeout
269        if matches!(
270            result,
271            Err(PluginError::SocketError(ref msg)) if msg.contains("Operation not permitted")
272        ) {
273            eprintln!("skipping test_run_timeout due to sandbox socket restrictions");
274            return;
275        }
276
277        let err = result.expect_err("runner should timeout");
278        assert!(err.to_string().contains("Script execution timed out after"));
279    }
280}