openzeppelin_relayer/observability/
middleware.rs

1use crate::observability::request_id::set_request_id;
2use actix_web::{
3    dev::{forward_ready, Service, ServiceRequest, ServiceResponse, Transform},
4    http::header::{HeaderName, HeaderValue},
5    Error, HttpMessage,
6};
7use futures::future::LocalBoxFuture;
8use std::future::{ready, Ready};
9use tracing_actix_web::RequestId as ActixRequestId;
10
11/// Middleware that adds request ID tracking to all HTTP requests
12pub struct RequestIdMiddleware;
13
14impl<S, B> Transform<S, ServiceRequest> for RequestIdMiddleware
15where
16    S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = Error>,
17    S::Future: 'static,
18    B: 'static,
19{
20    type Response = ServiceResponse<B>;
21    type Error = Error;
22    type InitError = ();
23    type Transform = RequestIdMiddlewareService<S>;
24    type Future = Ready<Result<Self::Transform, Self::InitError>>;
25
26    fn new_transform(&self, service: S) -> Self::Future {
27        ready(Ok(RequestIdMiddlewareService { service }))
28    }
29}
30
31pub struct RequestIdMiddlewareService<S> {
32    service: S,
33}
34
35impl<S, B> Service<ServiceRequest> for RequestIdMiddlewareService<S>
36where
37    S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = Error>,
38    S::Future: 'static,
39    B: 'static,
40{
41    type Response = ServiceResponse<B>;
42    type Error = Error;
43    type Future = LocalBoxFuture<'static, Result<Self::Response, Self::Error>>;
44
45    forward_ready!(service);
46
47    fn call(&self, req: ServiceRequest) -> Self::Future {
48        // Priority order: incoming header -> ActixRequestId -> new UUID
49        let rid = req
50            .headers()
51            .get("x-request-id")
52            .and_then(|v| v.to_str().ok())
53            .map(|s| s.to_string())
54            .or_else(|| {
55                req.extensions()
56                    .get::<ActixRequestId>()
57                    .map(|r| r.to_string())
58            })
59            .unwrap_or_else(|| uuid::Uuid::new_v4().to_string());
60
61        set_request_id(rid.clone());
62
63        let fut = self.service.call(req);
64
65        Box::pin(async move {
66            let mut res = fut.await?;
67            // Use safe conversion to avoid panic on invalid header values
68            if let Ok(header_value) = HeaderValue::from_str(&rid) {
69                res.headers_mut()
70                    .insert(HeaderName::from_static("x-request-id"), header_value);
71            }
72            Ok(res)
73        })
74    }
75}
76
77#[cfg(test)]
78mod tests {
79    use super::*;
80    use actix_web::{test, web, App, HttpResponse};
81    use uuid::Uuid;
82
83    #[actix_rt::test]
84    async fn echoes_incoming_x_request_id_header() {
85        let app = test::init_service(App::new().wrap(RequestIdMiddleware).route(
86            "/",
87            web::get().to(|| async { HttpResponse::Ok().body("ok") }),
88        ))
89        .await;
90
91        let req = test::TestRequest::get()
92            .uri("/")
93            .insert_header(("x-request-id", "test-req-id-123"))
94            .to_request();
95
96        let resp = test::call_service(&app, req).await;
97        assert!(resp.status().is_success());
98
99        let hdr = resp
100            .headers()
101            .get("x-request-id")
102            .and_then(|v| v.to_str().ok())
103            .map(|s| s.to_string());
104
105        assert_eq!(hdr.as_deref(), Some("test-req-id-123"));
106    }
107
108    #[actix_rt::test]
109    async fn generates_uuid_when_header_absent() {
110        let app = test::init_service(
111            App::new()
112                .wrap(RequestIdMiddleware)
113                .route("/", web::get().to(|| async { HttpResponse::Ok().finish() })),
114        )
115        .await;
116
117        let req = test::TestRequest::get().uri("/").to_request();
118        let resp = test::call_service(&app, req).await;
119        assert!(resp.status().is_success());
120
121        let hdr = resp
122            .headers()
123            .get("x-request-id")
124            .expect("x-request-id header should be present")
125            .to_str()
126            .expect("header should be valid ascii");
127
128        // Ensure it's a valid UUID (version is not strictly asserted here)
129        let parsed = Uuid::try_parse(hdr).expect("x-request-id should be a UUID");
130        // Sanity: UUID should not be nil
131        assert_ne!(parsed, Uuid::nil());
132    }
133
134    #[actix_rt::test]
135    async fn preserves_header_on_internal_error() {
136        let app = test::init_service(App::new().wrap(RequestIdMiddleware).route(
137            "/",
138            web::get().to(|| async {
139                // Simulate 500 error
140                HttpResponse::InternalServerError().finish()
141            }),
142        ))
143        .await;
144
145        let req = test::TestRequest::get()
146            .uri("/")
147            .insert_header(("x-request-id", "err-req-id-999"))
148            .to_request();
149
150        let resp = test::call_service(&app, req).await;
151        assert_eq!(
152            resp.status(),
153            actix_web::http::StatusCode::INTERNAL_SERVER_ERROR
154        );
155
156        let hdr = resp
157            .headers()
158            .get("x-request-id")
159            .and_then(|v| v.to_str().ok());
160        assert_eq!(hdr, Some("err-req-id-999"));
161    }
162}