openzeppelin_relayer/services/aws_kms/
mod.rs1use alloy::primitives::keccak256;
34use async_trait::async_trait;
35use aws_config::{meta::region::RegionProviderChain, BehaviorVersion, Region};
36use aws_sdk_kms::{
37 primitives::Blob,
38 types::{MessageType, SigningAlgorithmSpec},
39 Client,
40};
41use once_cell::sync::Lazy;
42use serde::Serialize;
43use std::collections::HashMap;
44use tokio::sync::RwLock;
45
46use crate::{
47 models::{Address, AwsKmsSignerConfig},
48 services::signer::evm::utils::recover_evm_signature_from_der,
49 utils::{self, derive_ethereum_address_from_der},
50};
51
52#[cfg(test)]
53use mockall::{automock, mock};
54
55#[derive(Clone, Debug, thiserror::Error, Serialize)]
56pub enum AwsKmsError {
57 #[error("AWS KMS response parse error: {0}")]
58 ParseError(String),
59 #[error("AWS KMS config error: {0}")]
60 ConfigError(String),
61 #[error("AWS KMS get error: {0}")]
62 GetError(String),
63 #[error("AWS KMS signing error: {0}")]
64 SignError(String),
65 #[error("AWS KMS permissions error: {0}")]
66 PermissionError(String),
67 #[error("AWS KMS public key error: {0}")]
68 RecoveryError(#[from] utils::Secp256k1Error),
69 #[error("AWS KMS conversion error: {0}")]
70 ConvertError(String),
71 #[error("AWS KMS Other error: {0}")]
72 Other(String),
73}
74
75pub type AwsKmsResult<T> = Result<T, AwsKmsError>;
76
77#[async_trait]
78#[cfg_attr(test, automock)]
79pub trait AwsKmsEvmService: Send + Sync {
80 async fn get_evm_address(&self) -> AwsKmsResult<Address>;
82 async fn sign_payload_evm(&self, payload: &[u8]) -> AwsKmsResult<Vec<u8>>;
92
93 async fn sign_hash_evm(&self, hash: &[u8; 32]) -> AwsKmsResult<Vec<u8>>;
103}
104
105#[async_trait]
106#[cfg_attr(test, automock)]
107pub trait AwsKmsK256: Send + Sync {
108 async fn get_der_public_key<'a, 'b>(&'a self, key_id: &'b str) -> AwsKmsResult<Vec<u8>>;
110 async fn sign_digest<'a, 'b>(
112 &'a self,
113 key_id: &'b str,
114 digest: [u8; 32],
115 ) -> AwsKmsResult<Vec<u8>>;
116}
117
118#[cfg(test)]
119mock! {
120 pub AwsKmsClient { }
121 impl Clone for AwsKmsClient {
122 fn clone(&self) -> Self;
123 }
124
125 #[async_trait]
126 impl AwsKmsK256 for AwsKmsClient {
127 async fn get_der_public_key<'a, 'b>(&'a self, key_id: &'b str) -> AwsKmsResult<Vec<u8>>;
128 async fn sign_digest<'a, 'b>(
129 &'a self,
130 key_id: &'b str,
131 digest: [u8; 32],
132 ) -> AwsKmsResult<Vec<u8>>;
133 }
134
135}
136
137static KMS_DER_PK_CACHE: Lazy<RwLock<HashMap<String, Vec<u8>>>> =
139 Lazy::new(|| RwLock::new(HashMap::new()));
140
141#[derive(Debug, Clone)]
142pub struct AwsKmsClient {
143 inner: Client,
144}
145
146#[async_trait]
147impl AwsKmsK256 for AwsKmsClient {
148 async fn get_der_public_key<'a, 'b>(&'a self, key_id: &'b str) -> AwsKmsResult<Vec<u8>> {
149 let cached = {
151 let cache_read = KMS_DER_PK_CACHE.read().await;
152 cache_read.get(key_id).cloned()
153 };
154 if let Some(cached) = cached {
155 return Ok(cached);
156 }
157
158 let get_output = self
160 .inner
161 .get_public_key()
162 .key_id(key_id)
163 .send()
164 .await
165 .map_err(|e| AwsKmsError::GetError(e.to_string()))?;
166
167 let der_pk_blob = get_output
168 .public_key
169 .ok_or(AwsKmsError::GetError(
170 "No public key blob found".to_string(),
171 ))?
172 .into_inner();
173
174 let mut cache_write = KMS_DER_PK_CACHE.write().await;
176 cache_write.insert(key_id.to_string(), der_pk_blob.clone());
177 drop(cache_write);
178
179 Ok(der_pk_blob)
180 }
181
182 async fn sign_digest<'a, 'b>(
183 &'a self,
184 key_id: &'b str,
185 digest: [u8; 32],
186 ) -> AwsKmsResult<Vec<u8>> {
187 let sign_result = self
189 .inner
190 .sign()
191 .key_id(key_id)
192 .signing_algorithm(SigningAlgorithmSpec::EcdsaSha256)
193 .message_type(MessageType::Digest)
194 .message(Blob::new(digest))
195 .send()
196 .await;
197
198 let der_signature = sign_result
200 .map_err(|e| AwsKmsError::PermissionError(e.to_string()))?
201 .signature
202 .ok_or(AwsKmsError::SignError(
203 "Signature not found in response".to_string(),
204 ))?
205 .into_inner();
206
207 Ok(der_signature)
208 }
209}
210
211#[derive(Debug, Clone)]
212pub struct AwsKmsService<T: AwsKmsK256 + Clone = AwsKmsClient> {
213 pub kms_key_id: String,
214 client: T,
215}
216
217impl AwsKmsService<AwsKmsClient> {
218 pub async fn new(config: AwsKmsSignerConfig) -> AwsKmsResult<Self> {
219 let region_provider =
220 RegionProviderChain::first_try(config.region.map(Region::new)).or_default_provider();
221
222 let auth_config = aws_config::defaults(BehaviorVersion::latest())
223 .region(region_provider)
224 .load()
225 .await;
226 let client = AwsKmsClient {
227 inner: Client::new(&auth_config),
228 };
229
230 Ok(Self {
231 kms_key_id: config.key_id,
232 client,
233 })
234 }
235}
236
237#[cfg(test)]
238impl<T: AwsKmsK256 + Clone> AwsKmsService<T> {
239 pub fn new_for_testing(client: T, config: AwsKmsSignerConfig) -> Self {
240 Self {
241 client,
242 kms_key_id: config.key_id,
243 }
244 }
245}
246
247impl<T: AwsKmsK256 + Clone> AwsKmsService<T> {
248 async fn sign_and_recover_evm(
257 &self,
258 digest: [u8; 32],
259 original_bytes: &[u8],
260 use_prehash_recovery: bool,
261 ) -> AwsKmsResult<Vec<u8>> {
262 let der_signature = self.client.sign_digest(&self.kms_key_id, digest).await?;
264
265 let der_pk = self.client.get_der_public_key(&self.kms_key_id).await?;
267
268 recover_evm_signature_from_der(
270 &der_signature,
271 &der_pk,
272 digest,
273 original_bytes,
274 use_prehash_recovery,
275 )
276 .map_err(|e| AwsKmsError::ParseError(e.to_string()))
277 }
278
279 pub async fn sign_payload_evm(&self, bytes: &[u8]) -> AwsKmsResult<Vec<u8>> {
289 let digest = keccak256(bytes).0;
290 self.sign_and_recover_evm(digest, bytes, false).await
291 }
292
293 pub async fn sign_hash_evm(&self, hash: &[u8; 32]) -> AwsKmsResult<Vec<u8>> {
303 self.sign_and_recover_evm(*hash, hash, true).await
304 }
305}
306
307#[async_trait]
308impl<T: AwsKmsK256 + Clone> AwsKmsEvmService for AwsKmsService<T> {
309 async fn get_evm_address(&self) -> AwsKmsResult<Address> {
310 let der = self.client.get_der_public_key(&self.kms_key_id).await?;
311 let eth_address = derive_ethereum_address_from_der(&der)
312 .map_err(|e| AwsKmsError::ParseError(e.to_string()))?;
313 Ok(Address::Evm(eth_address))
314 }
315
316 async fn sign_payload_evm(&self, message: &[u8]) -> AwsKmsResult<Vec<u8>> {
317 let digest = keccak256(message).0;
318 self.sign_and_recover_evm(digest, message, false).await
319 }
320
321 async fn sign_hash_evm(&self, hash: &[u8; 32]) -> AwsKmsResult<Vec<u8>> {
322 self.sign_and_recover_evm(*hash, hash, true).await
324 }
325}
326
327#[cfg(test)]
328pub mod tests {
329 use super::*;
330
331 use alloy::primitives::utils::eip191_message;
332 use k256::{
333 ecdsa::SigningKey,
334 elliptic_curve::rand_core::OsRng,
335 pkcs8::{der::Encode, EncodePublicKey},
336 };
337 use mockall::predicate::{eq, ne};
338
339 pub fn setup_mock_kms_client() -> (MockAwsKmsClient, SigningKey) {
340 let mut client = MockAwsKmsClient::new();
341 let signing_key = SigningKey::random(&mut OsRng);
342 let s = signing_key
343 .verifying_key()
344 .to_public_key_der()
345 .unwrap()
346 .to_der()
347 .unwrap();
348
349 client
350 .expect_get_der_public_key()
351 .with(eq("test-key-id"))
352 .return_const(Ok(s));
353 client
354 .expect_get_der_public_key()
355 .with(ne("test-key-id"))
356 .return_const(Err(AwsKmsError::GetError("Key does not exist".to_string())));
357
358 client
359 .expect_sign_digest()
360 .withf(|key_id, _| key_id.ne("test-key-id"))
361 .return_const(Err(AwsKmsError::SignError(
362 "Key does not exist".to_string(),
363 )));
364
365 let key = signing_key.clone();
366 client
367 .expect_sign_digest()
368 .withf(|key_id, _| key_id.eq("test-key-id"))
369 .returning(move |_, digest| {
370 let (signature, _) = signing_key
371 .sign_prehash_recoverable(&digest)
372 .map_err(|e| AwsKmsError::SignError(e.to_string()))?;
373 let der_signature = signature.to_der().as_bytes().to_vec();
374 Ok(der_signature)
375 });
376
377 client.expect_clone().return_once(MockAwsKmsClient::new);
378
379 (client, key)
380 }
381
382 #[tokio::test]
383 async fn test_get_public_key() {
384 let (mock_client, key) = setup_mock_kms_client();
385 let kms = AwsKmsService::new_for_testing(
386 mock_client,
387 AwsKmsSignerConfig {
388 region: Some("us-east-1".to_string()),
389 key_id: "test-key-id".to_string(),
390 },
391 );
392
393 let result = kms.get_evm_address().await;
394 assert!(result.is_ok());
395 if let Ok(Address::Evm(evm_address)) = result {
396 let expected_address = derive_ethereum_address_from_der(
397 key.verifying_key().to_public_key_der().unwrap().as_bytes(),
398 )
399 .unwrap();
400 assert_eq!(expected_address, evm_address);
401 }
402 }
403
404 #[tokio::test]
405 async fn test_get_public_key_fail() {
406 let (mock_client, _) = setup_mock_kms_client();
407 let kms = AwsKmsService::new_for_testing(
408 mock_client,
409 AwsKmsSignerConfig {
410 region: Some("us-east-1".to_string()),
411 key_id: "invalid-key-id".to_string(),
412 },
413 );
414
415 let result = kms.get_evm_address().await;
416 assert!(result.is_err());
417 if let Err(err) = result {
418 assert!(matches!(err, AwsKmsError::GetError(_)))
419 }
420 }
421
422 #[tokio::test]
423 async fn test_sign_digest() {
424 let (mock_client, _) = setup_mock_kms_client();
425 let kms = AwsKmsService::new_for_testing(
426 mock_client,
427 AwsKmsSignerConfig {
428 region: Some("us-east-1".to_string()),
429 key_id: "test-key-id".to_string(),
430 },
431 );
432
433 let message_eip = eip191_message(b"Hello World!");
434 let result = kms.sign_payload_evm(&message_eip).await;
435
436 assert!(result.is_ok());
438 }
439
440 #[tokio::test]
441 async fn test_sign_digest_fail() {
442 let (mock_client, _) = setup_mock_kms_client();
443 let kms = AwsKmsService::new_for_testing(
444 mock_client,
445 AwsKmsSignerConfig {
446 region: Some("us-east-1".to_string()),
447 key_id: "invalid-key-id".to_string(),
448 },
449 );
450
451 let message_eip = eip191_message(b"Hello World!");
452 let result = kms.sign_payload_evm(&message_eip).await;
453 assert!(result.is_err());
454 if let Err(err) = result {
455 assert!(matches!(err, AwsKmsError::SignError(_)))
456 }
457 }
458}