1use async_trait::async_trait;
24use core::fmt;
25use once_cell::sync::Lazy;
26use serde::Serialize;
27use std::collections::HashMap;
28use std::hash::Hash;
29use std::sync::Arc;
30use std::time::{Duration, Instant};
31use thiserror::Error;
32use tokio::sync::RwLock;
33use tracing::debug;
34use vaultrs::{
35 auth::approle::login,
36 client::{VaultClient, VaultClientSettingsBuilder},
37 kv2, transit,
38};
39use zeroize::{Zeroize, ZeroizeOnDrop};
40
41#[derive(Error, Debug, Serialize)]
42pub enum VaultError {
43 #[error("Vault client error: {0}")]
44 ClientError(String),
45
46 #[error("Secret not found: {0}")]
47 SecretNotFound(String),
48
49 #[error("Authentication failed: {0}")]
50 AuthenticationFailed(String),
51
52 #[error("Configuration error: {0}")]
53 ConfigError(String),
54
55 #[error("Signing error: {0}")]
56 SigningError(String),
57}
58
59#[derive(Clone, Debug, PartialEq, Eq, Hash, Zeroize, ZeroizeOnDrop)]
61struct VaultCacheKey {
62 address: String,
63 role_id: String,
64 namespace: Option<String>,
65}
66
67impl fmt::Display for VaultCacheKey {
68 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
69 write!(
70 f,
71 "{}|{}|{}",
72 self.address,
73 self.role_id,
74 self.namespace.as_deref().unwrap_or("")
75 )
76 }
77}
78
79struct TokenCache {
80 client: Arc<VaultClient>,
81 expiry: Instant,
82}
83
84static TOKEN_CACHE: Lazy<RwLock<HashMap<VaultCacheKey, TokenCache>>> =
86 Lazy::new(|| RwLock::new(HashMap::new()));
87
88#[cfg(test)]
89use mockall::automock;
90
91use crate::models::SecretString;
92use crate::utils::base64_encode;
93
94#[derive(Clone, Debug)]
95pub struct VaultConfig {
96 pub address: String,
97 pub namespace: Option<String>,
98 pub role_id: SecretString,
99 pub secret_id: SecretString,
100 pub mount_path: String,
101 pub token_ttl: Option<u64>,
103}
104
105impl VaultConfig {
106 pub fn new(
107 address: String,
108 role_id: SecretString,
109 secret_id: SecretString,
110 namespace: Option<String>,
111 mount_path: String,
112 token_ttl: Option<u64>,
113 ) -> Self {
114 Self {
115 address,
116 role_id,
117 secret_id,
118 namespace,
119 mount_path,
120 token_ttl,
121 }
122 }
123
124 fn cache_key(&self) -> VaultCacheKey {
125 VaultCacheKey {
126 address: self.address.clone(),
127 role_id: self.role_id.to_str().to_string(),
128 namespace: self.namespace.clone(),
129 }
130 }
131}
132
133#[async_trait]
134#[cfg_attr(test, automock)]
135pub trait VaultServiceTrait: Send + Sync {
136 async fn retrieve_secret(&self, key_name: &str) -> Result<String, VaultError>;
137 async fn sign(&self, key_name: &str, message: &[u8]) -> Result<String, VaultError>;
138}
139
140#[derive(Clone, Debug)]
141pub struct VaultService {
142 pub config: VaultConfig,
143}
144
145impl VaultService {
146 pub fn new(config: VaultConfig) -> Self {
147 Self { config }
148 }
149
150 async fn get_client(&self) -> Result<Arc<VaultClient>, VaultError> {
152 let cache_key = self.config.cache_key();
153
154 {
156 let cache = TOKEN_CACHE.read().await;
157 if let Some(cached) = cache.get(&cache_key) {
158 if Instant::now() < cached.expiry {
159 return Ok(Arc::clone(&cached.client));
160 }
161 }
162 }
163
164 let mut cache = TOKEN_CACHE.write().await;
166 if let Some(cached) = cache.get(&cache_key) {
168 if Instant::now() < cached.expiry {
169 return Ok(Arc::clone(&cached.client));
170 }
171 }
172
173 let client = self.create_authenticated_client().await?;
175
176 let ttl = Duration::from_secs(self.config.token_ttl.unwrap_or(45 * 60));
178
179 cache.insert(
181 cache_key,
182 TokenCache {
183 client: client.clone(),
184 expiry: Instant::now() + ttl,
185 },
186 );
187
188 Ok(client)
189 }
190
191 async fn create_authenticated_client(&self) -> Result<Arc<VaultClient>, VaultError> {
193 let mut auth_settings_builder = VaultClientSettingsBuilder::default();
194 let address = &self.config.address;
195 auth_settings_builder.address(address).verify(true);
196
197 if let Some(namespace) = &self.config.namespace {
198 auth_settings_builder.namespace(Some(namespace.clone()));
199 }
200
201 let auth_settings = auth_settings_builder.build().map_err(|e| {
202 VaultError::ConfigError(format!("Failed to build Vault client settings: {e}"))
203 })?;
204
205 let client = VaultClient::new(auth_settings)
206 .map_err(|e| VaultError::ConfigError(format!("Failed to create Vault client: {e}")))?;
207
208 let token = login(
209 &client,
210 "approle",
211 &self.config.role_id.to_str(),
212 &self.config.secret_id.to_str(),
213 )
214 .await
215 .map_err(|e| VaultError::AuthenticationFailed(e.to_string()))?;
216
217 let mut transit_settings_builder = VaultClientSettingsBuilder::default();
218
219 transit_settings_builder
220 .address(self.config.address.clone())
221 .token(token.client_token.clone())
222 .verify(true);
223
224 if let Some(namespace) = &self.config.namespace {
225 transit_settings_builder.namespace(Some(namespace.clone()));
226 }
227
228 let transit_settings = transit_settings_builder.build().map_err(|e| {
229 VaultError::ConfigError(format!("Failed to build Vault client settings: {e}"))
230 })?;
231
232 let client = Arc::new(VaultClient::new(transit_settings).map_err(|e| {
233 VaultError::ConfigError(format!("Failed to create authenticated Vault client: {e}"))
234 })?);
235
236 Ok(client)
237 }
238}
239
240#[async_trait]
241impl VaultServiceTrait for VaultService {
242 async fn retrieve_secret(&self, key_name: &str) -> Result<String, VaultError> {
243 let client = self.get_client().await?;
244
245 let secret: serde_json::Value = kv2::read(&*client, &self.config.mount_path, key_name)
246 .await
247 .map_err(|e| VaultError::ClientError(e.to_string()))?;
248
249 let value = secret["value"]
250 .as_str()
251 .ok_or_else(|| {
252 VaultError::SecretNotFound(format!("Secret value invalid for key: {key_name}"))
253 })?
254 .to_string();
255
256 Ok(value)
257 }
258
259 async fn sign(&self, key_name: &str, message: &[u8]) -> Result<String, VaultError> {
260 let client = self.get_client().await?;
261
262 let vault_signature = transit::data::sign(
263 &*client,
264 &self.config.mount_path,
265 key_name,
266 &base64_encode(message),
267 None,
268 )
269 .await
270 .map_err(|e| VaultError::SigningError(format!("Failed to sign with Vault: {e}")))?;
271
272 let vault_signature_str = &vault_signature.signature;
273
274 debug!(vault_signature_str = %vault_signature_str, "vault signature string");
275
276 Ok(vault_signature_str.clone())
277 }
278}
279
280#[cfg(test)]
281mod tests {
282 use super::*;
283 use mockito;
284 use serde_json::json;
285
286 #[test]
287 fn test_vault_config_new() {
288 let config = VaultConfig::new(
289 "https://vault.example.com".to_string(),
290 SecretString::new("test-role-id"),
291 SecretString::new("test-secret-id"),
292 Some("test-namespace".to_string()),
293 "test-mount-path".to_string(),
294 Some(60),
295 );
296
297 assert_eq!(config.address, "https://vault.example.com");
298 assert_eq!(config.role_id.to_str().as_str(), "test-role-id");
299 assert_eq!(config.secret_id.to_str().as_str(), "test-secret-id");
300 assert_eq!(config.namespace, Some("test-namespace".to_string()));
301 assert_eq!(config.mount_path, "test-mount-path");
302 assert_eq!(config.token_ttl, Some(60));
303 }
304
305 #[test]
306 fn test_vault_cache_key() {
307 let config1 = VaultConfig {
308 address: "https://vault1.example.com".to_string(),
309 namespace: Some("namespace1".to_string()),
310 role_id: SecretString::new("role1"),
311 secret_id: SecretString::new("secret1"),
312 mount_path: "transit".to_string(),
313 token_ttl: None,
314 };
315
316 let config2 = VaultConfig {
317 address: "https://vault1.example.com".to_string(),
318 namespace: Some("namespace1".to_string()),
319 role_id: SecretString::new("role1"),
320 secret_id: SecretString::new("secret1"),
321 mount_path: "different-mount".to_string(),
322 token_ttl: None,
323 };
324
325 let config3 = VaultConfig {
326 address: "https://vault2.example.com".to_string(),
327 namespace: Some("namespace1".to_string()),
328 role_id: SecretString::new("role1"),
329 secret_id: SecretString::new("secret1"),
330 mount_path: "transit".to_string(),
331 token_ttl: None,
332 };
333
334 assert_eq!(config1.cache_key(), config1.cache_key());
335 assert_eq!(config1.cache_key(), config2.cache_key());
336 assert_ne!(config1.cache_key(), config3.cache_key());
337 }
338
339 #[test]
340 fn test_vault_cache_key_display() {
341 let key_with_namespace = VaultCacheKey {
342 address: "https://vault.example.com".to_string(),
343 role_id: "role-123".to_string(),
344 namespace: Some("my-namespace".to_string()),
345 };
346
347 let key_without_namespace = VaultCacheKey {
348 address: "https://vault.example.com".to_string(),
349 role_id: "role-123".to_string(),
350 namespace: None,
351 };
352
353 assert_eq!(
354 key_with_namespace.to_string(),
355 "https://vault.example.com|role-123|my-namespace"
356 );
357
358 assert_eq!(
359 key_without_namespace.to_string(),
360 "https://vault.example.com|role-123|"
361 );
362 }
363
364 async fn setup_mock_approle_login(
366 mock_server: &mut mockito::ServerGuard,
367 role_id: &str,
368 secret_id: &str,
369 token: &str,
370 ) -> mockito::Mock {
371 mock_server
372 .mock("POST", "/v1/auth/approle/login")
373 .match_body(mockito::Matcher::Json(json!({
374 "role_id": role_id,
375 "secret_id": secret_id
376 })))
377 .with_status(200)
378 .with_header("content-type", "application/json")
379 .with_body(
380 serde_json::to_string(&json!({
381 "request_id": "test-request-id",
382 "lease_id": "",
383 "renewable": false,
384 "lease_duration": 0,
385 "data": null,
386 "wrap_info": null,
387 "warnings": null,
388 "auth": {
389 "client_token": token,
390 "accessor": "test-accessor",
391 "policies": ["default"],
392 "token_policies": ["default"],
393 "metadata": {
394 "role_name": "test-role"
395 },
396 "lease_duration": 3600,
397 "renewable": true,
398 "entity_id": "test-entity-id",
399 "token_type": "service",
400 "orphan": true
401 }
402 }))
403 .unwrap(),
404 )
405 .create_async()
406 .await
407 }
408
409 #[tokio::test]
410 async fn test_vault_service_auth_failure() {
411 let mut mock_server = mockito::Server::new_async().await;
412
413 let _login_mock = setup_mock_approle_login(
414 &mut mock_server,
415 "test-role-id",
416 "test-secret-id",
417 "test-token",
418 )
419 .await;
420
421 let _secret_mock = mock_server
422 .mock("GET", "/v1/test-mount/data/my-secret")
423 .match_header("X-Vault-Token", "test-token")
424 .with_status(200)
425 .with_header("content-type", "application/json")
426 .with_body(
427 serde_json::to_string(&json!({
428 "request_id": "test-request-id",
429 "lease_id": "",
430 "renewable": false,
431 "lease_duration": 0,
432 "data": {
433 "data": {
434 "value": "super-secret-value"
435 },
436 "metadata": {
437 "created_time": "2024-01-01T00:00:00Z",
438 "deletion_time": "",
439 "destroyed": false,
440 "version": 1
441 }
442 },
443 "wrap_info": null,
444 "warnings": null,
445 "auth": null
446 }))
447 .unwrap(),
448 )
449 .create_async()
450 .await;
451
452 let config = VaultConfig::new(
453 mock_server.url(),
454 SecretString::new("test-role-id-fake"),
455 SecretString::new("test-secret-id-fake"),
456 None,
457 "test-mount".to_string(),
458 Some(60),
459 );
460
461 let vault_service = VaultService::new(config);
462
463 let secret = vault_service.retrieve_secret("my-secret").await;
464
465 assert!(secret.is_err());
466
467 if let Err(e) = secret {
468 assert!(matches!(e, VaultError::AuthenticationFailed(_)));
469 assert!(e.to_string().contains("An error occurred with the request"));
470 }
471 }
472
473 #[tokio::test]
474 async fn test_vault_service_retrieve_secret_success() {
475 let mut mock_server = mockito::Server::new_async().await;
476
477 let _login_mock = setup_mock_approle_login(
478 &mut mock_server,
479 "test-role-id",
480 "test-secret-id",
481 "test-token",
482 )
483 .await;
484
485 let _secret_mock = mock_server
486 .mock(
487 "GET",
488 mockito::Matcher::Regex(r"/v1/test-mount/data/my-secret.*".to_string()),
489 )
490 .match_header("X-Vault-Token", "test-token")
491 .with_status(200)
492 .with_header("content-type", "application/json")
493 .with_body(
494 serde_json::to_string(&json!({
495 "request_id": "test-request-id",
496 "lease_id": "",
497 "renewable": false,
498 "lease_duration": 0,
499 "data": {
500 "data": {
501 "value": "super-secret-value"
502 },
503 "metadata": {
504 "created_time": "2024-01-01T00:00:00Z",
505 "deletion_time": "",
506 "destroyed": false,
507 "version": 1
508 }
509 },
510 "wrap_info": null,
511 "warnings": null,
512 "auth": null
513 }))
514 .unwrap(),
515 )
516 .create_async()
517 .await;
518
519 let config = VaultConfig::new(
520 mock_server.url(),
521 SecretString::new("test-role-id"),
522 SecretString::new("test-secret-id"),
523 None,
524 "test-mount".to_string(),
525 Some(60),
526 );
527
528 let vault_service = VaultService::new(config);
529
530 let secret = vault_service.retrieve_secret("my-secret").await.unwrap();
531
532 assert_eq!(secret, "super-secret-value");
533 }
534
535 #[tokio::test]
536 async fn test_vault_service_sign_success() {
537 let mut mock_server = mockito::Server::new_async().await;
538
539 let _login_mock = setup_mock_approle_login(
540 &mut mock_server,
541 "test-role-id",
542 "test-secret-id",
543 "test-token",
544 )
545 .await;
546
547 let message = b"hello world";
548 let encoded_message = base64_encode(message);
549
550 let _sign_mock = mock_server
551 .mock("POST", "/v1/test-mount/sign/my-signing-key")
552 .match_header("X-Vault-Token", "test-token")
553 .match_body(mockito::Matcher::Json(json!({
554 "input": encoded_message
555 })))
556 .with_status(200)
557 .with_header("content-type", "application/json")
558 .with_body(
559 serde_json::to_string(&json!({
560 "request_id": "test-request-id",
561 "lease_id": "",
562 "renewable": false,
563 "lease_duration": 0,
564 "data": {
565 "signature": "vault:v1:fake-signature",
566 "key_version": 1
567 },
568 "wrap_info": null,
569 "warnings": null,
570 "auth": null
571 }))
572 .unwrap(),
573 )
574 .create_async()
575 .await;
576
577 let config = VaultConfig::new(
578 mock_server.url(),
579 SecretString::new("test-role-id"),
580 SecretString::new("test-secret-id"),
581 None,
582 "test-mount".to_string(),
583 Some(60),
584 );
585
586 let vault_service = VaultService::new(config);
587 let signature = vault_service.sign("my-signing-key", message).await.unwrap();
588
589 assert_eq!(signature, "vault:v1:fake-signature");
590 }
591
592 #[tokio::test]
593 async fn test_vault_service_retrieve_secret_failure() {
594 let mut mock_server = mockito::Server::new_async().await;
595
596 let _login_mock = setup_mock_approle_login(
597 &mut mock_server,
598 "test-role-id",
599 "test-secret-id",
600 "test-token",
601 )
602 .await;
603
604 let _secret_mock = mock_server
605 .mock(
606 "GET",
607 mockito::Matcher::Regex(r"/v1/test-mount/data/my-secret.*".to_string()),
608 )
609 .match_header("X-Vault-Token", "test-token")
610 .with_status(404)
611 .with_header("content-type", "application/json")
612 .with_body(
613 serde_json::to_string(&json!({
614 "errors": ["secret not found:"]
615 }))
616 .unwrap(),
617 )
618 .create_async()
619 .await;
620
621 let config = VaultConfig::new(
622 mock_server.url(),
623 SecretString::new("test-role-id"),
624 SecretString::new("test-secret-id"),
625 None,
626 "test-mount".to_string(),
627 Some(60),
628 );
629
630 let vault_service = VaultService::new(config);
631
632 let result = vault_service.retrieve_secret("my-secret").await;
633 assert!(result.is_err());
634
635 if let Err(e) = result {
636 assert!(matches!(e, VaultError::ClientError(_)));
637 assert!(e
638 .to_string()
639 .contains("The Vault server returned an error (status code 404)"));
640 }
641 }
642
643 #[tokio::test]
644 async fn test_vault_service_sign_failure() {
645 let mut mock_server = mockito::Server::new_async().await;
646
647 let _login_mock = setup_mock_approle_login(
648 &mut mock_server,
649 "test-role-id",
650 "test-secret-id",
651 "test-token",
652 )
653 .await;
654
655 let message = b"hello world";
656 let encoded_message = base64_encode(message);
657
658 let _sign_mock = mock_server
659 .mock("POST", "/v1/test-mount/sign/my-signing-key")
660 .match_header("X-Vault-Token", "test-token")
661 .match_body(mockito::Matcher::Json(json!({
662 "input": encoded_message
663 })))
664 .with_status(400)
665 .with_header("content-type", "application/json")
666 .with_body(
667 serde_json::to_string(&json!({
668 "errors": ["1 error occurred:\n\t* signing key not found"]
669 }))
670 .unwrap(),
671 )
672 .create_async()
673 .await;
674
675 let config = VaultConfig::new(
676 mock_server.url(),
677 SecretString::new("test-role-id"),
678 SecretString::new("test-secret-id"),
679 None,
680 "test-mount".to_string(),
681 Some(60),
682 );
683
684 let vault_service = VaultService::new(config);
685 let result = vault_service.sign("my-signing-key", message).await;
686 assert!(result.is_err());
687
688 if let Err(e) = result {
689 assert!(matches!(e, VaultError::SigningError(_)));
690 assert!(e.to_string().contains("Failed to sign with Vault"));
691 }
692 }
693}