openzeppelin_relayer/services/vault/
mod.rs

1//! # Vault Service Module
2//!
3//! This module provides integration with HashiCorp Vault for secure secret management
4//! and cryptographic operations.
5//!
6//! ## Features
7//!
8//! - Token-based authentication using AppRole method
9//! - Automatic token caching and renewal
10//! - Secret retrieval from KV2 secrets engine
11//! - Message signing via Vault's Transit engine
12//! - Namespace support for Vault Enterprise
13//!
14//! ## Architecture
15//!
16//! ```text
17//! VaultService (implements VaultServiceTrait)
18//!   ├── Authentication (AppRole)
19//!   ├── Token Caching
20//!   ├── KV2 Secret Operations
21//!   └── Transit Signing Operations
22//! ```
23use async_trait::async_trait;
24use core::fmt;
25use once_cell::sync::Lazy;
26use serde::Serialize;
27use std::collections::HashMap;
28use std::hash::Hash;
29use std::sync::Arc;
30use std::time::{Duration, Instant};
31use thiserror::Error;
32use tokio::sync::RwLock;
33use tracing::debug;
34use vaultrs::{
35    auth::approle::login,
36    client::{VaultClient, VaultClientSettingsBuilder},
37    kv2, transit,
38};
39use zeroize::{Zeroize, ZeroizeOnDrop};
40
41#[derive(Error, Debug, Serialize)]
42pub enum VaultError {
43    #[error("Vault client error: {0}")]
44    ClientError(String),
45
46    #[error("Secret not found: {0}")]
47    SecretNotFound(String),
48
49    #[error("Authentication failed: {0}")]
50    AuthenticationFailed(String),
51
52    #[error("Configuration error: {0}")]
53    ConfigError(String),
54
55    #[error("Signing error: {0}")]
56    SigningError(String),
57}
58
59// Token cache key to uniquely identify a vault configuration
60#[derive(Clone, Debug, PartialEq, Eq, Hash, Zeroize, ZeroizeOnDrop)]
61struct VaultCacheKey {
62    address: String,
63    role_id: String,
64    namespace: Option<String>,
65}
66
67impl fmt::Display for VaultCacheKey {
68    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
69        write!(
70            f,
71            "{}|{}|{}",
72            self.address,
73            self.role_id,
74            self.namespace.as_deref().unwrap_or("")
75        )
76    }
77}
78
79struct TokenCache {
80    client: Arc<VaultClient>,
81    expiry: Instant,
82}
83
84// Global token cache - HashMap keyed by VaultCacheKey
85static TOKEN_CACHE: Lazy<RwLock<HashMap<VaultCacheKey, TokenCache>>> =
86    Lazy::new(|| RwLock::new(HashMap::new()));
87
88#[cfg(test)]
89use mockall::automock;
90
91use crate::models::SecretString;
92use crate::utils::base64_encode;
93
94#[derive(Clone, Debug)]
95pub struct VaultConfig {
96    pub address: String,
97    pub namespace: Option<String>,
98    pub role_id: SecretString,
99    pub secret_id: SecretString,
100    pub mount_path: String,
101    // Optional token TTL in seconds, defaults to 45 minutes if not set
102    pub token_ttl: Option<u64>,
103}
104
105impl VaultConfig {
106    pub fn new(
107        address: String,
108        role_id: SecretString,
109        secret_id: SecretString,
110        namespace: Option<String>,
111        mount_path: String,
112        token_ttl: Option<u64>,
113    ) -> Self {
114        Self {
115            address,
116            role_id,
117            secret_id,
118            namespace,
119            mount_path,
120            token_ttl,
121        }
122    }
123
124    fn cache_key(&self) -> VaultCacheKey {
125        VaultCacheKey {
126            address: self.address.clone(),
127            role_id: self.role_id.to_str().to_string(),
128            namespace: self.namespace.clone(),
129        }
130    }
131}
132
133#[async_trait]
134#[cfg_attr(test, automock)]
135pub trait VaultServiceTrait: Send + Sync {
136    async fn retrieve_secret(&self, key_name: &str) -> Result<String, VaultError>;
137    async fn sign(&self, key_name: &str, message: &[u8]) -> Result<String, VaultError>;
138}
139
140#[derive(Clone, Debug)]
141pub struct VaultService {
142    pub config: VaultConfig,
143}
144
145impl VaultService {
146    pub fn new(config: VaultConfig) -> Self {
147        Self { config }
148    }
149
150    // Get a cached client or create a new one if cache is empty/expired
151    async fn get_client(&self) -> Result<Arc<VaultClient>, VaultError> {
152        let cache_key = self.config.cache_key();
153
154        // Try to read from cache first
155        {
156            let cache = TOKEN_CACHE.read().await;
157            if let Some(cached) = cache.get(&cache_key) {
158                if Instant::now() < cached.expiry {
159                    return Ok(Arc::clone(&cached.client));
160                }
161            }
162        }
163
164        // Cache miss or expired token, need to acquire write lock and refresh
165        let mut cache = TOKEN_CACHE.write().await;
166        // Double-check after acquiring write lock
167        if let Some(cached) = cache.get(&cache_key) {
168            if Instant::now() < cached.expiry {
169                return Ok(Arc::clone(&cached.client));
170            }
171        }
172
173        // Create and authenticate a new client
174        let client = self.create_authenticated_client().await?;
175
176        // Determine TTL (defaults to 45 minutes if not specified)
177        let ttl = Duration::from_secs(self.config.token_ttl.unwrap_or(45 * 60));
178
179        // Update the cache
180        cache.insert(
181            cache_key,
182            TokenCache {
183                client: client.clone(),
184                expiry: Instant::now() + ttl,
185            },
186        );
187
188        Ok(client)
189    }
190
191    // Create and authenticate a new vault client
192    async fn create_authenticated_client(&self) -> Result<Arc<VaultClient>, VaultError> {
193        let mut auth_settings_builder = VaultClientSettingsBuilder::default();
194        let address = &self.config.address;
195        auth_settings_builder.address(address).verify(true);
196
197        if let Some(namespace) = &self.config.namespace {
198            auth_settings_builder.namespace(Some(namespace.clone()));
199        }
200
201        let auth_settings = auth_settings_builder.build().map_err(|e| {
202            VaultError::ConfigError(format!("Failed to build Vault client settings: {e}"))
203        })?;
204
205        let client = VaultClient::new(auth_settings)
206            .map_err(|e| VaultError::ConfigError(format!("Failed to create Vault client: {e}")))?;
207
208        let token = login(
209            &client,
210            "approle",
211            &self.config.role_id.to_str(),
212            &self.config.secret_id.to_str(),
213        )
214        .await
215        .map_err(|e| VaultError::AuthenticationFailed(e.to_string()))?;
216
217        let mut transit_settings_builder = VaultClientSettingsBuilder::default();
218
219        transit_settings_builder
220            .address(self.config.address.clone())
221            .token(token.client_token.clone())
222            .verify(true);
223
224        if let Some(namespace) = &self.config.namespace {
225            transit_settings_builder.namespace(Some(namespace.clone()));
226        }
227
228        let transit_settings = transit_settings_builder.build().map_err(|e| {
229            VaultError::ConfigError(format!("Failed to build Vault client settings: {e}"))
230        })?;
231
232        let client = Arc::new(VaultClient::new(transit_settings).map_err(|e| {
233            VaultError::ConfigError(format!("Failed to create authenticated Vault client: {e}"))
234        })?);
235
236        Ok(client)
237    }
238}
239
240#[async_trait]
241impl VaultServiceTrait for VaultService {
242    async fn retrieve_secret(&self, key_name: &str) -> Result<String, VaultError> {
243        let client = self.get_client().await?;
244
245        let secret: serde_json::Value = kv2::read(&*client, &self.config.mount_path, key_name)
246            .await
247            .map_err(|e| VaultError::ClientError(e.to_string()))?;
248
249        let value = secret["value"]
250            .as_str()
251            .ok_or_else(|| {
252                VaultError::SecretNotFound(format!("Secret value invalid for key: {key_name}"))
253            })?
254            .to_string();
255
256        Ok(value)
257    }
258
259    async fn sign(&self, key_name: &str, message: &[u8]) -> Result<String, VaultError> {
260        let client = self.get_client().await?;
261
262        let vault_signature = transit::data::sign(
263            &*client,
264            &self.config.mount_path,
265            key_name,
266            &base64_encode(message),
267            None,
268        )
269        .await
270        .map_err(|e| VaultError::SigningError(format!("Failed to sign with Vault: {e}")))?;
271
272        let vault_signature_str = &vault_signature.signature;
273
274        debug!(vault_signature_str = %vault_signature_str, "vault signature string");
275
276        Ok(vault_signature_str.clone())
277    }
278}
279
280#[cfg(test)]
281mod tests {
282    use super::*;
283    use mockito;
284    use serde_json::json;
285
286    #[test]
287    fn test_vault_config_new() {
288        let config = VaultConfig::new(
289            "https://vault.example.com".to_string(),
290            SecretString::new("test-role-id"),
291            SecretString::new("test-secret-id"),
292            Some("test-namespace".to_string()),
293            "test-mount-path".to_string(),
294            Some(60),
295        );
296
297        assert_eq!(config.address, "https://vault.example.com");
298        assert_eq!(config.role_id.to_str().as_str(), "test-role-id");
299        assert_eq!(config.secret_id.to_str().as_str(), "test-secret-id");
300        assert_eq!(config.namespace, Some("test-namespace".to_string()));
301        assert_eq!(config.mount_path, "test-mount-path");
302        assert_eq!(config.token_ttl, Some(60));
303    }
304
305    #[test]
306    fn test_vault_cache_key() {
307        let config1 = VaultConfig {
308            address: "https://vault1.example.com".to_string(),
309            namespace: Some("namespace1".to_string()),
310            role_id: SecretString::new("role1"),
311            secret_id: SecretString::new("secret1"),
312            mount_path: "transit".to_string(),
313            token_ttl: None,
314        };
315
316        let config2 = VaultConfig {
317            address: "https://vault1.example.com".to_string(),
318            namespace: Some("namespace1".to_string()),
319            role_id: SecretString::new("role1"),
320            secret_id: SecretString::new("secret1"),
321            mount_path: "different-mount".to_string(),
322            token_ttl: None,
323        };
324
325        let config3 = VaultConfig {
326            address: "https://vault2.example.com".to_string(),
327            namespace: Some("namespace1".to_string()),
328            role_id: SecretString::new("role1"),
329            secret_id: SecretString::new("secret1"),
330            mount_path: "transit".to_string(),
331            token_ttl: None,
332        };
333
334        assert_eq!(config1.cache_key(), config1.cache_key());
335        assert_eq!(config1.cache_key(), config2.cache_key());
336        assert_ne!(config1.cache_key(), config3.cache_key());
337    }
338
339    #[test]
340    fn test_vault_cache_key_display() {
341        let key_with_namespace = VaultCacheKey {
342            address: "https://vault.example.com".to_string(),
343            role_id: "role-123".to_string(),
344            namespace: Some("my-namespace".to_string()),
345        };
346
347        let key_without_namespace = VaultCacheKey {
348            address: "https://vault.example.com".to_string(),
349            role_id: "role-123".to_string(),
350            namespace: None,
351        };
352
353        assert_eq!(
354            key_with_namespace.to_string(),
355            "https://vault.example.com|role-123|my-namespace"
356        );
357
358        assert_eq!(
359            key_without_namespace.to_string(),
360            "https://vault.example.com|role-123|"
361        );
362    }
363
364    // utility function to setup a mock AppRole login response
365    async fn setup_mock_approle_login(
366        mock_server: &mut mockito::ServerGuard,
367        role_id: &str,
368        secret_id: &str,
369        token: &str,
370    ) -> mockito::Mock {
371        mock_server
372            .mock("POST", "/v1/auth/approle/login")
373            .match_body(mockito::Matcher::Json(json!({
374                "role_id": role_id,
375                "secret_id": secret_id
376            })))
377            .with_status(200)
378            .with_header("content-type", "application/json")
379            .with_body(
380                serde_json::to_string(&json!({
381                    "request_id": "test-request-id",
382                    "lease_id": "",
383                    "renewable": false,
384                    "lease_duration": 0,
385                    "data": null,
386                    "wrap_info": null,
387                    "warnings": null,
388                    "auth": {
389                        "client_token": token,
390                        "accessor": "test-accessor",
391                        "policies": ["default"],
392                        "token_policies": ["default"],
393                        "metadata": {
394                            "role_name": "test-role"
395                        },
396                        "lease_duration": 3600,
397                        "renewable": true,
398                        "entity_id": "test-entity-id",
399                        "token_type": "service",
400                        "orphan": true
401                    }
402                }))
403                .unwrap(),
404            )
405            .create_async()
406            .await
407    }
408
409    #[tokio::test]
410    async fn test_vault_service_auth_failure() {
411        let mut mock_server = mockito::Server::new_async().await;
412
413        let _login_mock = setup_mock_approle_login(
414            &mut mock_server,
415            "test-role-id",
416            "test-secret-id",
417            "test-token",
418        )
419        .await;
420
421        let _secret_mock = mock_server
422            .mock("GET", "/v1/test-mount/data/my-secret")
423            .match_header("X-Vault-Token", "test-token")
424            .with_status(200)
425            .with_header("content-type", "application/json")
426            .with_body(
427                serde_json::to_string(&json!({
428                    "request_id": "test-request-id",
429                    "lease_id": "",
430                    "renewable": false,
431                    "lease_duration": 0,
432                    "data": {
433                        "data": {
434                            "value": "super-secret-value"
435                        },
436                        "metadata": {
437                            "created_time": "2024-01-01T00:00:00Z",
438                            "deletion_time": "",
439                            "destroyed": false,
440                            "version": 1
441                        }
442                    },
443                    "wrap_info": null,
444                    "warnings": null,
445                    "auth": null
446                }))
447                .unwrap(),
448            )
449            .create_async()
450            .await;
451
452        let config = VaultConfig::new(
453            mock_server.url(),
454            SecretString::new("test-role-id-fake"),
455            SecretString::new("test-secret-id-fake"),
456            None,
457            "test-mount".to_string(),
458            Some(60),
459        );
460
461        let vault_service = VaultService::new(config);
462
463        let secret = vault_service.retrieve_secret("my-secret").await;
464
465        assert!(secret.is_err());
466
467        if let Err(e) = secret {
468            assert!(matches!(e, VaultError::AuthenticationFailed(_)));
469            assert!(e.to_string().contains("An error occurred with the request"));
470        }
471    }
472
473    #[tokio::test]
474    async fn test_vault_service_retrieve_secret_success() {
475        let mut mock_server = mockito::Server::new_async().await;
476
477        let _login_mock = setup_mock_approle_login(
478            &mut mock_server,
479            "test-role-id",
480            "test-secret-id",
481            "test-token",
482        )
483        .await;
484
485        let _secret_mock = mock_server
486            .mock(
487                "GET",
488                mockito::Matcher::Regex(r"/v1/test-mount/data/my-secret.*".to_string()),
489            )
490            .match_header("X-Vault-Token", "test-token")
491            .with_status(200)
492            .with_header("content-type", "application/json")
493            .with_body(
494                serde_json::to_string(&json!({
495                    "request_id": "test-request-id",
496                    "lease_id": "",
497                    "renewable": false,
498                    "lease_duration": 0,
499                    "data": {
500                        "data": {
501                            "value": "super-secret-value"
502                        },
503                        "metadata": {
504                            "created_time": "2024-01-01T00:00:00Z",
505                            "deletion_time": "",
506                            "destroyed": false,
507                            "version": 1
508                        }
509                    },
510                    "wrap_info": null,
511                    "warnings": null,
512                    "auth": null
513                }))
514                .unwrap(),
515            )
516            .create_async()
517            .await;
518
519        let config = VaultConfig::new(
520            mock_server.url(),
521            SecretString::new("test-role-id"),
522            SecretString::new("test-secret-id"),
523            None,
524            "test-mount".to_string(),
525            Some(60),
526        );
527
528        let vault_service = VaultService::new(config);
529
530        let secret = vault_service.retrieve_secret("my-secret").await.unwrap();
531
532        assert_eq!(secret, "super-secret-value");
533    }
534
535    #[tokio::test]
536    async fn test_vault_service_sign_success() {
537        let mut mock_server = mockito::Server::new_async().await;
538
539        let _login_mock = setup_mock_approle_login(
540            &mut mock_server,
541            "test-role-id",
542            "test-secret-id",
543            "test-token",
544        )
545        .await;
546
547        let message = b"hello world";
548        let encoded_message = base64_encode(message);
549
550        let _sign_mock = mock_server
551            .mock("POST", "/v1/test-mount/sign/my-signing-key")
552            .match_header("X-Vault-Token", "test-token")
553            .match_body(mockito::Matcher::Json(json!({
554                "input": encoded_message
555            })))
556            .with_status(200)
557            .with_header("content-type", "application/json")
558            .with_body(
559                serde_json::to_string(&json!({
560                    "request_id": "test-request-id",
561                    "lease_id": "",
562                    "renewable": false,
563                    "lease_duration": 0,
564                    "data": {
565                        "signature": "vault:v1:fake-signature",
566                        "key_version": 1
567                    },
568                    "wrap_info": null,
569                    "warnings": null,
570                    "auth": null
571                }))
572                .unwrap(),
573            )
574            .create_async()
575            .await;
576
577        let config = VaultConfig::new(
578            mock_server.url(),
579            SecretString::new("test-role-id"),
580            SecretString::new("test-secret-id"),
581            None,
582            "test-mount".to_string(),
583            Some(60),
584        );
585
586        let vault_service = VaultService::new(config);
587        let signature = vault_service.sign("my-signing-key", message).await.unwrap();
588
589        assert_eq!(signature, "vault:v1:fake-signature");
590    }
591
592    #[tokio::test]
593    async fn test_vault_service_retrieve_secret_failure() {
594        let mut mock_server = mockito::Server::new_async().await;
595
596        let _login_mock = setup_mock_approle_login(
597            &mut mock_server,
598            "test-role-id",
599            "test-secret-id",
600            "test-token",
601        )
602        .await;
603
604        let _secret_mock = mock_server
605            .mock(
606                "GET",
607                mockito::Matcher::Regex(r"/v1/test-mount/data/my-secret.*".to_string()),
608            )
609            .match_header("X-Vault-Token", "test-token")
610            .with_status(404)
611            .with_header("content-type", "application/json")
612            .with_body(
613                serde_json::to_string(&json!({
614                    "errors": ["secret not found:"]
615                }))
616                .unwrap(),
617            )
618            .create_async()
619            .await;
620
621        let config = VaultConfig::new(
622            mock_server.url(),
623            SecretString::new("test-role-id"),
624            SecretString::new("test-secret-id"),
625            None,
626            "test-mount".to_string(),
627            Some(60),
628        );
629
630        let vault_service = VaultService::new(config);
631
632        let result = vault_service.retrieve_secret("my-secret").await;
633        assert!(result.is_err());
634
635        if let Err(e) = result {
636            assert!(matches!(e, VaultError::ClientError(_)));
637            assert!(e
638                .to_string()
639                .contains("The Vault server returned an error (status code 404)"));
640        }
641    }
642
643    #[tokio::test]
644    async fn test_vault_service_sign_failure() {
645        let mut mock_server = mockito::Server::new_async().await;
646
647        let _login_mock = setup_mock_approle_login(
648            &mut mock_server,
649            "test-role-id",
650            "test-secret-id",
651            "test-token",
652        )
653        .await;
654
655        let message = b"hello world";
656        let encoded_message = base64_encode(message);
657
658        let _sign_mock = mock_server
659            .mock("POST", "/v1/test-mount/sign/my-signing-key")
660            .match_header("X-Vault-Token", "test-token")
661            .match_body(mockito::Matcher::Json(json!({
662                "input": encoded_message
663            })))
664            .with_status(400)
665            .with_header("content-type", "application/json")
666            .with_body(
667                serde_json::to_string(&json!({
668                    "errors": ["1 error occurred:\n\t* signing key not found"]
669                }))
670                .unwrap(),
671            )
672            .create_async()
673            .await;
674
675        let config = VaultConfig::new(
676            mock_server.url(),
677            SecretString::new("test-role-id"),
678            SecretString::new("test-secret-id"),
679            None,
680            "test-mount".to_string(),
681            Some(60),
682        );
683
684        let vault_service = VaultService::new(config);
685        let result = vault_service.sign("my-signing-key", message).await;
686        assert!(result.is_err());
687
688        if let Err(e) = result {
689            assert!(matches!(e, VaultError::SigningError(_)));
690            assert!(e.to_string().contains("Failed to sign with Vault"));
691        }
692    }
693}