openzeppelin_relayer/observability/
middleware.rs1use crate::observability::request_id::set_request_id;
2use actix_web::{
3 dev::{forward_ready, Service, ServiceRequest, ServiceResponse, Transform},
4 http::header::{HeaderName, HeaderValue},
5 Error, HttpMessage,
6};
7use futures::future::LocalBoxFuture;
8use std::future::{ready, Ready};
9use tracing_actix_web::RequestId as ActixRequestId;
10
11pub struct RequestIdMiddleware;
13
14impl<S, B> Transform<S, ServiceRequest> for RequestIdMiddleware
15where
16 S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = Error>,
17 S::Future: 'static,
18 B: 'static,
19{
20 type Response = ServiceResponse<B>;
21 type Error = Error;
22 type InitError = ();
23 type Transform = RequestIdMiddlewareService<S>;
24 type Future = Ready<Result<Self::Transform, Self::InitError>>;
25
26 fn new_transform(&self, service: S) -> Self::Future {
27 ready(Ok(RequestIdMiddlewareService { service }))
28 }
29}
30
31pub struct RequestIdMiddlewareService<S> {
32 service: S,
33}
34
35impl<S, B> Service<ServiceRequest> for RequestIdMiddlewareService<S>
36where
37 S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = Error>,
38 S::Future: 'static,
39 B: 'static,
40{
41 type Response = ServiceResponse<B>;
42 type Error = Error;
43 type Future = LocalBoxFuture<'static, Result<Self::Response, Self::Error>>;
44
45 forward_ready!(service);
46
47 fn call(&self, req: ServiceRequest) -> Self::Future {
48 let rid = req
50 .headers()
51 .get("x-request-id")
52 .and_then(|v| v.to_str().ok())
53 .map(|s| s.to_string())
54 .or_else(|| {
55 req.extensions()
56 .get::<ActixRequestId>()
57 .map(|r| r.to_string())
58 })
59 .unwrap_or_else(|| uuid::Uuid::new_v4().to_string());
60
61 set_request_id(rid.clone());
62
63 let fut = self.service.call(req);
64
65 Box::pin(async move {
66 let mut res = fut.await?;
67 if let Ok(header_value) = HeaderValue::from_str(&rid) {
69 res.headers_mut()
70 .insert(HeaderName::from_static("x-request-id"), header_value);
71 }
72 Ok(res)
73 })
74 }
75}
76
77#[cfg(test)]
78mod tests {
79 use super::*;
80 use actix_web::{test, web, App, HttpResponse};
81 use uuid::Uuid;
82
83 #[actix_rt::test]
84 async fn echoes_incoming_x_request_id_header() {
85 let app = test::init_service(App::new().wrap(RequestIdMiddleware).route(
86 "/",
87 web::get().to(|| async { HttpResponse::Ok().body("ok") }),
88 ))
89 .await;
90
91 let req = test::TestRequest::get()
92 .uri("/")
93 .insert_header(("x-request-id", "test-req-id-123"))
94 .to_request();
95
96 let resp = test::call_service(&app, req).await;
97 assert!(resp.status().is_success());
98
99 let hdr = resp
100 .headers()
101 .get("x-request-id")
102 .and_then(|v| v.to_str().ok())
103 .map(|s| s.to_string());
104
105 assert_eq!(hdr.as_deref(), Some("test-req-id-123"));
106 }
107
108 #[actix_rt::test]
109 async fn generates_uuid_when_header_absent() {
110 let app = test::init_service(
111 App::new()
112 .wrap(RequestIdMiddleware)
113 .route("/", web::get().to(|| async { HttpResponse::Ok().finish() })),
114 )
115 .await;
116
117 let req = test::TestRequest::get().uri("/").to_request();
118 let resp = test::call_service(&app, req).await;
119 assert!(resp.status().is_success());
120
121 let hdr = resp
122 .headers()
123 .get("x-request-id")
124 .expect("x-request-id header should be present")
125 .to_str()
126 .expect("header should be valid ascii");
127
128 let parsed = Uuid::try_parse(hdr).expect("x-request-id should be a UUID");
130 assert_ne!(parsed, Uuid::nil());
132 }
133
134 #[actix_rt::test]
135 async fn preserves_header_on_internal_error() {
136 let app = test::init_service(App::new().wrap(RequestIdMiddleware).route(
137 "/",
138 web::get().to(|| async {
139 HttpResponse::InternalServerError().finish()
141 }),
142 ))
143 .await;
144
145 let req = test::TestRequest::get()
146 .uri("/")
147 .insert_header(("x-request-id", "err-req-id-999"))
148 .to_request();
149
150 let resp = test::call_service(&app, req).await;
151 assert_eq!(
152 resp.status(),
153 actix_web::http::StatusCode::INTERNAL_SERVER_ERROR
154 );
155
156 let hdr = resp
157 .headers()
158 .get("x-request-id")
159 .and_then(|v| v.to_str().ok());
160 assert_eq!(hdr, Some("err-req-id-999"));
161 }
162}