Skip to content

Commit

Permalink
WIP ECDH
Browse files Browse the repository at this point in the history
  • Loading branch information
whyoleg committed Jul 27, 2024
1 parent c96af43 commit bb815af
Show file tree
Hide file tree
Showing 7 changed files with 60 additions and 49 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,13 @@ public interface ECDH : EC<ECDH.PublicKey, ECDH.PrivateKey, ECDH.KeyPair> {

@SubclassOptInRequired(CryptographyProviderApi::class)
public interface PublicKey : EC.PublicKey {
public fun sharedSecretDerivation(): SharedSecretDerivation<PrivateKey>
public fun secretDerivation(): SecretDerivation<PrivateKey>
public fun asyncSecretDerivation(): AsyncSecretDerivation<PrivateKey>
}

@SubclassOptInRequired(CryptographyProviderApi::class)
public interface PrivateKey : EC.PrivateKey {
public fun sharedSecretDerivation(): SharedSecretDerivation<PublicKey>
public fun secretDerivation(): SecretDerivation<PublicKey>
public fun asyncSecretDerivation(): AsyncSecretDerivation<PublicKey>
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
/*
* Copyright (c) 2024 Oleg Yukhnevich. Use of this source code is governed by the Apache 2.0 license.
*/

package dev.whyoleg.cryptography.operations

import dev.whyoleg.cryptography.*
import dev.whyoleg.cryptography.materials.key.*

@SubclassOptInRequired(CryptographyProviderApi::class)
public interface SecretDerivation<K : Key> {
public fun deriveSecret(other: K): ByteArray
}

@SubclassOptInRequired(CryptographyProviderApi::class)
public interface AsyncSecretDerivation<K : Key> {
public suspend fun deriveSecret(other: K): ByteArray
}

@CryptographyProviderApi
public fun <K : Key> SecretDerivation<K>.asAsync(): AsyncSecretDerivation<K> = object : AsyncSecretDerivation<K> {
override suspend fun deriveSecret(other: K): ByteArray = this@asAsync.deriveSecret(other)
}

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,10 @@ abstract class EcdhCompatibilityTest(
) { otherKeyPair, otherKeyReference, _ ->

val secrets = listOf(
keyPair.privateKey.sharedSecretDerivation().deriveSharedSecret(otherKeyPair.publicKey),
keyPair.publicKey.sharedSecretDerivation().deriveSharedSecret(otherKeyPair.privateKey),
otherKeyPair.privateKey.sharedSecretDerivation().deriveSharedSecret(keyPair.publicKey),
otherKeyPair.publicKey.sharedSecretDerivation().deriveSharedSecret(keyPair.privateKey),
keyPair.privateKey.secretDerivation().deriveSharedSecret(otherKeyPair.publicKey),
keyPair.publicKey.secretDerivation().deriveSharedSecret(otherKeyPair.privateKey),
otherKeyPair.privateKey.secretDerivation().deriveSharedSecret(keyPair.publicKey),
otherKeyPair.publicKey.secretDerivation().deriveSharedSecret(keyPair.privateKey),
)

repeat(secrets.size) { i ->
Expand Down Expand Up @@ -69,12 +69,12 @@ abstract class EcdhCompatibilityTest(
otherPrivateKeys.forEach { otherPrivateKey ->
assertContentEquals(
sharedSecret,
publicKey.sharedSecretDerivation().deriveSharedSecret(otherPrivateKey),
publicKey.secretDerivation().deriveSharedSecret(otherPrivateKey),
"Public + Other Private"
)
assertContentEquals(
sharedSecret,
otherPrivateKey.sharedSecretDerivation().deriveSharedSecret(publicKey),
otherPrivateKey.secretDerivation().deriveSharedSecret(publicKey),
"Other Private + Public"
)
}
Expand All @@ -83,12 +83,12 @@ abstract class EcdhCompatibilityTest(
otherPublicKeys.forEach { otherPublicKey ->
assertContentEquals(
sharedSecret,
otherPublicKey.sharedSecretDerivation().deriveSharedSecret(privateKey),
otherPublicKey.secretDerivation().deriveSharedSecret(privateKey),
"Other Public + Private"
)
assertContentEquals(
sharedSecret,
privateKey.sharedSecretDerivation().deriveSharedSecret(otherPublicKey),
privateKey.secretDerivation().deriveSharedSecret(otherPublicKey),
"Private + Other Public"
)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,23 +22,26 @@ internal class JdkEcdh(state: JdkCryptographyState) : JdkEc<ECDH.PublicKey, ECDH
private val state: JdkCryptographyState,
val key: JPublicKey,
) : ECDH.PublicKey, BaseEcPublicKey(key) {
override fun sharedSecretDerivation(): SharedSecretDerivation<ECDH.PrivateKey> = EcdhPublicKeySecretDerivation(state, key)
override fun secretDerivation(): SecretDerivation<ECDH.PrivateKey> = EcdhPublicKeySecretDerivation(state, key)
override fun asyncSecretDerivation(): AsyncSecretDerivation<ECDH.PrivateKey> {
TODO("Not yet implemented")
}
}

private class EcdhPrivateKey(
private val state: JdkCryptographyState,
val key: JPrivateKey,
) : ECDH.PrivateKey, BaseEcPrivateKey(key) {
override fun sharedSecretDerivation(): SharedSecretDerivation<ECDH.PublicKey> = EcdhPrivateKeySecretDerivation(state, key)
override fun secretDerivation(): SecretDerivation<ECDH.PublicKey> = EcdhPrivateKeySecretDerivation(state, key)
}

private class EcdhPublicKeySecretDerivation(
private val state: JdkCryptographyState,
private val publicKey: JPublicKey,
) : SharedSecretDerivation<ECDH.PrivateKey> {
) : SecretDerivation<ECDH.PrivateKey> {
private val keyAgreement = state.keyAgreement("ECDH")

override fun deriveSharedSecretBlocking(other: ECDH.PrivateKey): ByteArray {
override fun deriveSecret(other: ECDH.PrivateKey): ByteArray {
check(other is EcdhPrivateKey) { "Only ${EcdhPrivateKey::class} supported" }

return keyAgreement.use {
Expand All @@ -48,16 +51,16 @@ internal class JdkEcdh(state: JdkCryptographyState) : JdkEc<ECDH.PublicKey, ECDH
}
}

override suspend fun deriveSharedSecret(other: ECDH.PrivateKey): ByteArray = deriveSharedSecretBlocking(other)
override suspend fun deriveSharedSecret(other: ECDH.PrivateKey): ByteArray = deriveSecret(other)
}

private class EcdhPrivateKeySecretDerivation(
private val state: JdkCryptographyState,
private val privateKey: JPrivateKey,
) : SharedSecretDerivation<ECDH.PublicKey> {
) : SecretDerivation<ECDH.PublicKey> {
private val keyAgreement = state.keyAgreement("ECDH")

override fun deriveSharedSecretBlocking(other: ECDH.PublicKey): ByteArray {
override fun deriveSecret(other: ECDH.PublicKey): ByteArray {
check(other is EcdhPublicKey) { "Only ${EcdhPublicKey::class} supported" }

return keyAgreement.use {
Expand All @@ -67,6 +70,6 @@ internal class JdkEcdh(state: JdkCryptographyState) : JdkEc<ECDH.PublicKey, ECDH
}
}

override suspend fun deriveSharedSecret(other: ECDH.PublicKey): ByteArray = deriveSharedSecretBlocking(other)
override suspend fun deriveSharedSecret(other: ECDH.PublicKey): ByteArray = deriveSecret(other)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ internal object Openssl3Ecdh : ECDH {

private class EcPrivateKey(
key: CPointer<EVP_PKEY>,
) : ECDH.PrivateKey, Openssl3PrivateKeyEncodable<EC.PrivateKey.Format>(key), SharedSecretDerivation<ECDH.PublicKey> {
) : ECDH.PrivateKey, Openssl3PrivateKeyEncodable<EC.PrivateKey.Format>(key), SecretDerivation<ECDH.PublicKey> {
override fun outputType(format: EC.PrivateKey.Format): String = when (format) {
EC.PrivateKey.Format.DER, EC.PrivateKey.Format.DER.SEC1 -> "DER"
EC.PrivateKey.Format.PEM, EC.PrivateKey.Format.PEM.SEC1 -> "PEM"
Expand All @@ -94,20 +94,18 @@ internal object Openssl3Ecdh : ECDH {
else -> super.outputStruct(format)
}

override fun sharedSecretDerivation(): SharedSecretDerivation<ECDH.PublicKey> = this
override fun secretDerivation(): SecretDerivation<ECDH.PublicKey> = this
override fun asyncSecretDerivation(): AsyncSecretDerivation<ECDH.PublicKey> = asAsync()

override fun deriveSharedSecretBlocking(other: ECDH.PublicKey): ByteArray {
override fun deriveSecret(other: ECDH.PublicKey): ByteArray {
check(other is EcPublicKey)

return deriveSharedSecret(publicKey = other.key, privateKey = key)
}

override suspend fun deriveSharedSecret(other: ECDH.PublicKey): ByteArray = deriveSharedSecretBlocking(other)
}

private class EcPublicKey(
key: CPointer<EVP_PKEY>,
) : ECDH.PublicKey, Openssl3PublicKeyEncodable<EC.PublicKey.Format>(key), SharedSecretDerivation<ECDH.PrivateKey> {
) : ECDH.PublicKey, Openssl3PublicKeyEncodable<EC.PublicKey.Format>(key), SecretDerivation<ECDH.PrivateKey> {
override fun outputType(format: EC.PublicKey.Format): String = when (format) {
EC.PublicKey.Format.DER -> "DER"
EC.PublicKey.Format.PEM -> "PEM"
Expand All @@ -120,15 +118,14 @@ internal object Openssl3Ecdh : ECDH {
else -> super.encodeToBlocking(format)
}

override fun sharedSecretDerivation(): SharedSecretDerivation<ECDH.PrivateKey> = this
override fun secretDerivation(): SecretDerivation<ECDH.PrivateKey> = this
override fun asyncSecretDerivation(): AsyncSecretDerivation<ECDH.PrivateKey> = asAsync()

override fun deriveSharedSecretBlocking(other: ECDH.PrivateKey): ByteArray {
override fun deriveSecret(other: ECDH.PrivateKey): ByteArray {
check(other is EcPrivateKey)

return deriveSharedSecret(publicKey = key, privateKey = other.key)
}

override suspend fun deriveSharedSecret(other: ECDH.PrivateKey): ByteArray = deriveSharedSecretBlocking(other)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@ internal object WebCryptoEcdh : WebCryptoEc<ECDH.PublicKey, ECDH.PrivateKey, ECD

private class EcdhPublicKey(
publicKey: CryptoKey,
) : EcPublicKey(publicKey), ECDH.PublicKey, SharedSecretDerivation<ECDH.PrivateKey> {
override fun sharedSecretDerivation(): SharedSecretDerivation<ECDH.PrivateKey> = this
) : EcPublicKey(publicKey), ECDH.PublicKey, SecretDerivation<ECDH.PrivateKey> {
override fun secretDerivation(): SecretDerivation<ECDH.PrivateKey> = this

override suspend fun deriveSharedSecret(other: ECDH.PrivateKey): ByteArray {
check(other is EcdhPrivateKey)
Expand All @@ -34,13 +34,13 @@ internal object WebCryptoEcdh : WebCryptoEc<ECDH.PublicKey, ECDH.PrivateKey, ECD
)
}

override fun deriveSharedSecretBlocking(other: ECDH.PrivateKey): ByteArray = nonBlocking()
override fun deriveSecret(other: ECDH.PrivateKey): ByteArray = nonBlocking()
}

private class EcdhPrivateKey(
privateKey: CryptoKey,
) : EcPrivateKey(privateKey), ECDH.PrivateKey, SharedSecretDerivation<ECDH.PublicKey> {
override fun sharedSecretDerivation(): SharedSecretDerivation<ECDH.PublicKey> = this
) : EcPrivateKey(privateKey), ECDH.PrivateKey, SecretDerivation<ECDH.PublicKey> {
override fun secretDerivation(): SecretDerivation<ECDH.PublicKey> = this
override suspend fun deriveSharedSecret(other: ECDH.PublicKey): ByteArray {
check(other is EcdhPublicKey)
return WebCrypto.deriveBits(
Expand All @@ -50,6 +50,6 @@ internal object WebCryptoEcdh : WebCryptoEc<ECDH.PublicKey, ECDH.PrivateKey, ECD
)
}

override fun deriveSharedSecretBlocking(other: ECDH.PublicKey): ByteArray = nonBlocking()
override fun deriveSecret(other: ECDH.PublicKey): ByteArray = nonBlocking()
}
}

0 comments on commit bb815af

Please sign in to comment.