openzeppelin_relayer/services/plugins/
socket.rs

1//! This module is responsible for creating a socket connection to the relayer server.
2//! It is used to send requests to the relayer server and processing the responses.
3//! It also intercepts the logs, errors and return values.
4//!
5//! The socket connection is created using the `UnixListener`.
6//!
7//! 1. Creates a socket connection using the `UnixListener`.
8//! 2. Each request payload is stringified by the client and added as a new line to the socket.
9//! 3. The server reads the requests from the socket and processes them.
10//! 4. The server sends the responses back to the client in the same format. By writing a new line in the socket
11//! 5. When the client sends the socket shutdown signal, the server closes the socket connection.
12//!
13//! Example:
14//! 1. Create a new socket connection using `/tmp/socket.sock`
15//! 2. Client sends request (writes in `/tmp/socket.sock`):
16//! ```json
17//! {
18//!   "request_id": "123",
19//!   "relayer_id": "relayer1",
20//!   "method": "sendTransaction",
21//!   "payload": {
22//!     "to": "0x1234567890123456789012345678901234567890",
23//!     "value": "1000000000000000000"
24//!   }
25//! }
26//! ```
27//! 3. Server process the requests, calls the relayer API and sends back the response (writes in `/tmp/socket.sock`):
28//! ```json
29//! {
30//!   "request_id": "123",
31//!   "result": {
32//!     "id": "123",
33//!     "status": "success"
34//!   }
35//! }
36//! ```
37//! 4. Client reads the response (reads from `/tmp/socket.sock`):
38//! ```json
39//! {
40//!   "request_id": "123",
41//!   "result": {
42//!     "id": "123",
43//!     "status": "success"
44//!   }
45//! }
46//! ```
47//! 5. Once the client finishes the execution, it sends a shutdown signal to the server.
48//! 6. The server closes the socket connection.
49//!
50
51use crate::jobs::JobProducerTrait;
52use crate::models::{
53    NetworkRepoModel, NotificationRepoModel, RelayerRepoModel, SignerRepoModel, ThinDataAppState,
54    TransactionRepoModel,
55};
56use crate::repositories::{
57    ApiKeyRepositoryTrait, NetworkRepository, PluginRepositoryTrait, RelayerRepository, Repository,
58    TransactionCounterTrait, TransactionRepository,
59};
60use std::sync::Arc;
61use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
62use tokio::net::{UnixListener, UnixStream};
63use tokio::sync::oneshot;
64use tracing::debug;
65
66use super::{
67    relayer_api::{RelayerApiTrait, Request},
68    PluginError,
69};
70
71pub struct SocketService {
72    socket_path: String,
73    listener: UnixListener,
74}
75
76impl SocketService {
77    /// Creates a new socket service.
78    ///
79    /// # Arguments
80    ///
81    /// * `socket_path` - The path to the socket file.
82    pub fn new(socket_path: &str) -> Result<Self, PluginError> {
83        // Remove existing socket file if it exists
84        let _ = std::fs::remove_file(socket_path);
85
86        let listener =
87            UnixListener::bind(socket_path).map_err(|e| PluginError::SocketError(e.to_string()))?;
88
89        Ok(Self {
90            socket_path: socket_path.to_string(),
91            listener,
92        })
93    }
94
95    pub fn socket_path(&self) -> &str {
96        &self.socket_path
97    }
98
99    /// Listens for incoming connections and processes the requests.
100    ///
101    /// # Arguments
102    ///
103    /// * `shutdown_rx` - A receiver for the shutdown signal.
104    /// * `state` - The application state.
105    /// * `relayer_api` - The relayer API.
106    ///
107    /// # Returns
108    ///
109    /// A vector of traces.
110    #[allow(clippy::type_complexity)]
111    pub async fn listen<RA, J, RR, TR, NR, NFR, SR, TCR, PR, AKR>(
112        self,
113        shutdown_rx: oneshot::Receiver<()>,
114        state: Arc<ThinDataAppState<J, RR, TR, NR, NFR, SR, TCR, PR, AKR>>,
115        relayer_api: Arc<RA>,
116    ) -> Result<Vec<serde_json::Value>, PluginError>
117    where
118        RA: RelayerApiTrait<J, RR, TR, NR, NFR, SR, TCR, PR, AKR> + 'static + Send + Sync,
119        J: JobProducerTrait + Send + Sync + 'static,
120        RR: RelayerRepository + Repository<RelayerRepoModel, String> + Send + Sync + 'static,
121        TR: TransactionRepository
122            + Repository<TransactionRepoModel, String>
123            + Send
124            + Sync
125            + 'static,
126        NR: NetworkRepository + Repository<NetworkRepoModel, String> + Send + Sync + 'static,
127        NFR: Repository<NotificationRepoModel, String> + Send + Sync + 'static,
128        SR: Repository<SignerRepoModel, String> + Send + Sync + 'static,
129        TCR: TransactionCounterTrait + Send + Sync + 'static,
130        PR: PluginRepositoryTrait + Send + Sync + 'static,
131        AKR: ApiKeyRepositoryTrait + Send + Sync + 'static,
132    {
133        let mut shutdown = shutdown_rx;
134
135        let mut traces = Vec::new();
136
137        loop {
138            let state = Arc::clone(&state);
139            let relayer_api = Arc::clone(&relayer_api);
140            tokio::select! {
141                Ok((stream, _)) = self.listener.accept() => {
142                    let result = tokio::spawn(Self::handle_connection::<RA, J, RR, TR, NR, NFR, SR, TCR, PR, AKR>(stream, state, relayer_api))
143                        .await
144                        .map_err(|e| PluginError::SocketError(e.to_string()))?;
145
146                    match result {
147                        Ok(trace) => traces.extend(trace),
148                        Err(e) => return Err(e),
149                    }
150                }
151                _ = &mut shutdown => {
152                    debug!("Shutdown signal received. Closing listener.");
153                    break;
154                }
155            }
156        }
157
158        Ok(traces)
159    }
160
161    /// Handles a new connection.
162    ///
163    /// # Arguments
164    ///
165    /// * `stream` - The stream to the client.
166    /// * `state` - The application state.
167    /// * `relayer_api` - The relayer API.
168    ///
169    /// # Returns
170    ///
171    /// A vector of traces.
172    #[allow(clippy::type_complexity)]
173    async fn handle_connection<RA, J, RR, TR, NR, NFR, SR, TCR, PR, AKR>(
174        stream: UnixStream,
175        state: Arc<ThinDataAppState<J, RR, TR, NR, NFR, SR, TCR, PR, AKR>>,
176        relayer_api: Arc<RA>,
177    ) -> Result<Vec<serde_json::Value>, PluginError>
178    where
179        RA: RelayerApiTrait<J, RR, TR, NR, NFR, SR, TCR, PR, AKR> + 'static + Send + Sync,
180        J: JobProducerTrait + 'static,
181        RR: RelayerRepository + Repository<RelayerRepoModel, String> + Send + Sync + 'static,
182        TR: TransactionRepository
183            + Repository<TransactionRepoModel, String>
184            + Send
185            + Sync
186            + 'static,
187        NR: NetworkRepository + Repository<NetworkRepoModel, String> + Send + Sync + 'static,
188        NFR: Repository<NotificationRepoModel, String> + Send + Sync + 'static,
189        SR: Repository<SignerRepoModel, String> + Send + Sync + 'static,
190        TCR: TransactionCounterTrait + Send + Sync + 'static,
191        PR: PluginRepositoryTrait + Send + Sync + 'static,
192        AKR: ApiKeyRepositoryTrait + Send + Sync + 'static,
193    {
194        let (r, mut w) = stream.into_split();
195        let mut reader = BufReader::new(r).lines();
196        let mut traces = Vec::new();
197
198        while let Ok(Some(line)) = reader.next_line().await {
199            let trace: serde_json::Value = serde_json::from_str(&line)
200                .map_err(|e| PluginError::PluginError(format!("Failed to parse trace: {e}")))?;
201            traces.push(trace);
202
203            let request: Request =
204                serde_json::from_str(&line).map_err(|e| PluginError::PluginError(e.to_string()))?;
205
206            let response = relayer_api.handle_request(request, &state).await;
207
208            let response_str = serde_json::to_string(&response)
209                .map_err(|e| PluginError::PluginError(e.to_string()))?
210                + "\n";
211
212            let _ = w.write_all(response_str.as_bytes()).await;
213        }
214
215        Ok(traces)
216    }
217}
218
219#[cfg(test)]
220mod tests {
221    use crate::{
222        services::plugins::{MockRelayerApiTrait, PluginMethod, Response},
223        utils::mocks::mockutils::{create_mock_app_state, create_mock_evm_transaction_request},
224    };
225    use actix_web::web;
226    use std::time::Duration;
227
228    use super::*;
229
230    use tempfile::tempdir;
231    use tokio::{
232        io::{AsyncBufReadExt, BufReader},
233        time::timeout,
234    };
235
236    #[tokio::test]
237    async fn test_socket_service_listen_and_shutdown() {
238        let temp_dir = tempdir().unwrap();
239        let socket_path = temp_dir.path().join("test.sock");
240
241        let mock_relayer = MockRelayerApiTrait::default();
242
243        let service = SocketService::new(socket_path.to_str().unwrap()).unwrap();
244
245        let state = create_mock_app_state(None, None, None, None, None, None).await;
246        let (shutdown_tx, shutdown_rx) = oneshot::channel();
247
248        let listen_handle = tokio::spawn(async move {
249            service
250                .listen(
251                    shutdown_rx,
252                    Arc::new(web::ThinData(state)),
253                    Arc::new(mock_relayer),
254                )
255                .await
256        });
257
258        shutdown_tx.send(()).unwrap();
259
260        let result = timeout(Duration::from_millis(100), listen_handle).await;
261        assert!(result.is_ok(), "Listen handle timed out");
262        assert!(result.unwrap().is_ok(), "Listen handle returned error");
263    }
264
265    #[tokio::test]
266    async fn test_socket_service_handle_connection() {
267        let temp_dir = tempdir().unwrap();
268        let socket_path = temp_dir.path().join("test.sock");
269
270        let mut mock_relayer = MockRelayerApiTrait::default();
271
272        mock_relayer.expect_handle_request().returning(|_, _| {
273            Box::pin(async move {
274                Response {
275                    request_id: "test".to_string(),
276                    result: Some(serde_json::json!("test")),
277                    error: None,
278                }
279            })
280        });
281
282        let service = SocketService::new(socket_path.to_str().unwrap()).unwrap();
283
284        let state = create_mock_app_state(None, None, None, None, None, None).await;
285        let (shutdown_tx, shutdown_rx) = oneshot::channel();
286
287        let listen_handle = tokio::spawn(async move {
288            service
289                .listen(
290                    shutdown_rx,
291                    Arc::new(web::ThinData(state)),
292                    Arc::new(mock_relayer),
293                )
294                .await
295        });
296
297        tokio::time::sleep(Duration::from_millis(50)).await;
298
299        let mut client = UnixStream::connect(socket_path.to_str().unwrap())
300            .await
301            .unwrap();
302
303        let request = Request {
304            request_id: "test".to_string(),
305            relayer_id: "test".to_string(),
306            method: PluginMethod::SendTransaction,
307            payload: serde_json::json!(create_mock_evm_transaction_request()),
308            http_request_id: None,
309        };
310
311        let request_json = serde_json::to_string(&request).unwrap() + "\n";
312
313        client.write_all(request_json.as_bytes()).await.unwrap();
314
315        let mut reader = BufReader::new(&mut client);
316        let mut response_str = String::new();
317        let read_result = timeout(
318            Duration::from_millis(1000),
319            reader.read_line(&mut response_str),
320        )
321        .await;
322
323        assert!(
324            read_result.is_ok(),
325            "Reading response timed out: {:?}",
326            read_result
327        );
328        let bytes_read = read_result.unwrap().unwrap();
329        assert!(bytes_read > 0, "No data received");
330        shutdown_tx.send(()).unwrap();
331
332        let response: Response = serde_json::from_str(&response_str).unwrap();
333
334        assert!(response.error.is_none(), "Error should be none");
335        assert!(response.result.is_some(), "Result should be some");
336        assert_eq!(
337            response.request_id, request.request_id,
338            "Request id mismatch"
339        );
340
341        client.shutdown().await.unwrap();
342
343        let traces = listen_handle.await.unwrap().unwrap();
344
345        assert_eq!(traces.len(), 1);
346        let expected: serde_json::Value = serde_json::from_str(&request_json).unwrap();
347        let actual: serde_json::Value =
348            serde_json::from_str(&serde_json::to_string(&traces[0]).unwrap()).unwrap();
349        assert_eq!(expected, actual, "Request json mismatch with trace");
350    }
351}