From c6b77c34551bad8940f04f082cade812a5b6486c Mon Sep 17 00:00:00 2001 From: GHadari Date: Wed, 25 Aug 2021 10:36:21 +0300 Subject: [PATCH] write keys to file if path exists and file does not exist (#21) * write keys to file if path exists and file does not exist --- jwthenticator/tests/test_utils.py | 109 ++++++++++++++++++++++++++++++ jwthenticator/tests/utils.py | 19 ++++++ jwthenticator/utils.py | 48 ++++++++----- pyproject.toml | 3 + 4 files changed, 162 insertions(+), 17 deletions(-) create mode 100644 jwthenticator/tests/test_utils.py diff --git a/jwthenticator/tests/test_utils.py b/jwthenticator/tests/test_utils.py new file mode 100644 index 0000000..1d96747 --- /dev/null +++ b/jwthenticator/tests/test_utils.py @@ -0,0 +1,109 @@ +from datetime import datetime +import os +from importlib import reload +from os import environ +from typing import Tuple, Optional, AsyncGenerator +import pytest + +import aiofiles +from async_generator import asynccontextmanager +from Cryptodome.PublicKey import RSA + +import jwthenticator.utils +from jwthenticator import consts +from jwthenticator.utils import get_rsa_key_pair +from jwthenticator.tests.utils import random_key, backup_environment + +PUBLIC_KEY_PATH_ENV_KEY = "RSA_PUBLIC_KEY_PATH" +PRIVATE_KEY_PATH_ENV_KEY = "RSA_PRIVATE_KEY_PATH" +PUBLIC_KEY_VALUE_ENV_KEY = "RSA_PUBLIC_KEY" +PRIVATE_KEY_VALUE_ENV_KEY = "RSA_PRIVATE_KEY" + + +def _reload_env_vars_get_rsa_key_pair() -> Tuple[str, Optional[str]]: + reload(consts) + reload(jwthenticator.utils) + return get_rsa_key_pair() + + +@asynccontextmanager +async def _create_random_file() -> AsyncGenerator[Tuple[str, str], None]: + random_data = await random_key(8) + filename = f"test_tmp_file_{datetime.now()}.txt" + # ignore type due to mypy-aiofiles issues + async with aiofiles.open(filename, "w", encoding='utf8') as file: # type: ignore + await file.write(random_data) + try: + yield filename, random_data + finally: + os.remove(filename) + + +@backup_environment +@pytest.mark.asyncio +# Get key pair value from env (not the path) +async def test_get_rsa_key_pair_by_env_value() -> None: + generated_public_key = await random_key(8) + generated_private_key = await random_key(8) + environ[PUBLIC_KEY_PATH_ENV_KEY] = "" + environ[PRIVATE_KEY_PATH_ENV_KEY] = "" + environ[PUBLIC_KEY_VALUE_ENV_KEY] = generated_public_key + environ[PRIVATE_KEY_VALUE_ENV_KEY] = generated_private_key + public_key, private_key = _reload_env_vars_get_rsa_key_pair() + assert public_key == generated_public_key + assert private_key == generated_private_key + + +@backup_environment +@pytest.mark.asyncio +# No keys or path inputted - create keys +async def test_get_rsa_key_pair_no_input() -> None: + environ[PUBLIC_KEY_PATH_ENV_KEY] = "" + environ[PRIVATE_KEY_PATH_ENV_KEY] = "" + environ[PUBLIC_KEY_VALUE_ENV_KEY] = "" + environ[PRIVATE_KEY_VALUE_ENV_KEY] = "" + public_key, private_key = _reload_env_vars_get_rsa_key_pair() + assert RSA.import_key(str(private_key)) + assert RSA.import_key(str(public_key)) + + +@backup_environment +@pytest.mark.asyncio +# File exists - read keys +async def test_get_rsa_key_pair_from_file() -> None: + # Pylint sets a false positive + async with _create_random_file() as (private_file_name, private_file_data), \ + _create_random_file() as (public_file_name, public_file_data): # pylint: disable=not-async-context-manager + environ[PUBLIC_KEY_PATH_ENV_KEY] = public_file_name + environ[PRIVATE_KEY_PATH_ENV_KEY] = private_file_name + public_key, private_key = _reload_env_vars_get_rsa_key_pair() + assert public_file_data in public_key + + # Type can be ignored because a private key should be generated + assert private_file_data in private_key # type: ignore + + +@backup_environment +@pytest.mark.asyncio +# Path exists and files do not exist - create them +async def test_get_rsa_key_pair_create_file() -> None: + public_file_name = f"test_tmp_file_{datetime.now()}.txt" + private_file_name = f"test_tmp_file_{datetime.now()}.txt" + environ[PUBLIC_KEY_PATH_ENV_KEY] = public_file_name + environ[PRIVATE_KEY_PATH_ENV_KEY] = private_file_name + try: + public_key, private_key = _reload_env_vars_get_rsa_key_pair() + # ignore type due to mypy-aiofiles issues + async with aiofiles.open(public_file_name, 'r', encoding='utf8') as file: # type: ignore + public_key_from_file = await file.read() + async with aiofiles.open(private_file_name, 'r', encoding='utf8') as file: # type: ignore + private_key_from_file = await file.read() + assert public_key_from_file == public_key + assert private_key_from_file == private_key + assert RSA.import_key(public_key_from_file) + assert RSA.import_key(private_key_from_file) + finally: + try: + os.remove(public_file_name) + finally: + os.remove(private_file_name) diff --git a/jwthenticator/tests/utils.py b/jwthenticator/tests/utils.py index b90778f..08b3433 100644 --- a/jwthenticator/tests/utils.py +++ b/jwthenticator/tests/utils.py @@ -1,11 +1,16 @@ from __future__ import absolute_import +from importlib import reload +import functools import random +from os import environ from string import ascii_letters from datetime import datetime, timedelta from hashlib import sha512 from uuid import uuid4 +from jwthenticator import consts, utils + async def random_key(length: int = 32) -> str: return "".join([random.choice(ascii_letters) for i in range(length)]) @@ -28,3 +33,17 @@ async def hash_key(key: str) -> str: async def future_datetime(seconds: int = 0) -> datetime: return datetime.utcnow() + timedelta(seconds=seconds) + + +def backup_environment(func): # type: ignore + @functools.wraps(func) + async def wrapper(*args, **kwargs): # type: ignore + _environ_copy = environ.copy() + try: + return await func(*args, **kwargs) + finally: + environ.clear() + environ.update(_environ_copy) + reload(consts) + reload(utils) + return wrapper diff --git a/jwthenticator/utils.py b/jwthenticator/utils.py index 777fe52..d18936c 100644 --- a/jwthenticator/utils.py +++ b/jwthenticator/utils.py @@ -1,5 +1,6 @@ from __future__ import absolute_import +from os.path import isfile from typing import Tuple, Optional from urllib.parse import urlparse @@ -13,30 +14,43 @@ def get_rsa_key_pair() -> Tuple[str, Optional[str]]: """ Get RSA key pair. - Will try to get them by this order: - 1. RSA_PUBLIC/PRIVATE_KEY_PATH - 2. RSA_PUBLIC/PRIVATE_KEY - 3. Generate new keys. + Will get RSA key pair depending on available ENV variables, in the following order: + 1. Read file path from RSA_PUBLIC_PATH and PRIVATE_KEY_PATH and use the files there, + If a path is specified (in RSA_PUBLIC_PATH and PRIVATE_KEY_PATH) but the files do + not exist - they will be created and populated + 2. Use data directly in the env vars RSA_PUBLIC_KEY and RSA_PRIVATE_KEY + 3. Use stateless new keys :return (public_key, private_key): A key pair tuple. Will raise exception if key paths are given and fail to read. """ - if RSA_PUBLIC_KEY_PATH is not None: - # Read public key. - with open(RSA_PUBLIC_KEY_PATH) as f_obj: - public_key = f_obj.read() + if RSA_PUBLIC_KEY_PATH: + if isfile(RSA_PUBLIC_KEY_PATH): + return _read_rsa_keys_from_file() + return _create_rsa_key_files() - # Read private key if given. - private_key = None - if RSA_PRIVATE_KEY_PATH is not None: - with open(RSA_PRIVATE_KEY_PATH) as f_obj: - private_key = f_obj.read() + if RSA_PUBLIC_KEY: + return RSA_PUBLIC_KEY, RSA_PRIVATE_KEY - return (public_key, private_key) + return create_rsa_key_pair() - if RSA_PUBLIC_KEY is not None: - return (RSA_PUBLIC_KEY, RSA_PRIVATE_KEY) - return create_rsa_key_pair() +def _read_rsa_keys_from_file() -> Tuple[str, Optional[str]]: + with open(RSA_PUBLIC_KEY_PATH, 'r', encoding='utf8') as f_obj: + public_key = f_obj.read() + private_key = None + if RSA_PRIVATE_KEY_PATH is not None: + with open(RSA_PRIVATE_KEY_PATH, 'r', encoding='utf8') as f_obj: + private_key = f_obj.read() + return public_key, private_key + + +def _create_rsa_key_files() -> Tuple[str, Optional[str]]: + public_key, private_key = create_rsa_key_pair() + with open(RSA_PUBLIC_KEY_PATH, 'w', encoding='utf8') as f_obj: + f_obj.write(public_key) + with open(RSA_PRIVATE_KEY_PATH, 'w', encoding='utf8') as f_obj: + f_obj.write(private_key) + return public_key, private_key def create_rsa_key_pair() -> Tuple[str, str]: diff --git a/pyproject.toml b/pyproject.toml index 2d8d721..948e29a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -35,6 +35,9 @@ pytest-aiohttp = "^0.3" freezegun = "^1.0" pyjwt = "^2.0" diagrams = "^0.17.0" +mock = "^4.0.3" +async_generator = "^1.10" +aiofiles = "^0.7.0" [tool.pylint.message_control]