Skip to content

Commit

Permalink
[kms] Adjust Decrypt key validity checking
Browse files Browse the repository at this point in the history
  • Loading branch information
dzbarsky committed Dec 1, 2023
1 parent 2d0c2f0 commit b462db8
Showing 1 changed file with 28 additions and 11 deletions.
39 changes: 28 additions & 11 deletions services/kms/kms.go
Original file line number Diff line number Diff line change
Expand Up @@ -647,6 +647,9 @@ func (k *KMS) GenerateDataKeyPair(input GenerateDataKeyPairInput) (*GenerateData
default:
return nil, ValidationException("Unknown value for KeyPair Spec")
}
if err != nil {
return nil, KMSInternalException(err.Error())
}

serializedPublicKey, err := x509.MarshalPKIXPublicKey(pkey.Public())
if err != nil {
Expand Down Expand Up @@ -839,30 +842,29 @@ func (k *KMS) Decrypt(input DecryptInput) (*DecryptOutput, *awserrors.Error) {
input.EncryptionAlgorithm = "SYMMETRIC_DEFAULT"
}

keyArn := input.KeyId
keyId := input.KeyId

ciphertext := input.CiphertextBlob
if len(ciphertext) == 0 {
return nil, InvalidCiphertextException("")
}

k.mu.Lock()
defer k.mu.Unlock()

encryptionKey := k.lockedGetKey(keyArn)

if keyArn == "" || encryptionKey.IsAES() {
// AES can pack keyId into the ciphertext
// This logic is the opposite of Key.Encrypt
if keyId == "" {
// Passing KeyId is optional for symmetric encyption.
// Run the opposite of Key.Encrypt to unpack the key that was used
data := ciphertext
keyArnLen, data := uint8(data[0]), data[1:]
if len(data) < 4+int(keyArnLen) {
return nil, InvalidCiphertextException("")
}
keyArn, ciphertext = string(data[:keyArnLen]), data[keyArnLen:]
keyId = string(data[:keyArnLen])
}

encryptionKey = k.lockedGetKey(keyArn)
k.mu.Lock()
defer k.mu.Unlock()

encryptionKey := k.lockedGetKey(keyId)

if encryptionKey == nil {
return nil, NotFoundException("")
}
Expand All @@ -871,6 +873,21 @@ func (k *KMS) Decrypt(input DecryptInput) (*DecryptOutput, *awserrors.Error) {
return nil, DisabledException("")
}

if encryptionKey.IsAES() {
if input.KeyId != "" && input.KeyId != encryptionKey.Id() {
// TODO(zbarsky): we should flag this
// return nil, IncorrectKeyException("")
k.logger.Info("WRONG aes key", "expected", input.KeyId, "actual", encryptionKey.Id())
}
// AES can pack keyId into the ciphertext
// This logic is the opposite of Key.Encrypt
keyArnLen := uint8(ciphertext[0])
if len(ciphertext) < 5+int(keyArnLen) {
return nil, InvalidCiphertextException("")
}
ciphertext = ciphertext[keyArnLen+1:]
}

plaintext, err := encryptionKey.Decrypt(ciphertext, input.EncryptionAlgorithm, input.EncryptionContext)
if err != nil {
if errors.Is(err, key.ErrBadAlgorithm) {
Expand Down

0 comments on commit b462db8

Please sign in to comment.