1use crate::jobs::JobProducerTrait;
52use crate::models::{
53 NetworkRepoModel, NotificationRepoModel, RelayerRepoModel, SignerRepoModel, ThinDataAppState,
54 TransactionRepoModel,
55};
56use crate::repositories::{
57 ApiKeyRepositoryTrait, NetworkRepository, PluginRepositoryTrait, RelayerRepository, Repository,
58 TransactionCounterTrait, TransactionRepository,
59};
60use std::sync::Arc;
61use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
62use tokio::net::{UnixListener, UnixStream};
63use tokio::sync::oneshot;
64use tracing::debug;
65
66use super::{
67 relayer_api::{RelayerApiTrait, Request},
68 PluginError,
69};
70
71pub struct SocketService {
72 socket_path: String,
73 listener: UnixListener,
74}
75
76impl SocketService {
77 pub fn new(socket_path: &str) -> Result<Self, PluginError> {
83 let _ = std::fs::remove_file(socket_path);
85
86 let listener =
87 UnixListener::bind(socket_path).map_err(|e| PluginError::SocketError(e.to_string()))?;
88
89 Ok(Self {
90 socket_path: socket_path.to_string(),
91 listener,
92 })
93 }
94
95 pub fn socket_path(&self) -> &str {
96 &self.socket_path
97 }
98
99 #[allow(clippy::type_complexity)]
111 pub async fn listen<RA, J, RR, TR, NR, NFR, SR, TCR, PR, AKR>(
112 self,
113 shutdown_rx: oneshot::Receiver<()>,
114 state: Arc<ThinDataAppState<J, RR, TR, NR, NFR, SR, TCR, PR, AKR>>,
115 relayer_api: Arc<RA>,
116 ) -> Result<Vec<serde_json::Value>, PluginError>
117 where
118 RA: RelayerApiTrait<J, RR, TR, NR, NFR, SR, TCR, PR, AKR> + 'static + Send + Sync,
119 J: JobProducerTrait + Send + Sync + 'static,
120 RR: RelayerRepository + Repository<RelayerRepoModel, String> + Send + Sync + 'static,
121 TR: TransactionRepository
122 + Repository<TransactionRepoModel, String>
123 + Send
124 + Sync
125 + 'static,
126 NR: NetworkRepository + Repository<NetworkRepoModel, String> + Send + Sync + 'static,
127 NFR: Repository<NotificationRepoModel, String> + Send + Sync + 'static,
128 SR: Repository<SignerRepoModel, String> + Send + Sync + 'static,
129 TCR: TransactionCounterTrait + Send + Sync + 'static,
130 PR: PluginRepositoryTrait + Send + Sync + 'static,
131 AKR: ApiKeyRepositoryTrait + Send + Sync + 'static,
132 {
133 let mut shutdown = shutdown_rx;
134
135 let mut traces = Vec::new();
136
137 loop {
138 let state = Arc::clone(&state);
139 let relayer_api = Arc::clone(&relayer_api);
140 tokio::select! {
141 Ok((stream, _)) = self.listener.accept() => {
142 let result = tokio::spawn(Self::handle_connection::<RA, J, RR, TR, NR, NFR, SR, TCR, PR, AKR>(stream, state, relayer_api))
143 .await
144 .map_err(|e| PluginError::SocketError(e.to_string()))?;
145
146 match result {
147 Ok(trace) => traces.extend(trace),
148 Err(e) => return Err(e),
149 }
150 }
151 _ = &mut shutdown => {
152 debug!("Shutdown signal received. Closing listener.");
153 break;
154 }
155 }
156 }
157
158 Ok(traces)
159 }
160
161 #[allow(clippy::type_complexity)]
173 async fn handle_connection<RA, J, RR, TR, NR, NFR, SR, TCR, PR, AKR>(
174 stream: UnixStream,
175 state: Arc<ThinDataAppState<J, RR, TR, NR, NFR, SR, TCR, PR, AKR>>,
176 relayer_api: Arc<RA>,
177 ) -> Result<Vec<serde_json::Value>, PluginError>
178 where
179 RA: RelayerApiTrait<J, RR, TR, NR, NFR, SR, TCR, PR, AKR> + 'static + Send + Sync,
180 J: JobProducerTrait + 'static,
181 RR: RelayerRepository + Repository<RelayerRepoModel, String> + Send + Sync + 'static,
182 TR: TransactionRepository
183 + Repository<TransactionRepoModel, String>
184 + Send
185 + Sync
186 + 'static,
187 NR: NetworkRepository + Repository<NetworkRepoModel, String> + Send + Sync + 'static,
188 NFR: Repository<NotificationRepoModel, String> + Send + Sync + 'static,
189 SR: Repository<SignerRepoModel, String> + Send + Sync + 'static,
190 TCR: TransactionCounterTrait + Send + Sync + 'static,
191 PR: PluginRepositoryTrait + Send + Sync + 'static,
192 AKR: ApiKeyRepositoryTrait + Send + Sync + 'static,
193 {
194 let (r, mut w) = stream.into_split();
195 let mut reader = BufReader::new(r).lines();
196 let mut traces = Vec::new();
197
198 while let Ok(Some(line)) = reader.next_line().await {
199 let trace: serde_json::Value = serde_json::from_str(&line)
200 .map_err(|e| PluginError::PluginError(format!("Failed to parse trace: {e}")))?;
201 traces.push(trace);
202
203 let request: Request =
204 serde_json::from_str(&line).map_err(|e| PluginError::PluginError(e.to_string()))?;
205
206 let response = relayer_api.handle_request(request, &state).await;
207
208 let response_str = serde_json::to_string(&response)
209 .map_err(|e| PluginError::PluginError(e.to_string()))?
210 + "\n";
211
212 let _ = w.write_all(response_str.as_bytes()).await;
213 }
214
215 Ok(traces)
216 }
217}
218
219#[cfg(test)]
220mod tests {
221 use crate::{
222 services::plugins::{MockRelayerApiTrait, PluginMethod, Response},
223 utils::mocks::mockutils::{create_mock_app_state, create_mock_evm_transaction_request},
224 };
225 use actix_web::web;
226 use std::time::Duration;
227
228 use super::*;
229
230 use tempfile::tempdir;
231 use tokio::{
232 io::{AsyncBufReadExt, BufReader},
233 time::timeout,
234 };
235
236 #[tokio::test]
237 async fn test_socket_service_listen_and_shutdown() {
238 let temp_dir = tempdir().unwrap();
239 let socket_path = temp_dir.path().join("test.sock");
240
241 let mock_relayer = MockRelayerApiTrait::default();
242
243 let service = SocketService::new(socket_path.to_str().unwrap()).unwrap();
244
245 let state = create_mock_app_state(None, None, None, None, None, None).await;
246 let (shutdown_tx, shutdown_rx) = oneshot::channel();
247
248 let listen_handle = tokio::spawn(async move {
249 service
250 .listen(
251 shutdown_rx,
252 Arc::new(web::ThinData(state)),
253 Arc::new(mock_relayer),
254 )
255 .await
256 });
257
258 shutdown_tx.send(()).unwrap();
259
260 let result = timeout(Duration::from_millis(100), listen_handle).await;
261 assert!(result.is_ok(), "Listen handle timed out");
262 assert!(result.unwrap().is_ok(), "Listen handle returned error");
263 }
264
265 #[tokio::test]
266 async fn test_socket_service_handle_connection() {
267 let temp_dir = tempdir().unwrap();
268 let socket_path = temp_dir.path().join("test.sock");
269
270 let mut mock_relayer = MockRelayerApiTrait::default();
271
272 mock_relayer.expect_handle_request().returning(|_, _| {
273 Box::pin(async move {
274 Response {
275 request_id: "test".to_string(),
276 result: Some(serde_json::json!("test")),
277 error: None,
278 }
279 })
280 });
281
282 let service = SocketService::new(socket_path.to_str().unwrap()).unwrap();
283
284 let state = create_mock_app_state(None, None, None, None, None, None).await;
285 let (shutdown_tx, shutdown_rx) = oneshot::channel();
286
287 let listen_handle = tokio::spawn(async move {
288 service
289 .listen(
290 shutdown_rx,
291 Arc::new(web::ThinData(state)),
292 Arc::new(mock_relayer),
293 )
294 .await
295 });
296
297 tokio::time::sleep(Duration::from_millis(50)).await;
298
299 let mut client = UnixStream::connect(socket_path.to_str().unwrap())
300 .await
301 .unwrap();
302
303 let request = Request {
304 request_id: "test".to_string(),
305 relayer_id: "test".to_string(),
306 method: PluginMethod::SendTransaction,
307 payload: serde_json::json!(create_mock_evm_transaction_request()),
308 http_request_id: None,
309 };
310
311 let request_json = serde_json::to_string(&request).unwrap() + "\n";
312
313 client.write_all(request_json.as_bytes()).await.unwrap();
314
315 let mut reader = BufReader::new(&mut client);
316 let mut response_str = String::new();
317 let read_result = timeout(
318 Duration::from_millis(1000),
319 reader.read_line(&mut response_str),
320 )
321 .await;
322
323 assert!(
324 read_result.is_ok(),
325 "Reading response timed out: {:?}",
326 read_result
327 );
328 let bytes_read = read_result.unwrap().unwrap();
329 assert!(bytes_read > 0, "No data received");
330 shutdown_tx.send(()).unwrap();
331
332 let response: Response = serde_json::from_str(&response_str).unwrap();
333
334 assert!(response.error.is_none(), "Error should be none");
335 assert!(response.result.is_some(), "Result should be some");
336 assert_eq!(
337 response.request_id, request.request_id,
338 "Request id mismatch"
339 );
340
341 client.shutdown().await.unwrap();
342
343 let traces = listen_handle.await.unwrap().unwrap();
344
345 assert_eq!(traces.len(), 1);
346 let expected: serde_json::Value = serde_json::from_str(&request_json).unwrap();
347 let actual: serde_json::Value =
348 serde_json::from_str(&serde_json::to_string(&traces[0]).unwrap()).unwrap();
349 assert_eq!(expected, actual, "Request json mismatch with trace");
350 }
351}