Skip to content

Commit

Permalink
write keys to file if path exists and file does not exist (#21)
Browse files Browse the repository at this point in the history
* write keys to file if path exists and file does not exist
  • Loading branch information
GHadari authored Aug 25, 2021
1 parent d52f312 commit c6b77c3
Show file tree
Hide file tree
Showing 4 changed files with 162 additions and 17 deletions.
109 changes: 109 additions & 0 deletions jwthenticator/tests/test_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
from datetime import datetime
import os
from importlib import reload
from os import environ
from typing import Tuple, Optional, AsyncGenerator
import pytest

import aiofiles
from async_generator import asynccontextmanager
from Cryptodome.PublicKey import RSA

import jwthenticator.utils
from jwthenticator import consts
from jwthenticator.utils import get_rsa_key_pair
from jwthenticator.tests.utils import random_key, backup_environment

PUBLIC_KEY_PATH_ENV_KEY = "RSA_PUBLIC_KEY_PATH"
PRIVATE_KEY_PATH_ENV_KEY = "RSA_PRIVATE_KEY_PATH"
PUBLIC_KEY_VALUE_ENV_KEY = "RSA_PUBLIC_KEY"
PRIVATE_KEY_VALUE_ENV_KEY = "RSA_PRIVATE_KEY"


def _reload_env_vars_get_rsa_key_pair() -> Tuple[str, Optional[str]]:
reload(consts)
reload(jwthenticator.utils)
return get_rsa_key_pair()


@asynccontextmanager
async def _create_random_file() -> AsyncGenerator[Tuple[str, str], None]:
random_data = await random_key(8)
filename = f"test_tmp_file_{datetime.now()}.txt"
# ignore type due to mypy-aiofiles issues
async with aiofiles.open(filename, "w", encoding='utf8') as file: # type: ignore
await file.write(random_data)
try:
yield filename, random_data
finally:
os.remove(filename)


@backup_environment
@pytest.mark.asyncio
# Get key pair value from env (not the path)
async def test_get_rsa_key_pair_by_env_value() -> None:
generated_public_key = await random_key(8)
generated_private_key = await random_key(8)
environ[PUBLIC_KEY_PATH_ENV_KEY] = ""
environ[PRIVATE_KEY_PATH_ENV_KEY] = ""
environ[PUBLIC_KEY_VALUE_ENV_KEY] = generated_public_key
environ[PRIVATE_KEY_VALUE_ENV_KEY] = generated_private_key
public_key, private_key = _reload_env_vars_get_rsa_key_pair()
assert public_key == generated_public_key
assert private_key == generated_private_key


@backup_environment
@pytest.mark.asyncio
# No keys or path inputted - create keys
async def test_get_rsa_key_pair_no_input() -> None:
environ[PUBLIC_KEY_PATH_ENV_KEY] = ""
environ[PRIVATE_KEY_PATH_ENV_KEY] = ""
environ[PUBLIC_KEY_VALUE_ENV_KEY] = ""
environ[PRIVATE_KEY_VALUE_ENV_KEY] = ""
public_key, private_key = _reload_env_vars_get_rsa_key_pair()
assert RSA.import_key(str(private_key))
assert RSA.import_key(str(public_key))


@backup_environment
@pytest.mark.asyncio
# File exists - read keys
async def test_get_rsa_key_pair_from_file() -> None:
# Pylint sets a false positive
async with _create_random_file() as (private_file_name, private_file_data), \
_create_random_file() as (public_file_name, public_file_data): # pylint: disable=not-async-context-manager
environ[PUBLIC_KEY_PATH_ENV_KEY] = public_file_name
environ[PRIVATE_KEY_PATH_ENV_KEY] = private_file_name
public_key, private_key = _reload_env_vars_get_rsa_key_pair()
assert public_file_data in public_key

# Type can be ignored because a private key should be generated
assert private_file_data in private_key # type: ignore


@backup_environment
@pytest.mark.asyncio
# Path exists and files do not exist - create them
async def test_get_rsa_key_pair_create_file() -> None:
public_file_name = f"test_tmp_file_{datetime.now()}.txt"
private_file_name = f"test_tmp_file_{datetime.now()}.txt"
environ[PUBLIC_KEY_PATH_ENV_KEY] = public_file_name
environ[PRIVATE_KEY_PATH_ENV_KEY] = private_file_name
try:
public_key, private_key = _reload_env_vars_get_rsa_key_pair()
# ignore type due to mypy-aiofiles issues
async with aiofiles.open(public_file_name, 'r', encoding='utf8') as file: # type: ignore
public_key_from_file = await file.read()
async with aiofiles.open(private_file_name, 'r', encoding='utf8') as file: # type: ignore
private_key_from_file = await file.read()
assert public_key_from_file == public_key
assert private_key_from_file == private_key
assert RSA.import_key(public_key_from_file)
assert RSA.import_key(private_key_from_file)
finally:
try:
os.remove(public_file_name)
finally:
os.remove(private_file_name)
19 changes: 19 additions & 0 deletions jwthenticator/tests/utils.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,16 @@
from __future__ import absolute_import

from importlib import reload
import functools
import random
from os import environ
from string import ascii_letters
from datetime import datetime, timedelta
from hashlib import sha512
from uuid import uuid4

from jwthenticator import consts, utils


async def random_key(length: int = 32) -> str:
return "".join([random.choice(ascii_letters) for i in range(length)])
Expand All @@ -28,3 +33,17 @@ async def hash_key(key: str) -> str:

async def future_datetime(seconds: int = 0) -> datetime:
return datetime.utcnow() + timedelta(seconds=seconds)


def backup_environment(func): # type: ignore
@functools.wraps(func)
async def wrapper(*args, **kwargs): # type: ignore
_environ_copy = environ.copy()
try:
return await func(*args, **kwargs)
finally:
environ.clear()
environ.update(_environ_copy)
reload(consts)
reload(utils)
return wrapper
48 changes: 31 additions & 17 deletions jwthenticator/utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import absolute_import

from os.path import isfile
from typing import Tuple, Optional
from urllib.parse import urlparse

Expand All @@ -13,30 +14,43 @@
def get_rsa_key_pair() -> Tuple[str, Optional[str]]:
"""
Get RSA key pair.
Will try to get them by this order:
1. RSA_PUBLIC/PRIVATE_KEY_PATH
2. RSA_PUBLIC/PRIVATE_KEY
3. Generate new keys.
Will get RSA key pair depending on available ENV variables, in the following order:
1. Read file path from RSA_PUBLIC_PATH and PRIVATE_KEY_PATH and use the files there,
If a path is specified (in RSA_PUBLIC_PATH and PRIVATE_KEY_PATH) but the files do
not exist - they will be created and populated
2. Use data directly in the env vars RSA_PUBLIC_KEY and RSA_PRIVATE_KEY
3. Use stateless new keys
:return (public_key, private_key): A key pair tuple.
Will raise exception if key paths are given and fail to read.
"""
if RSA_PUBLIC_KEY_PATH is not None:
# Read public key.
with open(RSA_PUBLIC_KEY_PATH) as f_obj:
public_key = f_obj.read()
if RSA_PUBLIC_KEY_PATH:
if isfile(RSA_PUBLIC_KEY_PATH):
return _read_rsa_keys_from_file()
return _create_rsa_key_files()

# Read private key if given.
private_key = None
if RSA_PRIVATE_KEY_PATH is not None:
with open(RSA_PRIVATE_KEY_PATH) as f_obj:
private_key = f_obj.read()
if RSA_PUBLIC_KEY:
return RSA_PUBLIC_KEY, RSA_PRIVATE_KEY

return (public_key, private_key)
return create_rsa_key_pair()

if RSA_PUBLIC_KEY is not None:
return (RSA_PUBLIC_KEY, RSA_PRIVATE_KEY)

return create_rsa_key_pair()
def _read_rsa_keys_from_file() -> Tuple[str, Optional[str]]:
with open(RSA_PUBLIC_KEY_PATH, 'r', encoding='utf8') as f_obj:
public_key = f_obj.read()
private_key = None
if RSA_PRIVATE_KEY_PATH is not None:
with open(RSA_PRIVATE_KEY_PATH, 'r', encoding='utf8') as f_obj:
private_key = f_obj.read()
return public_key, private_key


def _create_rsa_key_files() -> Tuple[str, Optional[str]]:
public_key, private_key = create_rsa_key_pair()
with open(RSA_PUBLIC_KEY_PATH, 'w', encoding='utf8') as f_obj:
f_obj.write(public_key)
with open(RSA_PRIVATE_KEY_PATH, 'w', encoding='utf8') as f_obj:
f_obj.write(private_key)
return public_key, private_key


def create_rsa_key_pair() -> Tuple[str, str]:
Expand Down
3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,9 @@ pytest-aiohttp = "^0.3"
freezegun = "^1.0"
pyjwt = "^2.0"
diagrams = "^0.17.0"
mock = "^4.0.3"
async_generator = "^1.10"
aiofiles = "^0.7.0"


[tool.pylint.message_control]
Expand Down

0 comments on commit c6b77c3

Please sign in to comment.