openzeppelin_relayer/utils/
encryption.rs1use aes_gcm::{
7 aead::{rand_core::RngCore, Aead, KeyInit, OsRng},
8 Aes256Gcm, Key, Nonce,
9};
10use serde::{Deserialize, Serialize};
11use std::env;
12use thiserror::Error;
13use zeroize::Zeroize;
14
15use crate::{
16 models::SecretString,
17 utils::{base64_decode, base64_encode},
18};
19
20#[derive(Error, Debug, Clone)]
21pub enum EncryptionError {
22 #[error("Encryption failed: {0}")]
23 EncryptionFailed(String),
24 #[error("Decryption failed: {0}")]
25 DecryptionFailed(String),
26 #[error("Key derivation failed: {0}")]
27 KeyDerivationFailed(String),
28 #[error("Invalid encrypted data format: {0}")]
29 InvalidFormat(String),
30 #[error("Missing encryption key environment variable: {0}")]
31 MissingKey(String),
32 #[error("Invalid key length: expected 32 bytes, got {0}")]
33 InvalidKeyLength(usize),
34}
35
36#[derive(Debug, Clone, Serialize, Deserialize)]
38pub struct EncryptedData {
39 pub nonce: String,
41 pub ciphertext: String,
43 pub version: u8,
45}
46
47#[derive(Clone)]
49pub struct FieldEncryption {
50 cipher: Aes256Gcm,
51}
52
53impl FieldEncryption {
54 pub fn new() -> Result<Self, EncryptionError> {
60 let key = Self::load_key_from_env()?;
61 let cipher = Aes256Gcm::new(&key);
62 Ok(Self { cipher })
63 }
64
65 pub fn new_with_key(key: &[u8; 32]) -> Result<Self, EncryptionError> {
67 let key = Key::<Aes256Gcm>::from(*key);
68 let cipher = Aes256Gcm::new(&key);
69 Ok(Self { cipher })
70 }
71
72 fn load_key_from_env() -> Result<Key<Aes256Gcm>, EncryptionError> {
74 let key = env::var("STORAGE_ENCRYPTION_KEY")
75 .map(|v| SecretString::new(&v))
76 .map_err(|_| {
77 EncryptionError::MissingKey("STORAGE_ENCRYPTION_KEY must be set".to_string())
78 })?;
79
80 key.as_str(|key_b64| {
81 let mut key_bytes = base64_decode(key_b64)
82 .map_err(|e| EncryptionError::KeyDerivationFailed(e.to_string()))?;
83 if key_bytes.len() != 32 {
84 key_bytes.zeroize(); return Err(EncryptionError::InvalidKeyLength(key_bytes.len()));
86 }
87
88 let key_array: [u8; 32] = key_bytes
89 .as_slice()
90 .try_into()
91 .map_err(|_| EncryptionError::InvalidKeyLength(key_bytes.len()))?;
92 Ok(Key::<Aes256Gcm>::from(key_array))
93 })
94 }
95
96 pub fn encrypt(&self, plaintext: &[u8]) -> Result<EncryptedData, EncryptionError> {
98 let mut nonce_bytes = [0u8; 12];
100 OsRng.fill_bytes(&mut nonce_bytes);
101 let nonce = &Nonce::from(nonce_bytes);
102
103 let ciphertext = self
105 .cipher
106 .encrypt(nonce, plaintext)
107 .map_err(|e| EncryptionError::EncryptionFailed(e.to_string()))?;
108
109 Ok(EncryptedData {
110 nonce: base64_encode(&nonce_bytes),
111 ciphertext: base64_encode(&ciphertext),
112 version: 1,
113 })
114 }
115
116 pub fn decrypt(&self, encrypted_data: &EncryptedData) -> Result<Vec<u8>, EncryptionError> {
118 if encrypted_data.version != 1 {
119 return Err(EncryptionError::InvalidFormat(format!(
120 "Unsupported encryption version: {}",
121 encrypted_data.version
122 )));
123 }
124
125 let nonce_bytes = base64_decode(&encrypted_data.nonce)
127 .map_err(|e| EncryptionError::InvalidFormat(format!("Invalid nonce: {e}")))?;
128
129 let ciphertext_bytes = base64_decode(&encrypted_data.ciphertext)
130 .map_err(|e| EncryptionError::InvalidFormat(format!("Invalid ciphertext: {e}")))?;
131
132 if nonce_bytes.len() != 12 {
133 return Err(EncryptionError::InvalidFormat(format!(
134 "Invalid nonce length: expected 12, got {}",
135 nonce_bytes.len()
136 )));
137 }
138
139 let nonce_array: [u8; 12] = nonce_bytes
140 .as_slice()
141 .try_into()
142 .map_err(|_| EncryptionError::InvalidFormat("Invalid nonce length".to_string()))?;
143 let nonce = &Nonce::from(nonce_array);
144
145 let plaintext = self
147 .cipher
148 .decrypt(nonce, ciphertext_bytes.as_ref())
149 .map_err(|e| EncryptionError::DecryptionFailed(e.to_string()))?;
150
151 Ok(plaintext)
152 }
153
154 pub fn encrypt_string(&self, plaintext: &str) -> Result<String, EncryptionError> {
156 let encrypted_data = self.encrypt(plaintext.as_bytes())?;
157 let json_data = serde_json::to_string(&encrypted_data)
158 .map_err(|e| EncryptionError::EncryptionFailed(format!("Serialization failed: {e}")))?;
159
160 Ok(base64_encode(json_data.as_bytes()))
162 }
163
164 pub fn decrypt_string(&self, encrypted_base64: &str) -> Result<String, EncryptionError> {
166 let json_bytes = base64_decode(encrypted_base64)
168 .map_err(|e| EncryptionError::InvalidFormat(format!("Invalid base64: {e}")))?;
169
170 let encrypted_json = String::from_utf8(json_bytes).map_err(|e| {
171 EncryptionError::InvalidFormat(format!("Invalid UTF-8 in decoded data: {e}"))
172 })?;
173
174 let encrypted_data: EncryptedData = serde_json::from_str(&encrypted_json)
175 .map_err(|e| EncryptionError::InvalidFormat(format!("Invalid JSON structure: {e}")))?;
176
177 let plaintext_bytes = self.decrypt(&encrypted_data)?;
178 String::from_utf8(plaintext_bytes).map_err(|e| {
179 EncryptionError::DecryptionFailed(format!("Invalid UTF-8 in plaintext: {e}"))
180 })
181 }
182
183 pub fn generate_key() -> String {
185 let mut key = [0u8; 32];
186 OsRng.fill_bytes(&mut key);
187 let key_b64 = base64_encode(&key);
188
189 let mut key_zeroize = key;
191 key_zeroize.zeroize();
192
193 key_b64
194 }
195
196 pub fn is_configured() -> bool {
198 env::var("STORAGE_ENCRYPTION_KEY").is_ok()
199 }
200}
201
202static ENCRYPTION_INSTANCE: std::sync::OnceLock<Result<FieldEncryption, EncryptionError>> =
204 std::sync::OnceLock::new();
205
206pub fn get_encryption() -> Result<&'static FieldEncryption, &'static EncryptionError> {
208 ENCRYPTION_INSTANCE
209 .get_or_init(FieldEncryption::new)
210 .as_ref()
211}
212
213pub fn encrypt_sensitive_field(data: &str) -> Result<String, EncryptionError> {
215 if FieldEncryption::is_configured() {
216 match get_encryption() {
217 Ok(encryption) => encryption.encrypt_string(data),
218 Err(e) => Err(e.clone()),
219 }
220 } else {
221 let json_data = serde_json::to_string(data)
224 .map_err(|e| EncryptionError::EncryptionFailed(format!("JSON encoding failed: {e}")))?;
225 Ok(base64_encode(json_data.as_bytes()))
226 }
227}
228
229pub fn decrypt_sensitive_field(data: &str) -> Result<String, EncryptionError> {
231 let json_bytes = base64_decode(data)
233 .map_err(|e| EncryptionError::InvalidFormat(format!("Invalid base64: {e}")))?;
234
235 let json_str = String::from_utf8(json_bytes)
236 .map_err(|e| EncryptionError::InvalidFormat(format!("Invalid UTF-8: {e}")))?;
237
238 if FieldEncryption::is_configured() {
240 if let Ok(encryption) = get_encryption() {
241 if let Ok(encrypted_data) = serde_json::from_str::<EncryptedData>(&json_str) {
243 let plaintext_bytes = encryption.decrypt(&encrypted_data)?;
245 return String::from_utf8(plaintext_bytes).map_err(|e| {
246 EncryptionError::DecryptionFailed(format!("Invalid UTF-8 in plaintext: {e}"))
247 });
248 }
249 }
250 }
251
252 serde_json::from_str(&json_str)
255 .map_err(|e| EncryptionError::DecryptionFailed(format!("Invalid JSON string: {e}")))
256}
257
258pub fn generate_encryption_key() -> String {
260 FieldEncryption::generate_key()
261}
262
263#[cfg(test)]
264mod tests {
265 use super::*;
266 use std::env;
267
268 #[test]
269 fn test_encrypt_decrypt_data() {
270 let key = [0u8; 32]; let encryption = FieldEncryption::new_with_key(&key).unwrap();
272
273 let plaintext = b"This is a secret message!";
274 let encrypted = encryption.encrypt(plaintext).unwrap();
275 let decrypted = encryption.decrypt(&encrypted).unwrap();
276
277 assert_eq!(plaintext, decrypted.as_slice());
278 }
279
280 #[test]
281 fn test_encrypt_decrypt_string() {
282 let key = [1u8; 32]; let encryption = FieldEncryption::new_with_key(&key).unwrap();
284
285 let plaintext = "Sensitive API key: sk-1234567890abcdef";
286 let encrypted = encryption.encrypt_string(plaintext).unwrap();
287 let decrypted = encryption.decrypt_string(&encrypted).unwrap();
288
289 assert_eq!(plaintext, decrypted);
290 }
291
292 #[test]
293 fn test_different_keys_produce_different_results() {
294 let key1 = [1u8; 32];
295 let key2 = [2u8; 32];
296 let encryption1 = FieldEncryption::new_with_key(&key1).unwrap();
297 let encryption2 = FieldEncryption::new_with_key(&key2).unwrap();
298
299 let plaintext = "secret";
300 let encrypted1 = encryption1.encrypt_string(plaintext).unwrap();
301 let encrypted2 = encryption2.encrypt_string(plaintext).unwrap();
302
303 assert_ne!(encrypted1, encrypted2);
304
305 assert_eq!(encryption1.decrypt_string(&encrypted1).unwrap(), plaintext);
307 assert_eq!(encryption2.decrypt_string(&encrypted2).unwrap(), plaintext);
308
309 assert!(encryption1.decrypt_string(&encrypted2).is_err());
311 assert!(encryption2.decrypt_string(&encrypted1).is_err());
312 }
313
314 #[test]
315 fn test_nonce_uniqueness() {
316 let key = [3u8; 32];
317 let encryption = FieldEncryption::new_with_key(&key).unwrap();
318
319 let plaintext = "same message";
320 let encrypted1 = encryption.encrypt_string(plaintext).unwrap();
321 let encrypted2 = encryption.encrypt_string(plaintext).unwrap();
322
323 assert_ne!(encrypted1, encrypted2);
325
326 assert_eq!(encryption.decrypt_string(&encrypted1).unwrap(), plaintext);
328 assert_eq!(encryption.decrypt_string(&encrypted2).unwrap(), plaintext);
329 }
330
331 #[test]
332 fn test_invalid_encrypted_data() {
333 let key = [4u8; 32];
334 let encryption = FieldEncryption::new_with_key(&key).unwrap();
335
336 assert!(encryption.decrypt_string("invalid base64!").is_err());
338
339 assert!(encryption
341 .decrypt_string(&base64_encode(b"not json"))
342 .is_err());
343
344 let invalid_json_b64 = base64_encode(b"{\"wrong\": \"structure\"}");
346 assert!(encryption.decrypt_string(&invalid_json_b64).is_err());
347
348 assert!(encryption
350 .decrypt_string(&base64_encode(
351 b"{\"nonce\":\"test\",\"ciphertext\":\"test\",\"version\":1}"
352 ))
353 .is_err());
354 }
355
356 #[test]
357 fn test_generate_key() {
358 let key1 = FieldEncryption::generate_key();
359 let key2 = FieldEncryption::generate_key();
360
361 assert_ne!(key1, key2);
363
364 assert!(base64_decode(&key1).is_ok());
366 assert!(base64_decode(&key2).is_ok());
367
368 assert_eq!(base64_decode(&key1).unwrap().len(), 32);
370 assert_eq!(base64_decode(&key2).unwrap().len(), 32);
371 }
372
373 #[test]
374 fn test_env_key_loading() {
375 let test_key = FieldEncryption::generate_key();
377 env::set_var("STORAGE_ENCRYPTION_KEY", &test_key);
378
379 let encryption = FieldEncryption::new().unwrap();
380 let plaintext = "test message";
381 let encrypted = encryption.encrypt_string(plaintext).unwrap();
382 let decrypted = encryption.decrypt_string(&encrypted).unwrap();
383 assert_eq!(plaintext, decrypted);
384
385 env::remove_var("STORAGE_ENCRYPTION_KEY");
387 assert!(FieldEncryption::new().is_err());
388
389 env::set_var("STORAGE_ENCRYPTION_KEY", &test_key);
391 }
392
393 #[test]
394 fn test_high_level_encryption_functions() {
395 let plaintext = "sensitive data";
396
397 let encoded = encrypt_sensitive_field(plaintext).unwrap();
399 let decoded = decrypt_sensitive_field(&encoded).unwrap();
400 assert_eq!(plaintext, decoded);
401
402 assert!(base64_decode(&encoded).is_ok());
404
405 }
408
409 #[test]
410 fn test_fallback_when_encryption_disabled() {
411 let old_key = env::var("STORAGE_ENCRYPTION_KEY").ok();
413
414 env::remove_var("STORAGE_ENCRYPTION_KEY");
415
416 let plaintext = "fallback test";
417
418 let encoded = encrypt_sensitive_field(plaintext).unwrap();
420 let decoded = decrypt_sensitive_field(&encoded).unwrap();
421 assert_eq!(plaintext, decoded);
422
423 let expected_json = serde_json::to_string(plaintext).unwrap();
425 let expected_b64 = base64_encode(expected_json.as_bytes());
426 assert_eq!(encoded, expected_b64);
427
428 if let Some(key) = old_key {
430 env::set_var("STORAGE_ENCRYPTION_KEY", key);
431 }
432 }
433
434 #[test]
435 fn test_core_encryption_methods() {
436 let key = [9u8; 32];
437 let encryption = FieldEncryption::new_with_key(&key).unwrap();
438 let plaintext = "core encryption test";
439
440 let encrypted = encryption.encrypt_string(plaintext).unwrap();
442 let decrypted = encryption.decrypt_string(&encrypted).unwrap();
443 assert_eq!(plaintext, decrypted);
444
445 assert!(base64_decode(&encrypted).is_ok());
447 assert!(!encrypted.contains("nonce"));
449 assert!(!encrypted.contains("ciphertext"));
450 assert!(!encrypted.contains("{"));
451 }
452
453 #[test]
454 fn test_base64_encoding_hides_structure() {
455 let key = [7u8; 32];
456 let encryption = FieldEncryption::new_with_key(&key).unwrap();
457
458 let plaintext = "secret message";
459 let encrypted = encryption.encrypt_string(plaintext).unwrap();
460
461 assert!(base64_decode(&encrypted).is_ok());
463
464 assert!(!encrypted.contains("nonce"));
466 assert!(!encrypted.contains("ciphertext"));
467 assert!(!encrypted.contains("version"));
468 assert!(!encrypted.contains("{"));
469 assert!(!encrypted.contains("}"));
470
471 let decrypted = encryption.decrypt_string(&encrypted).unwrap();
473 assert_eq!(plaintext, decrypted);
474 }
475}