openzeppelin_relayer/utils/
encryption.rs

1//! Field-level encryption utilities for sensitive data protection
2//!
3//! This module provides secure encryption and decryption of sensitive fields using AES-256-GCM.
4//! It's designed to be used transparently in the repository layer to protect data at rest.
5
6use aes_gcm::{
7    aead::{rand_core::RngCore, Aead, KeyInit, OsRng},
8    Aes256Gcm, Key, Nonce,
9};
10use serde::{Deserialize, Serialize};
11use std::env;
12use thiserror::Error;
13use zeroize::Zeroize;
14
15use crate::{
16    models::SecretString,
17    utils::{base64_decode, base64_encode},
18};
19
20#[derive(Error, Debug, Clone)]
21pub enum EncryptionError {
22    #[error("Encryption failed: {0}")]
23    EncryptionFailed(String),
24    #[error("Decryption failed: {0}")]
25    DecryptionFailed(String),
26    #[error("Key derivation failed: {0}")]
27    KeyDerivationFailed(String),
28    #[error("Invalid encrypted data format: {0}")]
29    InvalidFormat(String),
30    #[error("Missing encryption key environment variable: {0}")]
31    MissingKey(String),
32    #[error("Invalid key length: expected 32 bytes, got {0}")]
33    InvalidKeyLength(usize),
34}
35
36/// Encrypted data container that holds the nonce and ciphertext
37#[derive(Debug, Clone, Serialize, Deserialize)]
38pub struct EncryptedData {
39    /// Base64-encoded nonce (12 bytes for GCM)
40    pub nonce: String,
41    /// Base64-encoded ciphertext with authentication tag
42    pub ciphertext: String,
43    /// Version for future compatibility
44    pub version: u8,
45}
46
47/// Main encryption service for field-level encryption
48#[derive(Clone)]
49pub struct FieldEncryption {
50    cipher: Aes256Gcm,
51}
52
53impl FieldEncryption {
54    /// Creates a new FieldEncryption instance using a key from environment variables
55    ///
56    /// # Environment Variables
57    /// - `STORAGE_ENCRYPTION_KEY`: Base64-encoded 32-byte encryption key
58    /// ```
59    pub fn new() -> Result<Self, EncryptionError> {
60        let key = Self::load_key_from_env()?;
61        let cipher = Aes256Gcm::new(&key);
62        Ok(Self { cipher })
63    }
64
65    /// Creates a new FieldEncryption instance with a provided key (for testing)
66    pub fn new_with_key(key: &[u8; 32]) -> Result<Self, EncryptionError> {
67        let key = Key::<Aes256Gcm>::from(*key);
68        let cipher = Aes256Gcm::new(&key);
69        Ok(Self { cipher })
70    }
71
72    /// Loads encryption key from environment variables
73    fn load_key_from_env() -> Result<Key<Aes256Gcm>, EncryptionError> {
74        let key = env::var("STORAGE_ENCRYPTION_KEY")
75            .map(|v| SecretString::new(&v))
76            .map_err(|_| {
77                EncryptionError::MissingKey("STORAGE_ENCRYPTION_KEY must be set".to_string())
78            })?;
79
80        key.as_str(|key_b64| {
81            let mut key_bytes = base64_decode(key_b64)
82                .map_err(|e| EncryptionError::KeyDerivationFailed(e.to_string()))?;
83            if key_bytes.len() != 32 {
84                key_bytes.zeroize(); // Explicit cleanup on error path
85                return Err(EncryptionError::InvalidKeyLength(key_bytes.len()));
86            }
87
88            let key_array: [u8; 32] = key_bytes
89                .as_slice()
90                .try_into()
91                .map_err(|_| EncryptionError::InvalidKeyLength(key_bytes.len()))?;
92            Ok(Key::<Aes256Gcm>::from(key_array))
93        })
94    }
95
96    /// Encrypts plaintext data and returns an EncryptedData structure
97    pub fn encrypt(&self, plaintext: &[u8]) -> Result<EncryptedData, EncryptionError> {
98        // Generate random 12-byte nonce for GCM
99        let mut nonce_bytes = [0u8; 12];
100        OsRng.fill_bytes(&mut nonce_bytes);
101        let nonce = &Nonce::from(nonce_bytes);
102
103        // Encrypt the data
104        let ciphertext = self
105            .cipher
106            .encrypt(nonce, plaintext)
107            .map_err(|e| EncryptionError::EncryptionFailed(e.to_string()))?;
108
109        Ok(EncryptedData {
110            nonce: base64_encode(&nonce_bytes),
111            ciphertext: base64_encode(&ciphertext),
112            version: 1,
113        })
114    }
115
116    /// Decrypts an EncryptedData structure and returns the plaintext
117    pub fn decrypt(&self, encrypted_data: &EncryptedData) -> Result<Vec<u8>, EncryptionError> {
118        if encrypted_data.version != 1 {
119            return Err(EncryptionError::InvalidFormat(format!(
120                "Unsupported encryption version: {}",
121                encrypted_data.version
122            )));
123        }
124
125        // Decode nonce and ciphertext
126        let nonce_bytes = base64_decode(&encrypted_data.nonce)
127            .map_err(|e| EncryptionError::InvalidFormat(format!("Invalid nonce: {e}")))?;
128
129        let ciphertext_bytes = base64_decode(&encrypted_data.ciphertext)
130            .map_err(|e| EncryptionError::InvalidFormat(format!("Invalid ciphertext: {e}")))?;
131
132        if nonce_bytes.len() != 12 {
133            return Err(EncryptionError::InvalidFormat(format!(
134                "Invalid nonce length: expected 12, got {}",
135                nonce_bytes.len()
136            )));
137        }
138
139        let nonce_array: [u8; 12] = nonce_bytes
140            .as_slice()
141            .try_into()
142            .map_err(|_| EncryptionError::InvalidFormat("Invalid nonce length".to_string()))?;
143        let nonce = &Nonce::from(nonce_array);
144
145        // Decrypt the data
146        let plaintext = self
147            .cipher
148            .decrypt(nonce, ciphertext_bytes.as_ref())
149            .map_err(|e| EncryptionError::DecryptionFailed(e.to_string()))?;
150
151        Ok(plaintext)
152    }
153
154    /// Encrypts a string and returns base64-encoded encrypted data (opaque format)
155    pub fn encrypt_string(&self, plaintext: &str) -> Result<String, EncryptionError> {
156        let encrypted_data = self.encrypt(plaintext.as_bytes())?;
157        let json_data = serde_json::to_string(&encrypted_data)
158            .map_err(|e| EncryptionError::EncryptionFailed(format!("Serialization failed: {e}")))?;
159
160        // Base64 encode the entire JSON to make it opaque
161        Ok(base64_encode(json_data.as_bytes()))
162    }
163
164    /// Decrypts a base64-encoded encrypted string
165    pub fn decrypt_string(&self, encrypted_base64: &str) -> Result<String, EncryptionError> {
166        // Decode from base64 to get the JSON
167        let json_bytes = base64_decode(encrypted_base64)
168            .map_err(|e| EncryptionError::InvalidFormat(format!("Invalid base64: {e}")))?;
169
170        let encrypted_json = String::from_utf8(json_bytes).map_err(|e| {
171            EncryptionError::InvalidFormat(format!("Invalid UTF-8 in decoded data: {e}"))
172        })?;
173
174        let encrypted_data: EncryptedData = serde_json::from_str(&encrypted_json)
175            .map_err(|e| EncryptionError::InvalidFormat(format!("Invalid JSON structure: {e}")))?;
176
177        let plaintext_bytes = self.decrypt(&encrypted_data)?;
178        String::from_utf8(plaintext_bytes).map_err(|e| {
179            EncryptionError::DecryptionFailed(format!("Invalid UTF-8 in plaintext: {e}"))
180        })
181    }
182
183    /// Utility function to generate a new encryption key for setup
184    pub fn generate_key() -> String {
185        let mut key = [0u8; 32];
186        OsRng.fill_bytes(&mut key);
187        let key_b64 = base64_encode(&key);
188
189        // Zero out the key from memory
190        let mut key_zeroize = key;
191        key_zeroize.zeroize();
192
193        key_b64
194    }
195
196    /// Checks if encryption is properly configured
197    pub fn is_configured() -> bool {
198        env::var("STORAGE_ENCRYPTION_KEY").is_ok()
199    }
200}
201
202/// Global encryption instance (lazy-initialized)
203static ENCRYPTION_INSTANCE: std::sync::OnceLock<Result<FieldEncryption, EncryptionError>> =
204    std::sync::OnceLock::new();
205
206/// Gets the global encryption instance
207pub fn get_encryption() -> Result<&'static FieldEncryption, &'static EncryptionError> {
208    ENCRYPTION_INSTANCE
209        .get_or_init(FieldEncryption::new)
210        .as_ref()
211}
212
213/// Encrypts sensitive data if encryption is configured, otherwise returns base64-encoded plaintext
214pub fn encrypt_sensitive_field(data: &str) -> Result<String, EncryptionError> {
215    if FieldEncryption::is_configured() {
216        match get_encryption() {
217            Ok(encryption) => encryption.encrypt_string(data),
218            Err(e) => Err(e.clone()),
219        }
220    } else {
221        // For development/testing when encryption is not configured,
222        // base64-encode the JSON string for consistency
223        let json_data = serde_json::to_string(data)
224            .map_err(|e| EncryptionError::EncryptionFailed(format!("JSON encoding failed: {e}")))?;
225        Ok(base64_encode(json_data.as_bytes()))
226    }
227}
228
229/// Decrypts sensitive data from base64 format
230pub fn decrypt_sensitive_field(data: &str) -> Result<String, EncryptionError> {
231    // Always try to decode base64 first
232    let json_bytes = base64_decode(data)
233        .map_err(|e| EncryptionError::InvalidFormat(format!("Invalid base64: {e}")))?;
234
235    let json_str = String::from_utf8(json_bytes)
236        .map_err(|e| EncryptionError::InvalidFormat(format!("Invalid UTF-8: {e}")))?;
237
238    // Try to parse as encrypted data first (if encryption is configured)
239    if FieldEncryption::is_configured() {
240        if let Ok(encryption) = get_encryption() {
241            // Check if this looks like encrypted data by trying to parse as EncryptedData
242            if let Ok(encrypted_data) = serde_json::from_str::<EncryptedData>(&json_str) {
243                // This is encrypted data, decrypt it
244                let plaintext_bytes = encryption.decrypt(&encrypted_data)?;
245                return String::from_utf8(plaintext_bytes).map_err(|e| {
246                    EncryptionError::DecryptionFailed(format!("Invalid UTF-8 in plaintext: {e}"))
247                });
248            }
249        }
250    }
251
252    // If we get here, either encryption is not configured, or this is fallback data
253    // Try to parse as JSON string (fallback format)
254    serde_json::from_str(&json_str)
255        .map_err(|e| EncryptionError::DecryptionFailed(format!("Invalid JSON string: {e}")))
256}
257
258/// Utility function to generate a new encryption key
259pub fn generate_encryption_key() -> String {
260    FieldEncryption::generate_key()
261}
262
263#[cfg(test)]
264mod tests {
265    use super::*;
266    use std::env;
267
268    #[test]
269    fn test_encrypt_decrypt_data() {
270        let key = [0u8; 32]; // Test key
271        let encryption = FieldEncryption::new_with_key(&key).unwrap();
272
273        let plaintext = b"This is a secret message!";
274        let encrypted = encryption.encrypt(plaintext).unwrap();
275        let decrypted = encryption.decrypt(&encrypted).unwrap();
276
277        assert_eq!(plaintext, decrypted.as_slice());
278    }
279
280    #[test]
281    fn test_encrypt_decrypt_string() {
282        let key = [1u8; 32]; // Different test key
283        let encryption = FieldEncryption::new_with_key(&key).unwrap();
284
285        let plaintext = "Sensitive API key: sk-1234567890abcdef";
286        let encrypted = encryption.encrypt_string(plaintext).unwrap();
287        let decrypted = encryption.decrypt_string(&encrypted).unwrap();
288
289        assert_eq!(plaintext, decrypted);
290    }
291
292    #[test]
293    fn test_different_keys_produce_different_results() {
294        let key1 = [1u8; 32];
295        let key2 = [2u8; 32];
296        let encryption1 = FieldEncryption::new_with_key(&key1).unwrap();
297        let encryption2 = FieldEncryption::new_with_key(&key2).unwrap();
298
299        let plaintext = "secret";
300        let encrypted1 = encryption1.encrypt_string(plaintext).unwrap();
301        let encrypted2 = encryption2.encrypt_string(plaintext).unwrap();
302
303        assert_ne!(encrypted1, encrypted2);
304
305        // Each should decrypt with their own key
306        assert_eq!(encryption1.decrypt_string(&encrypted1).unwrap(), plaintext);
307        assert_eq!(encryption2.decrypt_string(&encrypted2).unwrap(), plaintext);
308
309        // But not with the other key
310        assert!(encryption1.decrypt_string(&encrypted2).is_err());
311        assert!(encryption2.decrypt_string(&encrypted1).is_err());
312    }
313
314    #[test]
315    fn test_nonce_uniqueness() {
316        let key = [3u8; 32];
317        let encryption = FieldEncryption::new_with_key(&key).unwrap();
318
319        let plaintext = "same message";
320        let encrypted1 = encryption.encrypt_string(plaintext).unwrap();
321        let encrypted2 = encryption.encrypt_string(plaintext).unwrap();
322
323        // Same plaintext should produce different ciphertext due to random nonces
324        assert_ne!(encrypted1, encrypted2);
325
326        // Both should decrypt to the same plaintext
327        assert_eq!(encryption.decrypt_string(&encrypted1).unwrap(), plaintext);
328        assert_eq!(encryption.decrypt_string(&encrypted2).unwrap(), plaintext);
329    }
330
331    #[test]
332    fn test_invalid_encrypted_data() {
333        let key = [4u8; 32];
334        let encryption = FieldEncryption::new_with_key(&key).unwrap();
335
336        // Test with invalid base64
337        assert!(encryption.decrypt_string("invalid base64!").is_err());
338
339        // Test with valid base64 but invalid JSON inside
340        assert!(encryption
341            .decrypt_string(&base64_encode(b"not json"))
342            .is_err());
343
344        // Test with valid base64 but wrong JSON structure inside
345        let invalid_json_b64 = base64_encode(b"{\"wrong\": \"structure\"}");
346        assert!(encryption.decrypt_string(&invalid_json_b64).is_err());
347
348        // Test with plain JSON (old format) - should fail since we only accept base64
349        assert!(encryption
350            .decrypt_string(&base64_encode(
351                b"{\"nonce\":\"test\",\"ciphertext\":\"test\",\"version\":1}"
352            ))
353            .is_err());
354    }
355
356    #[test]
357    fn test_generate_key() {
358        let key1 = FieldEncryption::generate_key();
359        let key2 = FieldEncryption::generate_key();
360
361        // Keys should be different
362        assert_ne!(key1, key2);
363
364        // Keys should be valid base64
365        assert!(base64_decode(&key1).is_ok());
366        assert!(base64_decode(&key2).is_ok());
367
368        // Decoded keys should be 32 bytes
369        assert_eq!(base64_decode(&key1).unwrap().len(), 32);
370        assert_eq!(base64_decode(&key2).unwrap().len(), 32);
371    }
372
373    #[test]
374    fn test_env_key_loading() {
375        // Test base64 key
376        let test_key = FieldEncryption::generate_key();
377        env::set_var("STORAGE_ENCRYPTION_KEY", &test_key);
378
379        let encryption = FieldEncryption::new().unwrap();
380        let plaintext = "test message";
381        let encrypted = encryption.encrypt_string(plaintext).unwrap();
382        let decrypted = encryption.decrypt_string(&encrypted).unwrap();
383        assert_eq!(plaintext, decrypted);
384
385        // Test missing key
386        env::remove_var("STORAGE_ENCRYPTION_KEY");
387        assert!(FieldEncryption::new().is_err());
388
389        // Clean up
390        env::set_var("STORAGE_ENCRYPTION_KEY", &test_key);
391    }
392
393    #[test]
394    fn test_high_level_encryption_functions() {
395        let plaintext = "sensitive data";
396
397        // Test that the high-level encrypt/decrypt functions work together
398        let encoded = encrypt_sensitive_field(plaintext).unwrap();
399        let decoded = decrypt_sensitive_field(&encoded).unwrap();
400        assert_eq!(plaintext, decoded);
401
402        // All outputs should now be base64-encoded (whether encrypted or fallback)
403        assert!(base64_decode(&encoded).is_ok());
404
405        // Just verify it works - don't make assumptions about internal format
406        // since global encryption state may vary between test runs
407    }
408
409    #[test]
410    fn test_fallback_when_encryption_disabled() {
411        // Temporarily clear encryption key to test fallback
412        let old_key = env::var("STORAGE_ENCRYPTION_KEY").ok();
413
414        env::remove_var("STORAGE_ENCRYPTION_KEY");
415
416        let plaintext = "fallback test";
417
418        // Should use fallback mode (base64-encoded JSON)
419        let encoded = encrypt_sensitive_field(plaintext).unwrap();
420        let decoded = decrypt_sensitive_field(&encoded).unwrap();
421        assert_eq!(plaintext, decoded);
422
423        // Should be base64-encoded JSON
424        let expected_json = serde_json::to_string(plaintext).unwrap();
425        let expected_b64 = base64_encode(expected_json.as_bytes());
426        assert_eq!(encoded, expected_b64);
427
428        // Restore original environment
429        if let Some(key) = old_key {
430            env::set_var("STORAGE_ENCRYPTION_KEY", key);
431        }
432    }
433
434    #[test]
435    fn test_core_encryption_methods() {
436        let key = [9u8; 32];
437        let encryption = FieldEncryption::new_with_key(&key).unwrap();
438        let plaintext = "core encryption test";
439
440        // Test core encryption methods directly
441        let encrypted = encryption.encrypt_string(plaintext).unwrap();
442        let decrypted = encryption.decrypt_string(&encrypted).unwrap();
443        assert_eq!(plaintext, decrypted);
444
445        // Should be base64-encoded
446        assert!(base64_decode(&encrypted).is_ok());
447        // Should not contain readable structure
448        assert!(!encrypted.contains("nonce"));
449        assert!(!encrypted.contains("ciphertext"));
450        assert!(!encrypted.contains("{"));
451    }
452
453    #[test]
454    fn test_base64_encoding_hides_structure() {
455        let key = [7u8; 32];
456        let encryption = FieldEncryption::new_with_key(&key).unwrap();
457
458        let plaintext = "secret message";
459        let encrypted = encryption.encrypt_string(plaintext).unwrap();
460
461        // Should be valid base64
462        assert!(base64_decode(&encrypted).is_ok());
463
464        // Should not contain readable JSON structure
465        assert!(!encrypted.contains("nonce"));
466        assert!(!encrypted.contains("ciphertext"));
467        assert!(!encrypted.contains("version"));
468        assert!(!encrypted.contains("{"));
469        assert!(!encrypted.contains("}"));
470
471        // Should decrypt correctly
472        let decrypted = encryption.decrypt_string(&encrypted).unwrap();
473        assert_eq!(plaintext, decrypted);
474    }
475}