diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 448b70ce..d89bfaaf 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -180,10 +180,26 @@ jobs: os: ubuntu-latest python-version: '3.11' opt-deps: ['brotli', 'zstd'] - - name: py3.12with brotli and zstandard + - name: py3.12 with brotli and zstandard os: ubuntu-latest python-version: '3.12' opt-deps: ['brotli', 'zstd'] + - name: py3.9 with kyber-py + os: ubuntu-latest + python-version: "3.9" + opt-deps: ["kyber_py"] + - name: py3.10 with kyber-py + os: ubuntu-latest + python-version: "3.10" + opt-deps: ["kyber_py"] + - name: py3.11 with kyber-py + os: ubuntu-latest + python-version: "3.11" + opt-deps: ["kyber_py"] + - name: py3.12 with kyber-py + os: ubuntu-latest + python-version: "3.12" + opt-deps: ["kyber_py"] # finally test with multiple dependencies installed at the same time - name: py2.7 with m2crypto, pycrypto, gmpy, gmpy2, and brotli os: ubuntu-20.04 @@ -204,22 +220,22 @@ jobs: - name: py3.9 with m2crypto, gmpy, gmpy2, brotli, and zstandard os: ubuntu-latest python-version: 3.9 - opt-deps: ['m2crypto', 'gmpy', 'gmpy2', 'brotli', 'zstd'] + opt-deps: ['m2crypto', 'gmpy', 'gmpy2', 'brotli', 'zstd', 'kyber_py'] - name: py3.10 with m2crypto, gmpy, gmpy2, brotli, and zstandard os: ubuntu-latest python-version: '3.10' - opt-deps: ['m2crypto', 'gmpy', 'gmpy2', 'brotli', 'zstd'] + opt-deps: ['m2crypto', 'gmpy', 'gmpy2', 'brotli', 'zstd', 'kyber_py'] - name: py3.11 with m2crypto, gmpy, gmpy2, brotli, and zstandard os: ubuntu-latest python-version: '3.11' # gmpy doesn't build with 3.11 - opt-deps: ['m2crypto', 'gmpy2', 'brotli', 'zstd'] + opt-deps: ['m2crypto', 'gmpy2', 'brotli', 'zstd', 'kyber_py'] - name: py3.12 with m2crypto, gmpy, gmpy2, brotli, and zstandard os: ubuntu-latest python-version: '3.12' # gmpy doesn't build with 3.12 # coverage to codeclimate can be submitted just once - opt-deps: ['m2crypto', 'gmpy2', 'codeclimate', 'brotli', 'zstd'] + opt-deps: ['m2crypto', 'gmpy2', 'codeclimate', 'brotli', 'zstd', 'kyber_py'] steps: - uses: actions/checkout@v2 if: ${{ !matrix.container }} @@ -346,6 +362,9 @@ jobs: - name: Install zstandard for py3.8 and after if: ${{ contains(matrix.opt-deps, 'zstd') }} run: pip install zstandard + - name: Install kyber_py + if: ${{ contains(matrix.opt-deps, 'kyber_py') }} + run: pip install "https://github.com/GiacomoPope/kyber-py/archive/b187189a514b3327578928c1d4c901d34592678e.zip" - name: Install build dependencies (2.6) if: ${{ matrix.python-version == '2.6' }} run: | diff --git a/tlslite/constants.py b/tlslite/constants.py index 63aa61f1..49617655 100644 --- a/tlslite/constants.py +++ b/tlslite/constants.py @@ -438,7 +438,13 @@ class GroupName(TLSEnum): brainpoolP512r1tls13 = 33 allEC.extend(list(range(31, 34))) - all = allEC + allFF + # draft-kwiatkowski-tls-ecdhe-mlkem + secp256r1mlkem768 = 0x11EB + x25519mlkem768 = 0x11EC + secp384r1mlkem1024 = 0x11ED + allKEM = [0x11EB, 0x11EC, 0x11ED] + + all = allEC + allFF + allKEM @classmethod def toRepr(cls, value, blacklist=None): diff --git a/tlslite/handshakesettings.py b/tlslite/handshakesettings.py index 3a8755ac..1059f5ad 100644 --- a/tlslite/handshakesettings.py +++ b/tlslite/handshakesettings.py @@ -10,7 +10,7 @@ from .constants import CertificateType from .utils import cryptomath from .utils import cipherfactory -from .utils.compat import ecdsaAllCurves, int_types +from .utils.compat import ecdsaAllCurves, int_types, ML_KEM_AVAILABLE from .utils.compression import compression_algo_impls CIPHER_NAMES = ["chacha20-poly1305", @@ -34,9 +34,13 @@ ALL_RSA_SIGNATURE_HASHES = RSA_SIGNATURE_HASHES + ["md5"] SIGNATURE_SCHEMES = ["Ed25519", "Ed448"] RSA_SCHEMES = ["pss", "pkcs1"] +CURVE_NAMES = [] +if ML_KEM_AVAILABLE: + CURVE_NAMES += ["secp256r1mlkem768", "x25519mlkem768", + "secp384r1mlkem1024"] # while secp521r1 is the most secure, it's also much slower than the others # so place it as the last one -CURVE_NAMES = ["x25519", "x448", "secp384r1", "secp256r1", +CURVE_NAMES += ["x25519", "x448", "secp384r1", "secp256r1", "secp521r1"] ALL_CURVE_NAMES = CURVE_NAMES + ["secp256k1", "brainpoolP512r1", "brainpoolP384r1", "brainpoolP256r1"] @@ -57,7 +61,8 @@ TLS13_PERMITTED_GROUPS = ["secp256r1", "secp384r1", "secp521r1", "x25519", "x448", "ffdhe2048", "ffdhe3072", "ffdhe4096", "ffdhe6144", - "ffdhe8192"] + "ffdhe8192", "secp256r1mlkem768", "x25519mlkem768", + "secp384r1mlkem1024"] KNOWN_VERSIONS = ((3, 0), (3, 1), (3, 2), (3, 3), (3, 4)) TICKET_CIPHERS = ["chacha20-poly1305", "aes256gcm", "aes128gcm", "aes128ccm", "aes128ccm_8", "aes256ccm", "aes256ccm_8"] @@ -395,7 +400,11 @@ def _init_key_settings(self): self.dhParams = None self.dhGroups = list(ALL_DH_GROUP_NAMES) self.defaultCurve = "secp256r1" - self.keyShares = ["secp256r1", "x25519"] + if ML_KEM_AVAILABLE: + self.keyShares = ["x25519mlkem768"] + else: + self.keyShares = [] + self.keyShares += ["secp256r1", "x25519"] self.padding_cb = None self.use_heartbeat_extension = True self.heartbeat_response_callback = None diff --git a/tlslite/keyexchange.py b/tlslite/keyexchange.py index 2242aad3..e315cbbf 100644 --- a/tlslite/keyexchange.py +++ b/tlslite/keyexchange.py @@ -21,9 +21,13 @@ from .utils import tlshashlib as hashlib from .utils.x25519 import x25519, x448, X25519_G, X448_G, X25519_ORDER_SIZE, \ X448_ORDER_SIZE -from .utils.compat import int_types +from .utils.compat import int_types, ML_KEM_AVAILABLE from .utils.codec import DecodeError +if ML_KEM_AVAILABLE: + from kyber_py.ml_kem import ML_KEM_768, ML_KEM_1024 + + class KeyExchange(object): """ Common API for calculating Premaster secret @@ -1062,3 +1066,148 @@ def calc_shared_key(self, private, peer_share): S = ecdhYc * private return numberToByteArray(S.x(), getPointByteSize(ecdhYc)) + + +class KEMKeyExchange(object): + def __init__(self, group, version): + if not ML_KEM_AVAILABLE: + raise TLSInternalError("kyber-py library not installed!") + self.group = group + assert version == (3, 4) + del version + + if self.group not in GroupName.allKEM: + raise TLSInternalError("called with wrong group") + + if self.group == GroupName.secp256r1mlkem768: + self._classic_group = GroupName.secp256r1 + elif self.group == GroupName.x25519mlkem768: + self._classic_group = GroupName.x25519 + else: + assert self.group == GroupName.secp384r1mlkem1024 + self._classic_group = GroupName.secp384r1 + + def get_random_private_key(self): + """Generates a random value to be used as the private key in KEM.""" + + if self.group not in GroupName.allKEM: + raise TLSInternalError("called with wrong group") + if self.group in (GroupName.secp256r1mlkem768, + GroupName.x25519mlkem768): + pqc_pub_key, pqc_priv_key = ML_KEM_768.keygen() + else: + pqc_pub_key, pqc_priv_key = ML_KEM_1024.keygen() + + classic_kex = ECDHKeyExchange(self._classic_group, (3, 4)) + classic_key = classic_kex.get_random_private_key() + + return ((pqc_pub_key, pqc_priv_key), classic_key) + + def calc_public_value(self, private): + classic_kex = ECDHKeyExchange(self._classic_group, (3, 4)) + + classic_pub_key_share = classic_kex.calc_public_value(private[1]) + + if self.group == GroupName.x25519mlkem768: + return private[0][0] + classic_pub_key_share + return classic_pub_key_share + private[0][0] + + def encapsulate_key(self, public): + if self.group == GroupName.secp256r1mlkem768: + classic_key_len = 65 + pqc_key_len = 1184 + pqc_first = False + ml_kem = ML_KEM_768 + elif self.group == GroupName.x25519mlkem768: + classic_key_len = 32 + pqc_key_len = 1184 + pqc_first = True + ml_kem = ML_KEM_768 + else: + assert self.group == GroupName.secp384r1mlkem1024 + classic_key_len = 97 + pqc_key_len = 1568 + pqc_first = False + ml_kem = ML_KEM_1024 + + if len(public) != classic_key_len + pqc_key_len: + raise TLSIllegalParameterException( + "Invalid key size for the selected group") + + if pqc_first: + pqc_key = public[:pqc_key_len] + classic_key_share = bytearray(public[pqc_key_len:]) + else: + classic_key_share = bytearray(public[:classic_key_len]) + pqc_key = public[classic_key_len:] + + classic_kex = ECDHKeyExchange(self._classic_group, (3, 4)) + classic_key = classic_kex.get_random_private_key() + classic_my_key_share = classic_kex.calc_public_value(classic_key) + classic_shared_secret = classic_kex.calc_shared_key( + classic_key, classic_key_share) + + try: + pqc_shared_secret, pqc_encaps = ml_kem.encaps(pqc_key) + except ValueError: + raise TLSIllegalParameterException( + "Invalid PQC key from peer") + + if pqc_first: + shared_secret = pqc_shared_secret + classic_shared_secret + key_encapsulation = pqc_encaps + classic_my_key_share + else: + shared_secret = classic_shared_secret + pqc_shared_secret + key_encapsulation = classic_my_key_share + pqc_encaps + + return shared_secret, key_encapsulation + + def calc_shared_key(self, private, key_encaps): + if self.group == GroupName.secp256r1mlkem768: + classic_key_len = 65 + pqc_key_len = 1088 + pqc_first = False + ml_kem = ML_KEM_768 + elif self.group == GroupName.x25519mlkem768: + classic_key_len = 32 + pqc_key_len = 1088 + pqc_first = True + ml_kem = ML_KEM_768 + else: + assert self.group == GroupName.secp384r1mlkem1024 + classic_key_len = 97 + pqc_key_len = 1568 + pqc_first = False + ml_kem = ML_KEM_1024 + + if len(key_encaps) != classic_key_len + pqc_key_len: + raise TLSIllegalParameterException( + "Invalid key size for the selected group. " + "Expected {0}, received {1}".format( + classic_key_len + pqc_key_len, + len(key_encaps))) + + if pqc_first: + pqc_key = key_encaps[:pqc_key_len] + classic_key_share = bytearray(key_encaps[pqc_key_len:]) + else: + classic_key_share = bytearray(key_encaps[:classic_key_len]) + pqc_key = key_encaps[classic_key_len:] + + classic_kex = ECDHKeyExchange(self._classic_group, (3, 4)) + classic_shared_secret = classic_kex.calc_shared_key( + private[1], classic_key_share) + + try: + pqc_shared_secret = ml_kem.decaps(private[0][1], pqc_key) + except ValueError: + raise TLSIllegalParameterException( + "Error in KEM decapsulation") + + if pqc_first: + shared_secret = pqc_shared_secret + classic_shared_secret + else: + shared_secret = classic_shared_secret + pqc_shared_secret + + return shared_secret + diff --git a/tlslite/tlsconnection.py b/tlslite/tlsconnection.py index 7abfe2e3..7fcce7ea 100644 --- a/tlslite/tlsconnection.py +++ b/tlslite/tlsconnection.py @@ -35,7 +35,7 @@ from .utils.deprecations import deprecated_params from .keyexchange import KeyExchange, RSAKeyExchange, DHE_RSAKeyExchange, \ ECDHE_RSAKeyExchange, SRPKeyExchange, ADHKeyExchange, \ - AECDHKeyExchange, FFDHKeyExchange, ECDHKeyExchange + AECDHKeyExchange, FFDHKeyExchange, ECDHKeyExchange, KEMKeyExchange from .handshakehelpers import HandshakeHelpers from .utils.cipherfactory import createAESCCM, createAESCCM_8, \ createAESGCM, createCHACHA20 @@ -1196,6 +1196,8 @@ def _clientGetServerHello(self, settings, session, clientHello): @staticmethod def _getKEX(group, version): """Get object for performing key exchange.""" + if group in GroupName.allKEM: + return KEMKeyExchange(group, version) if group in GroupName.allFF: return FFDHKeyExchange(group, version) return ECDHKeyExchange(group, version) @@ -1209,6 +1211,15 @@ def _genKeyShareEntry(cls, group, version): share = kex.calc_public_value(private) return KeyShareEntry().create(group, share, private) + @classmethod + def _KEMEncaps(cls, group, public): + """Generate the server's KeyShareEntry object with encapsulated secret. + """ + kex = cls._getKEX(group, (3, 4)) + shared_sec, key_share_value = kex.encapsulate_key(public) + key_share = KeyShareEntry().create(group, key_share_value, None) + return shared_sec, key_share + @staticmethod def _getPRFParams(cipher_suite): """Return name of hash used for PRF and the hash output size.""" @@ -2803,16 +2814,21 @@ def _serverTLS13Handshake(self, settings, clientHello, cipherSuite, (psk is None and privateKey): self.ecdhCurve = selected_group kex = self._getKEX(selected_group, version) - key_share = self._genKeyShareEntry(selected_group, version) + if selected_group in GroupName.allKEM: + shared_sec, key_share = self._KEMEncaps( + selected_group, + cl_key_share.key_exchange) + else: + key_share = self._genKeyShareEntry(selected_group, version) - try: - shared_sec = kex.calc_shared_key(key_share.private, - cl_key_share.key_exchange) - except TLSIllegalParameterException as alert: - for result in self._sendError( - AlertDescription.illegal_parameter, - str(alert)): - yield result + try: + shared_sec = kex.calc_shared_key(key_share.private, + cl_key_share.key_exchange) + except TLSIllegalParameterException as alert: + for result in self._sendError( + AlertDescription.illegal_parameter, + str(alert)): + yield result sh_extensions.append(ServerKeyShareExtension().create(key_share)) elif (psk is not None and @@ -4915,7 +4931,11 @@ def _sigHashesToList(settings, privateKey=None, certList=None, @staticmethod def _curveNamesToList(settings): """Convert list of acceptable curves to array identifiers""" - return [getattr(GroupName, val) for val in settings.eccCurves] + ret = [getattr(GroupName, val) for val in settings.eccCurves] + if settings.maxVersion < (3, 4) and (3, 4) not in settings.versions: + # if we don't support TLS 1.3, filter out KEMs + ret = [i for i in ret if i not in GroupName.allKEM] + return ret @staticmethod def _groupNamesToList(settings): diff --git a/tlslite/utils/compat.py b/tlslite/utils/compat.py index 359de7f5..71945d67 100644 --- a/tlslite/utils/compat.py +++ b/tlslite/utils/compat.py @@ -235,3 +235,14 @@ def byte_length(val): ecdsaAllCurves = False else: ecdsaAllCurves = True + + +# kyber-py is an optional dependency +try: + from kyber_py.ml_kem import ML_KEM_768, ML_KEM_1024 + del ML_KEM_768 + del ML_KEM_1024 +except ImportError: + ML_KEM_AVAILABLE = False +else: + ML_KEM_AVAILABLE = True diff --git a/unit_tests/test_tlslite_keyexchange.py b/unit_tests/test_tlslite_keyexchange.py index c6e91076..3e89c022 100644 --- a/unit_tests/test_tlslite_keyexchange.py +++ b/unit_tests/test_tlslite_keyexchange.py @@ -35,7 +35,7 @@ from tlslite import VerifierDB from tlslite.extensions import SupportedGroupsExtension, SNIExtension from tlslite.utils.ecc import getCurveByName, getPointByteSize -from tlslite.utils.compat import a2b_hex +from tlslite.utils.compat import a2b_hex, ML_KEM_AVAILABLE import ecdsa from operator import mul try: @@ -45,7 +45,7 @@ from tlslite.keyexchange import KeyExchange, RSAKeyExchange, \ DHE_RSAKeyExchange, SRPKeyExchange, ECDHE_RSAKeyExchange, \ - RawDHKeyExchange, FFDHKeyExchange + RawDHKeyExchange, FFDHKeyExchange, KEMKeyExchange from tlslite.utils.x25519 import x25519, X25519_G, x448, X448_G from tlslite.mathtls import RFC7919_GROUPS from tlslite.utils.python_key import Python_Key @@ -2583,3 +2583,155 @@ def test_calc_shared_secret_for_invalid_sized_input(self): key_share = bytearray(b'\x00' * 10 + b'\x04') with self.assertRaises(TLSIllegalParameterException): kex.calc_shared_key(private, key_share) + + +@unittest.skipIf(not ML_KEM_AVAILABLE, "Kyber-py not installed") +class TestKEMKeyExchange(unittest.TestCase): + def test_init_with_wrong_group(self): + with self.assertRaises(TLSInternalError): + KEMKeyExchange(GroupName.x25519, (3, 4)) + + def test_with_wrong_key_share_size(self): + group = GroupName.x25519mlkem768 + version = (3, 4) + + kex = KEMKeyExchange(group, version) + + with self.assertRaises(TLSIllegalParameterException) as e: + # one byte too long + kex.encapsulate_key(bytearray(32 + 1184 + 1)) + + self.assertIn("Invalid key size", str(e.exception)) + + def test_with_invalid_classic_key_share(self): + group = GroupName.secp256r1mlkem768 + version = (3, 4) + + kex = KEMKeyExchange(group, version) + + alice_private_key = kex.get_random_private_key() + alice_key_share = kex.calc_public_value(alice_private_key) + alice_key_share = bytearray(alice_key_share) + + alice_key_share[1] ^= 0xff + + with self.assertRaises(TLSIllegalParameterException) as e: + kex.encapsulate_key(alice_key_share) + + self.assertIn("Invalid ECC", str(e.exception)) + + def test_with_invalid_pqc_key_share(self): + group = GroupName.secp256r1mlkem768 + version = (3, 4) + + kex = KEMKeyExchange(group, version) + + alice_private_key = kex.get_random_private_key() + alice_key_share = kex.calc_public_value(alice_private_key) + alice_key_share = bytearray(alice_key_share) + + alice_key_share[67] = 0xff + + with self.assertRaises(TLSIllegalParameterException) as e: + kex.encapsulate_key(alice_key_share) + + self.assertIn("Invalid PQC", str(e.exception)) + + def test_with_modified_pqc_key_share(self): + group = GroupName.secp256r1mlkem768 + version = (3, 4) + + kex = KEMKeyExchange(group, version) + + alice_private_key = kex.get_random_private_key() + alice_key_share = kex.calc_public_value(alice_private_key) + alice_key_share = bytearray(alice_key_share) + + alice_key_share[67] = 0x01 + + bob_shared_secret, bob_key_share = kex.encapsulate_key(alice_key_share) + + alice_shared_secret = kex.calc_shared_key( + alice_private_key, bob_key_share) + + self.assertNotEqual(alice_shared_secret, bob_shared_secret) + + def test_decaps_with_wrong_size_of_share(self): + group = GroupName.secp256r1mlkem768 + version = (3, 4) + + kex = KEMKeyExchange(group, version) + + alice_private_key = kex.get_random_private_key() + + with self.assertRaises(TLSIllegalParameterException) as e: + kex.calc_shared_key(alice_private_key, bytearray(65 + 1088 + 1)) + + self.assertIn("Invalid key size", str(e.exception)) + + def test_decaps_with_invalid_classical_share(self): + group = GroupName.secp256r1mlkem768 + version = (3, 4) + + kex = KEMKeyExchange(group, version) + + alice_private_key = kex.get_random_private_key() + alice_key_share = kex.calc_public_value(alice_private_key) + + bob_shared_secret, bob_key_share = kex.encapsulate_key(alice_key_share) + bob_key_share = bytearray(bob_key_share) + + bob_key_share[2] ^= 0xff + + with self.assertRaises(TLSIllegalParameterException) as e: + kex.calc_shared_key(alice_private_key, bob_key_share) + + self.assertIn("Invalid ECC", str(e.exception)) + + def test_decaps_with_invalid_pqc_share(self): + group = GroupName.secp256r1mlkem768 + version = (3, 4) + + kex = KEMKeyExchange(group, version) + + alice_private_key = kex.get_random_private_key() + alice_key_share = kex.calc_public_value(alice_private_key) + + bob_shared_secret, bob_key_share = kex.encapsulate_key(alice_key_share) + bob_key_share = bytearray(bob_key_share) + + bob_key_share[68] ^= 0xff + + alice_shared_secret = kex.calc_shared_key( + alice_private_key, bob_key_share) + + self.assertNotEqual(alice_shared_secret, bob_shared_secret) + + def do_kex(self, group): + version = (3, 4) + + alice_kex = KEMKeyExchange(group, version) + + alice_private_key = alice_kex.get_random_private_key() + alice_key_share = alice_kex.calc_public_value(alice_private_key) + + bob_kex = KEMKeyExchange(group, version) + bob_shared_secret, bob_key_share = \ + bob_kex.encapsulate_key(alice_key_share) + + alice_shared_secret = alice_kex.calc_shared_key( + alice_private_key, bob_key_share) + + self.assertEqual(alice_shared_secret, bob_shared_secret) + + def test_x25519_ml_kem_768(self): + group = GroupName.x25519mlkem768 + self.do_kex(group) + + def test_p256_ml_kem_768(self): + group = GroupName.secp256r1mlkem768 + self.do_kex(group) + + def test_p384_ml_kem_1024(self): + group = GroupName.secp384r1mlkem1024 + self.do_kex(group)