openzeppelin_relayer/services/aws_kms/
mod.rs

1//! # AWS KMS Service Module
2//!
3//! This module provides integration with AWS KMS for secure key management
4//! and cryptographic operations such as public key retrieval and message signing.
5//!
6//! Currently only EVM is supported.
7//!
8//! ## Features
9//!
10//! - Service account authentication using credential providers
11//! - Public key retrieval from KMS
12//! - Message signing via KMS
13//!
14//! ## Architecture
15//!
16//! ```text
17//! AwsKmsService (implements AwsKmsEvmService)
18//!   ├── Authentication (via AwsKmsClient)
19//!   ├── Public Key Retrieval (via AwsKmsClient)
20//!   └── Message Signing (via AwsKmsClient)
21//! ```
22//! is based on
23//! ```text
24//! AwsKmsClient (implements AwsKmsK256)
25//!   ├── Authentication (via shared credentials)
26//!   ├── Public Key Retrieval in DER Encoding
27//!   └── Message Digest Signing in DER Encoding
28//! ```
29//! `AwsKmsK256` is mocked with `mockall` for unit testing
30//! and injected into `AwsKmsService`
31//!
32
33use alloy::primitives::keccak256;
34use async_trait::async_trait;
35use aws_config::{meta::region::RegionProviderChain, BehaviorVersion, Region};
36use aws_sdk_kms::{
37    primitives::Blob,
38    types::{MessageType, SigningAlgorithmSpec},
39    Client,
40};
41use once_cell::sync::Lazy;
42use serde::Serialize;
43use std::collections::HashMap;
44use tokio::sync::RwLock;
45
46use crate::{
47    models::{Address, AwsKmsSignerConfig},
48    services::signer::evm::utils::recover_evm_signature_from_der,
49    utils::{self, derive_ethereum_address_from_der},
50};
51
52#[cfg(test)]
53use mockall::{automock, mock};
54
55#[derive(Clone, Debug, thiserror::Error, Serialize)]
56pub enum AwsKmsError {
57    #[error("AWS KMS response parse error: {0}")]
58    ParseError(String),
59    #[error("AWS KMS config error: {0}")]
60    ConfigError(String),
61    #[error("AWS KMS get error: {0}")]
62    GetError(String),
63    #[error("AWS KMS signing error: {0}")]
64    SignError(String),
65    #[error("AWS KMS permissions error: {0}")]
66    PermissionError(String),
67    #[error("AWS KMS public key error: {0}")]
68    RecoveryError(#[from] utils::Secp256k1Error),
69    #[error("AWS KMS conversion error: {0}")]
70    ConvertError(String),
71    #[error("AWS KMS Other error: {0}")]
72    Other(String),
73}
74
75pub type AwsKmsResult<T> = Result<T, AwsKmsError>;
76
77#[async_trait]
78#[cfg_attr(test, automock)]
79pub trait AwsKmsEvmService: Send + Sync {
80    /// Returns the EVM address derived from the configured public key.
81    async fn get_evm_address(&self) -> AwsKmsResult<Address>;
82    /// Signs a payload using the EVM signing scheme (hashes before signing).
83    ///
84    /// This method applies keccak256 hashing before signing.
85    ///
86    /// **Use for:**
87    /// - Raw transaction data (TxLegacy, TxEip1559)
88    /// - EIP-191 personal messages
89    ///
90    /// **Note:** For EIP-712 typed data, use `sign_hash_evm()` to avoid double-hashing.
91    async fn sign_payload_evm(&self, payload: &[u8]) -> AwsKmsResult<Vec<u8>>;
92
93    /// Signs a pre-computed hash using the EVM signing scheme (no hashing).
94    ///
95    /// This method signs the hash directly without applying keccak256.
96    ///
97    /// **Use for:**
98    /// - EIP-712 typed data (already hashed)
99    /// - Pre-computed message digests
100    ///
101    /// **Note:** For raw data, use `sign_payload_evm()` instead.
102    async fn sign_hash_evm(&self, hash: &[u8; 32]) -> AwsKmsResult<Vec<u8>>;
103}
104
105#[async_trait]
106#[cfg_attr(test, automock)]
107pub trait AwsKmsK256: Send + Sync {
108    /// Fetches the DER-encoded public key from AWS KMS.
109    async fn get_der_public_key<'a, 'b>(&'a self, key_id: &'b str) -> AwsKmsResult<Vec<u8>>;
110    /// Signs a digest using EcdsaSha256 spec. Returns DER-encoded signature
111    async fn sign_digest<'a, 'b>(
112        &'a self,
113        key_id: &'b str,
114        digest: [u8; 32],
115    ) -> AwsKmsResult<Vec<u8>>;
116}
117
118#[cfg(test)]
119mock! {
120    pub AwsKmsClient { }
121    impl Clone for AwsKmsClient {
122        fn clone(&self) -> Self;
123    }
124
125    #[async_trait]
126    impl AwsKmsK256 for AwsKmsClient {
127        async fn get_der_public_key<'a, 'b>(&'a self, key_id: &'b str) -> AwsKmsResult<Vec<u8>>;
128        async fn sign_digest<'a, 'b>(
129            &'a self,
130            key_id: &'b str,
131            digest: [u8; 32],
132        ) -> AwsKmsResult<Vec<u8>>;
133    }
134
135}
136
137// Global cache - HashMap keyed by kms_key_id
138static KMS_DER_PK_CACHE: Lazy<RwLock<HashMap<String, Vec<u8>>>> =
139    Lazy::new(|| RwLock::new(HashMap::new()));
140
141#[derive(Debug, Clone)]
142pub struct AwsKmsClient {
143    inner: Client,
144}
145
146#[async_trait]
147impl AwsKmsK256 for AwsKmsClient {
148    async fn get_der_public_key<'a, 'b>(&'a self, key_id: &'b str) -> AwsKmsResult<Vec<u8>> {
149        // Try cache first with minimal lock time
150        let cached = {
151            let cache_read = KMS_DER_PK_CACHE.read().await;
152            cache_read.get(key_id).cloned()
153        };
154        if let Some(cached) = cached {
155            return Ok(cached);
156        }
157
158        // Fetch from AWS KMS
159        let get_output = self
160            .inner
161            .get_public_key()
162            .key_id(key_id)
163            .send()
164            .await
165            .map_err(|e| AwsKmsError::GetError(e.to_string()))?;
166
167        let der_pk_blob = get_output
168            .public_key
169            .ok_or(AwsKmsError::GetError(
170                "No public key blob found".to_string(),
171            ))?
172            .into_inner();
173
174        // Cache the result
175        let mut cache_write = KMS_DER_PK_CACHE.write().await;
176        cache_write.insert(key_id.to_string(), der_pk_blob.clone());
177        drop(cache_write);
178
179        Ok(der_pk_blob)
180    }
181
182    async fn sign_digest<'a, 'b>(
183        &'a self,
184        key_id: &'b str,
185        digest: [u8; 32],
186    ) -> AwsKmsResult<Vec<u8>> {
187        // Sign the digest with the AWS KMS
188        let sign_result = self
189            .inner
190            .sign()
191            .key_id(key_id)
192            .signing_algorithm(SigningAlgorithmSpec::EcdsaSha256)
193            .message_type(MessageType::Digest)
194            .message(Blob::new(digest))
195            .send()
196            .await;
197
198        // Process the result, extract DER signature
199        let der_signature = sign_result
200            .map_err(|e| AwsKmsError::PermissionError(e.to_string()))?
201            .signature
202            .ok_or(AwsKmsError::SignError(
203                "Signature not found in response".to_string(),
204            ))?
205            .into_inner();
206
207        Ok(der_signature)
208    }
209}
210
211#[derive(Debug, Clone)]
212pub struct AwsKmsService<T: AwsKmsK256 + Clone = AwsKmsClient> {
213    pub kms_key_id: String,
214    client: T,
215}
216
217impl AwsKmsService<AwsKmsClient> {
218    pub async fn new(config: AwsKmsSignerConfig) -> AwsKmsResult<Self> {
219        let region_provider =
220            RegionProviderChain::first_try(config.region.map(Region::new)).or_default_provider();
221
222        let auth_config = aws_config::defaults(BehaviorVersion::latest())
223            .region(region_provider)
224            .load()
225            .await;
226        let client = AwsKmsClient {
227            inner: Client::new(&auth_config),
228        };
229
230        Ok(Self {
231            kms_key_id: config.key_id,
232            client,
233        })
234    }
235}
236
237#[cfg(test)]
238impl<T: AwsKmsK256 + Clone> AwsKmsService<T> {
239    pub fn new_for_testing(client: T, config: AwsKmsSignerConfig) -> Self {
240        Self {
241            client,
242            kms_key_id: config.key_id,
243        }
244    }
245}
246
247impl<T: AwsKmsK256 + Clone> AwsKmsService<T> {
248    /// Common signing logic for EVM signatures.
249    ///
250    /// This internal helper eliminates duplication between `sign_payload_evm` and `sign_hash_evm`.
251    ///
252    /// # Parameters
253    /// * `digest` - The 32-byte hash to sign
254    /// * `original_bytes` - The original message bytes for recovery verification (if applicable)
255    /// * `use_prehash_recovery` - If true, recovers using hash directly; if false, uses original bytes
256    async fn sign_and_recover_evm(
257        &self,
258        digest: [u8; 32],
259        original_bytes: &[u8],
260        use_prehash_recovery: bool,
261    ) -> AwsKmsResult<Vec<u8>> {
262        // Sign the digest with AWS KMS
263        let der_signature = self.client.sign_digest(&self.kms_key_id, digest).await?;
264
265        // Get public key
266        let der_pk = self.client.get_der_public_key(&self.kms_key_id).await?;
267
268        // Use shared signature recovery logic
269        recover_evm_signature_from_der(
270            &der_signature,
271            &der_pk,
272            digest,
273            original_bytes,
274            use_prehash_recovery,
275        )
276        .map_err(|e| AwsKmsError::ParseError(e.to_string()))
277    }
278
279    /// Signs a payload using the EVM signing scheme (hashes before signing).
280    ///
281    /// This method applies keccak256 hashing before signing.
282    ///
283    /// **Use for:**
284    /// - Raw transaction data (TxLegacy, TxEip1559)
285    /// - EIP-191 personal messages
286    ///
287    /// **Note:** For EIP-712 typed data, use `sign_hash_evm()` to avoid double-hashing.
288    pub async fn sign_payload_evm(&self, bytes: &[u8]) -> AwsKmsResult<Vec<u8>> {
289        let digest = keccak256(bytes).0;
290        self.sign_and_recover_evm(digest, bytes, false).await
291    }
292
293    /// Signs a pre-computed hash using the EVM signing scheme (no hashing).
294    ///
295    /// This method signs the hash directly without applying keccak256.
296    ///
297    /// **Use for:**
298    /// - EIP-712 typed data (already hashed)
299    /// - Pre-computed message digests
300    ///
301    /// **Note:** For raw data, use `sign_payload_evm()` instead.
302    pub async fn sign_hash_evm(&self, hash: &[u8; 32]) -> AwsKmsResult<Vec<u8>> {
303        self.sign_and_recover_evm(*hash, hash, true).await
304    }
305}
306
307#[async_trait]
308impl<T: AwsKmsK256 + Clone> AwsKmsEvmService for AwsKmsService<T> {
309    async fn get_evm_address(&self) -> AwsKmsResult<Address> {
310        let der = self.client.get_der_public_key(&self.kms_key_id).await?;
311        let eth_address = derive_ethereum_address_from_der(&der)
312            .map_err(|e| AwsKmsError::ParseError(e.to_string()))?;
313        Ok(Address::Evm(eth_address))
314    }
315
316    async fn sign_payload_evm(&self, message: &[u8]) -> AwsKmsResult<Vec<u8>> {
317        let digest = keccak256(message).0;
318        self.sign_and_recover_evm(digest, message, false).await
319    }
320
321    async fn sign_hash_evm(&self, hash: &[u8; 32]) -> AwsKmsResult<Vec<u8>> {
322        // Delegates to the implementation method on AwsKmsService
323        self.sign_and_recover_evm(*hash, hash, true).await
324    }
325}
326
327#[cfg(test)]
328pub mod tests {
329    use super::*;
330
331    use alloy::primitives::utils::eip191_message;
332    use k256::{
333        ecdsa::SigningKey,
334        elliptic_curve::rand_core::OsRng,
335        pkcs8::{der::Encode, EncodePublicKey},
336    };
337    use mockall::predicate::{eq, ne};
338
339    pub fn setup_mock_kms_client() -> (MockAwsKmsClient, SigningKey) {
340        let mut client = MockAwsKmsClient::new();
341        let signing_key = SigningKey::random(&mut OsRng);
342        let s = signing_key
343            .verifying_key()
344            .to_public_key_der()
345            .unwrap()
346            .to_der()
347            .unwrap();
348
349        client
350            .expect_get_der_public_key()
351            .with(eq("test-key-id"))
352            .return_const(Ok(s));
353        client
354            .expect_get_der_public_key()
355            .with(ne("test-key-id"))
356            .return_const(Err(AwsKmsError::GetError("Key does not exist".to_string())));
357
358        client
359            .expect_sign_digest()
360            .withf(|key_id, _| key_id.ne("test-key-id"))
361            .return_const(Err(AwsKmsError::SignError(
362                "Key does not exist".to_string(),
363            )));
364
365        let key = signing_key.clone();
366        client
367            .expect_sign_digest()
368            .withf(|key_id, _| key_id.eq("test-key-id"))
369            .returning(move |_, digest| {
370                let (signature, _) = signing_key
371                    .sign_prehash_recoverable(&digest)
372                    .map_err(|e| AwsKmsError::SignError(e.to_string()))?;
373                let der_signature = signature.to_der().as_bytes().to_vec();
374                Ok(der_signature)
375            });
376
377        client.expect_clone().return_once(MockAwsKmsClient::new);
378
379        (client, key)
380    }
381
382    #[tokio::test]
383    async fn test_get_public_key() {
384        let (mock_client, key) = setup_mock_kms_client();
385        let kms = AwsKmsService::new_for_testing(
386            mock_client,
387            AwsKmsSignerConfig {
388                region: Some("us-east-1".to_string()),
389                key_id: "test-key-id".to_string(),
390            },
391        );
392
393        let result = kms.get_evm_address().await;
394        assert!(result.is_ok());
395        if let Ok(Address::Evm(evm_address)) = result {
396            let expected_address = derive_ethereum_address_from_der(
397                key.verifying_key().to_public_key_der().unwrap().as_bytes(),
398            )
399            .unwrap();
400            assert_eq!(expected_address, evm_address);
401        }
402    }
403
404    #[tokio::test]
405    async fn test_get_public_key_fail() {
406        let (mock_client, _) = setup_mock_kms_client();
407        let kms = AwsKmsService::new_for_testing(
408            mock_client,
409            AwsKmsSignerConfig {
410                region: Some("us-east-1".to_string()),
411                key_id: "invalid-key-id".to_string(),
412            },
413        );
414
415        let result = kms.get_evm_address().await;
416        assert!(result.is_err());
417        if let Err(err) = result {
418            assert!(matches!(err, AwsKmsError::GetError(_)))
419        }
420    }
421
422    #[tokio::test]
423    async fn test_sign_digest() {
424        let (mock_client, _) = setup_mock_kms_client();
425        let kms = AwsKmsService::new_for_testing(
426            mock_client,
427            AwsKmsSignerConfig {
428                region: Some("us-east-1".to_string()),
429                key_id: "test-key-id".to_string(),
430            },
431        );
432
433        let message_eip = eip191_message(b"Hello World!");
434        let result = kms.sign_payload_evm(&message_eip).await;
435
436        // We just assert for Ok, since the pubkey recovery indicates the validity of signature
437        assert!(result.is_ok());
438    }
439
440    #[tokio::test]
441    async fn test_sign_digest_fail() {
442        let (mock_client, _) = setup_mock_kms_client();
443        let kms = AwsKmsService::new_for_testing(
444            mock_client,
445            AwsKmsSignerConfig {
446                region: Some("us-east-1".to_string()),
447                key_id: "invalid-key-id".to_string(),
448            },
449        );
450
451        let message_eip = eip191_message(b"Hello World!");
452        let result = kms.sign_payload_evm(&message_eip).await;
453        assert!(result.is_err());
454        if let Err(err) = result {
455            assert!(matches!(err, AwsKmsError::SignError(_)))
456        }
457    }
458}