From 05bbfdf1fa275dc49990aa465ab3570f29f52736 Mon Sep 17 00:00:00 2001 From: npatel-cars <50211387+npatel-cars@users.noreply.github.com> Date: Thu, 27 Oct 2022 13:21:01 -0500 Subject: [PATCH] merge dev to main (#38) * Split normalize_config into two functions (#27) * Split normalize_config into two functions * Add test cases for normalize_config and parse_additional_config * Bug Fix: on test_parse_additional_config * Update test_parse_additional_config * Added new badges to the Readme (#30) * Move EC2 pricing calls to single function. (#29) * Added more clear SSH error message for improper credentials * Updated changelog * Fixed changelog * Updated SSH credential error message * Add tests for ssh.py module. * Add coverage as a test dependency. * Update changelog and fix style. * Add unittests for rsync module. (#33) * Add tests for yaml_loader.py to increase coverage (#34) * Add tests for yaml_loader.py to increase coverage * remove redundant imports Co-authored-by: ali * moved function outside for better testing (#35) Co-authored-by: ali * Added venv to .gitignore * bump version (#37) * Added venv to .gitignore * bump version bump version so we can merge with main * Update CHANGELOG.md update link for unreleased * Update CHANGELOG.md update link for unreleased Co-authored-by: Gabriele A. Ron Co-authored-by: Gabe Ron Co-authored-by: Heshanthaka Co-authored-by: Joao Moreira <13685125+jagmoreira@users.noreply.github.com> Co-authored-by: Gabriele A. Ron Co-authored-by: Mohammed Ali Zubair Co-authored-by: ali --- .bumpversion.cfg | 2 +- .gitignore | 2 + CHANGELOG.md | 14 +- README.md | 10 +- pyproject.toml | 8 ++ src/forge/__init__.py | 2 +- src/forge/common.py | 91 +++++++++++-- src/forge/create.py | 65 +++------ src/forge/destroy.py | 51 +------ src/forge/main.py | 4 +- src/forge/rsync.py | 9 +- src/forge/ssh.py | 7 +- src/forge/start.py | 67 +++++----- src/forge/yaml_loader.py | 4 +- tests/test_common.py | 88 +++++++++++- tests/test_rsync.py | 273 ++++++++++++++++++++++++++++++++++++++ tests/test_run.py | 1 + tests/test_ssh.py | 129 ++++++++++++++++++ tests/test_start.py | 69 ++++++++++ tests/test_yaml_loader.py | 29 +++- 20 files changed, 767 insertions(+), 158 deletions(-) create mode 100644 tests/test_rsync.py create mode 100644 tests/test_ssh.py create mode 100644 tests/test_start.py diff --git a/.bumpversion.cfg b/.bumpversion.cfg index ab2691e..a5a17bb 100644 --- a/.bumpversion.cfg +++ b/.bumpversion.cfg @@ -1,5 +1,5 @@ [bumpversion] -current_version = 1.0.1 +current_version = 1.0.2 commit = True tag = False parse = (?P\d+)\.(?P\d+)\.(?P\d+) diff --git a/.gitignore b/.gitignore index 170d0e8..5a40770 100644 --- a/.gitignore +++ b/.gitignore @@ -15,3 +15,5 @@ __pycache__/ *.egg-info/ build/ dist/ +.idea/ +venv/ \ No newline at end of file diff --git a/CHANGELOG.md b/CHANGELOG.md index 944c002..6caa440 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -6,9 +6,19 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), ## [Unreleased] + +## [1.0.2] - 2022-10-27 + ### Added +- **Tests** - Test for `ssh.py` and `rsync.py` module. - **GitHub** - Workflow to run unittests on every PR and push to main. +### Changed +- **SSH** - Add error to show when SSH credentials are invalid. +- **Dependencies** - Add `coverage` as a test dependency. +- **Readme** - Add new badges to the Readme. +- **Common** - Move EC2 pricing calls to single function in `common.py`. + ## [1.0.1] - 2022-09-28 @@ -23,12 +33,12 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), - **GitHub** - Update action to build and publish package only when version is bumped. - **Forge** - Added automatic tag `forge-name` to allow `Name` tag to be changed. - ## [1.0.0] - 2022-09-27 ### Added - **Initial commit** - Forge source code, unittests, docs, pyproject.toml, README.md, and LICENSE files. -[unreleased]: https://github.com/carsdotcom/cars-forge/compare/v1.0.1...HEAD +[unreleased]: https://github.com/carsdotcom/cars-forge/compare/v1.0.2...HEAD +[1.0.2]: https://github.com/carsdotcom/cars-forge/compare/v1.0.1...v1.0.2 [1.0.1]: https://github.com/carsdotcom/cars-forge/compare/v1.0.0...v1.0.1 [1.0.0]: https://github.com/carsdotcom/cars-forge/releases/tag/v1.0.0 diff --git a/README.md b/README.md index 45d83e1..39d80bc 100755 --- a/README.md +++ b/README.md @@ -1,5 +1,11 @@ -

- +

+ +[![GitHub license](https://img.shields.io/github/license/carsdotcom/cars-forge?color=navy&label=License&logo=License&style=flat-square)](https://github.com/carsdotcom/cars-forge/blob/main/LICENSE) +[![PyPI](https://img.shields.io/pypi/v/cars-forge?color=navy&style=flat-square)](https://pypi.org/project/cars-forge/) +![hacktoberfest](https://img.shields.io/github/issues/carsdotcom/cars-forge?color=orange&label=Hacktoberfest%202022&style=flat-square&?labelColor=black) +![PyPI - Downloads](https://img.shields.io/pypi/dm/cars-forge?color=navy&style=flat-square) +![GitHub Workflow Status (branch)](https://img.shields.io/github/workflow/status/carsdotcom/cars-forge/Publish%20Package/main?color=navy&style=flat-square) +![GitHub contributors](https://img.shields.io/github/contributors/carsdotcom/cars-forge?color=navy&style=flat-square) --- ## About diff --git a/pyproject.toml b/pyproject.toml index ad6989d..981805f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -44,6 +44,7 @@ dependencies = [ [project.optional-dependencies] test = [ "pytest~=7.1.0", + "pytest-cov~=4.0" ] dev = [ "bump2version~=1.0", @@ -72,3 +73,10 @@ packages = ["src/forge"] [tool.hatch.version] path = "src/forge/__init__.py" + +### +# Pytest settings +### +[tool.pytest.ini_options] +# Show coverage report with missing lines when running `pytest` +addopts = "--cov=forge --cov-report term-missing" diff --git a/src/forge/__init__.py b/src/forge/__init__.py index b383aca..5c92f2c 100755 --- a/src/forge/__init__.py +++ b/src/forge/__init__.py @@ -1,4 +1,4 @@ -__version__ = "1.0.1" +__version__ = "1.0.2" # Default values for forge's essential arguments DEFAULT_ARG_VALS = { diff --git a/src/forge/common.py b/src/forge/common.py index 6e31351..02ea5c3 100755 --- a/src/forge/common.py +++ b/src/forge/common.py @@ -7,6 +7,7 @@ import tempfile import sys import os +from datetime import datetime from numbers import Number import boto3 @@ -290,10 +291,9 @@ def _parse_list(option): return option -def normalize_config(config, additional_config=None): +def normalize_config(config): """normalizes the Forge configuration data - If additional_config is present, normalize_config will purely parse additional_config and add it to config. If it detects an environmental config data (determined by a lack of ram or cpu data), it processes the ratio and updates DEFAULT_ARG_VALS['ratio']. If it detects a user configuration option, it will parse the ram, cpu, ration, and market data so that it conforms to Forge's expectation. In either scenario, if it detects aws_az it will update @@ -303,8 +303,6 @@ def normalize_config(config, additional_config=None): ---------- config : dict Forge configuration data - additional_config : dict - Additional Forge use configuration options Notes ----- @@ -319,12 +317,6 @@ def normalize_config(config, additional_config=None): """ config = dict(config) - if additional_config: - additional_config = {x['name']: x['default'] for x in additional_config if x['default']} - config = {**config, **additional_config} - - return config - if config.get('aws_az'): config['region'] = config['aws_az'][:-1] @@ -346,6 +338,29 @@ def normalize_config(config, additional_config=None): return config +def parse_additional_config(config, additional_config): + """parse additional configuration data + + Parameters + ---------- + config : dict + Forge configuration data + additional_config : dict + Additional Forge use configuration options + + Returns + ------- + dict + The additional Forge configuration data + """ + config = dict(config) + + additional_config = {x['name']: x['default'] for x in additional_config if x['default']} + config = {**config, **additional_config} + + return config + + def set_boto_session(region, profile=None): """set the default Boto3 session @@ -458,3 +473,59 @@ def user_accessible_vars(config, **kwargs): user_vars.update({k: v for k, v in config.items() if k in ADDITIONAL_KEYS}) return user_vars + + +def get_ec2_pricing(ec2_type, market, config): + """Get the hourly spot or on-demand price of given EC2 instance type. + + Parameters + ---------- + ec2_type : str + EC2 type to get pricing for. + market : str + Whether EC2 is a `'spot'` or `'on-demand'` instance. + config : dict + Forge configuration data. + + Returns + ------- + float + Hourly price of given EC2 type in given market. + """ + region = config.get('region') + az = config.get('aws_az') + + if market == 'spot': + client = boto3.client('ec2') + response = client.describe_spot_price_history( + StartTime=datetime.utcnow(), + ProductDescriptions=['Linux/UNIX (Amazon VPC)'], + AvailabilityZone=az, + InstanceTypes=[ec2_type] + ) + price = float(response['SpotPriceHistory'][0]['SpotPrice']) + + elif market == 'on-demand': + client = boto3.client('pricing', region_name='us-east-1') + + long_region = get_regions()[region] + op_sys = 'Linux' + + filters = [ + {'Field': 'tenancy', 'Value': 'shared', 'Type': 'TERM_MATCH'}, + {'Field': 'operatingSystem', 'Value': op_sys, 'Type': 'TERM_MATCH'}, + {'Field': 'preInstalledSw', 'Value': 'NA', 'Type': 'TERM_MATCH'}, + {'Field': 'location', 'Value': long_region, 'Type': 'TERM_MATCH'}, + {'Field': 'capacitystatus', 'Value': 'Used', 'Type': 'TERM_MATCH'}, + {'Field': 'instanceType', 'Value': ec2_type, 'Type': 'TERM_MATCH'} + ] + response = client.get_products(ServiceCode='AmazonEC2', Filters=filters) + + results = response['PriceList'] + product = json.loads(results[0]) + od = product['terms']['OnDemand'] + price_details = list(od.values())[0]['priceDimensions'] + price = list(price_details.values())[0]['pricePerUnit']['USD'] + price = float(price) + + return price diff --git a/src/forge/create.py b/src/forge/create.py index 194929f..1ec4ccc 100755 --- a/src/forge/create.py +++ b/src/forge/create.py @@ -4,7 +4,6 @@ import sys import os import time -import json from datetime import datetime, timedelta import boto3 @@ -12,7 +11,8 @@ from . import DEFAULT_ARG_VALS, REQUIRED_ARGS from .parser import add_basic_args, add_job_args, add_env_args, add_general_args -from .common import ec2_ip, get_regions, destroy_hook, set_boto_session, user_accessible_vars, FormatEmpty +from .common import (ec2_ip, destroy_hook, set_boto_session, + user_accessible_vars, FormatEmpty, get_ec2_pricing) from .destroy import destroy logger = logging.getLogger(__name__) @@ -251,59 +251,32 @@ def pricing(n, config, fleet_id): set_boto_session(region, profile) - az = config.get('aws_az') region = config.get('region') ec2_client = boto3.client('ec2') - pricing_client = boto3.client('pricing', region_name='us-east-1') - # get list of EC2s - ec2_list = [] + + # Get list of active fleet EC2s + fleet_types = [] fleet_request_configs = ec2_client.describe_fleet_instances(FleetId=fleet_id) - active_instances_list = fleet_request_configs.get('ActiveInstances') - if active_instances_list is None: # Consider changing to if not active_instances_list: - return None + for i in fleet_request_configs.get('ActiveInstances', []): + fleet_types.append(i['InstanceType']) - for i in active_instances_list: - ec2_list.append(i['InstanceType']) + if not fleet_types: + return - # on-demand pricing - long_region = get_regions()[region] - op_sys = 'Linux' + # Get on-demand prices regardless of market total_on_demand_cost = 0 - for ec2 in ec2_list: - filters = [ - {'Field': 'tenancy', 'Value': 'shared', 'Type': 'TERM_MATCH'}, - {'Field': 'operatingSystem', 'Value': op_sys, 'Type': 'TERM_MATCH'}, - {'Field': 'preInstalledSw', 'Value': 'NA', 'Type': 'TERM_MATCH'}, - {'Field': 'location', 'Value': long_region, 'Type': 'TERM_MATCH'}, - {'Field': 'capacitystatus', 'Value': 'Used', 'Type': 'TERM_MATCH'}, - {'Field': 'instanceType', 'Value': ec2, 'Type': 'TERM_MATCH'} - ] - response = pricing_client.get_products(ServiceCode='AmazonEC2', Filters=filters) - results = response['PriceList'] - product = json.loads(results[0]) - od = product['terms']['OnDemand'] - price_details = list(od.values())[0]['priceDimensions'] - on_demand_price = list(price_details.values())[0]['pricePerUnit']['USD'] - on_demand_price = float(on_demand_price) - total_on_demand_cost = total_on_demand_cost + on_demand_price - total_on_demand_cost = round(total_on_demand_cost, 2) - # get spot pricing + for ec2_type in fleet_types: + total_on_demand_cost += get_ec2_pricing(ec2_type, 'on-demand', config) + + # If using spot instances get spot pricing to show savings over on-demand if market == 'spot': total_spot_cost = 0 - for ec2 in ec2_list: - describe_result = ec2_client.describe_spot_price_history( - StartTime=datetime.utcnow(), - ProductDescriptions=['Linux/UNIX (Amazon VPC)'], - AvailabilityZone=az, - InstanceTypes=[ec2] - ) - spot_price = float(describe_result['SpotPriceHistory'][0]['SpotPrice']) - total_spot_cost = total_spot_cost + spot_price - total_spot_cost = round(total_spot_cost, 2) - saving = round(100 * (1 - (total_spot_cost / total_on_demand_cost)), 2) - logger.info(f'Hourly price is ${total_spot_cost}. Savings of {saving}%') + for ec2_type in fleet_types: + total_spot_cost += get_ec2_pricing(ec2_type, market, config) + saving = 100 * (1 - (total_spot_cost / total_on_demand_cost)) + logger.info('Hourly price is $%.2f. Savings of %.2f%%', total_spot_cost, saving) elif market == 'on-demand': - logger.info(f'Hourly price is ${total_on_demand_cost}') + logger.info('Hourly price is $%.2f', total_on_demand_cost) def create_template(n, config, task): diff --git a/src/forge/destroy.py b/src/forge/destroy.py index 989e2f1..a1e53e6 100755 --- a/src/forge/destroy.py +++ b/src/forge/destroy.py @@ -8,7 +8,7 @@ from . import DEFAULT_ARG_VALS, REQUIRED_ARGS from .parser import add_basic_args, add_general_args, add_env_args -from .common import ec2_ip, get_regions, set_boto_session +from .common import ec2_ip, set_boto_session, get_ec2_pricing logger = logging.getLogger(__name__) @@ -44,17 +44,11 @@ def pricing(detail, config, market): The market the instance was created in """ logger.debug('config is %s', config) - env = config.get('forge_env') profile = config.get('aws_profile') - az = config.get('aws_az') region = config.get('region') set_boto_session(region, profile) - ec2_client = boto3.client('ec2') - pricing_client = boto3.client('pricing', region_name='us-east-1') - total_spot_cost = 0 - total_on_demand_cost = 0 total_cost = 0 now = datetime.now(timezone.utc) dif = timedelta() @@ -65,48 +59,15 @@ def pricing(detail, config, market): dif = (now - launch_time) if dif > max_dif: max_dif = dif - ec2 = e['instance_type'] - if market == 'spot': - describe_result = ec2_client.describe_spot_price_history( - StartTime=datetime.utcnow(), - ProductDescriptions=['Linux/UNIX (Amazon VPC)'], - AvailabilityZone=az, - InstanceTypes=[ec2] - ) - spot_price = float(describe_result['SpotPriceHistory'][0]['SpotPrice']) - total_cost = total_cost + spot_price - total_cost = round(total_cost, 2) - elif market == 'on-demand': - long_region = get_regions()[region] - op_sys = 'Linux' - filters = [ - {'Field': 'tenancy', 'Value': 'shared', 'Type': 'TERM_MATCH'}, - {'Field': 'operatingSystem', 'Value': op_sys, 'Type': 'TERM_MATCH'}, - {'Field': 'preInstalledSw', 'Value': 'NA', 'Type': 'TERM_MATCH'}, - {'Field': 'location', 'Value': long_region, 'Type': 'TERM_MATCH'}, - {'Field': 'capacitystatus', 'Value': 'Used', 'Type': 'TERM_MATCH'}, - {'Field': 'instanceType', 'Value': ec2, 'Type': 'TERM_MATCH'} - ] - response = pricing_client.get_products(ServiceCode='AmazonEC2', Filters=filters) - results = response['PriceList'] - product = json.loads(results[0]) - instance = (product['product']['attributes']['instanceType']) - od = product['terms']['OnDemand'] - price = float( - od[list(od)[0]]['priceDimensions'][list(od[list(od)[0]]['priceDimensions'])[0]]['pricePerUnit'][ - 'USD']) - ip = [instance, price] - on_demand_price = float(ip[1]) - total_cost = total_cost + on_demand_price - total_cost = round(total_cost, 2) + ec2_type = e['instance_type'] + total_cost = get_ec2_pricing(ec2_type, market, config) if total_cost > 0: time_d_float = max_dif.total_seconds() - d = {"days": max_dif.days} - d['hours'], rem = divmod(int(max_dif.total_seconds()), 3600) - d["minutes"], d["seconds"] = divmod(rem, 60) + hours, rem = divmod(int(time_d_float), 3600) + minutes = rem // 60 cost = round(total_cost * (time_d_float / 60 / 60), 2) - time_diff = "{hours} hours and {minutes} minutes".format(**d) + time_diff = f"{hours} hours and {minutes} minutes" logger.info('Total run time was %s. Total cost was $%s', time_diff, cost) diff --git a/src/forge/main.py b/src/forge/main.py index 111cad4..719ae23 100755 --- a/src/forge/main.py +++ b/src/forge/main.py @@ -5,7 +5,7 @@ import yaml from . import __version__, DEFAULT_ARG_VALS, REQUIRED_ARGS -from .common import check_keys, set_config_dir, normalize_config +from .common import check_keys, set_config_dir, normalize_config, parse_additional_config from .engine import cli_engine, engine from .create import create, cli_create from .destroy import cli_destroy, destroy @@ -143,7 +143,7 @@ def main(): with open(f'{args["config_dir"]}/{args["forge_env"]}.yaml') as handle: env_config = check_env_yaml(yaml.safe_load(handle)) if env_config.get('additional_config'): - env_config = normalize_config(env_config, env_config.pop('additional_config')) + env_config = parse_additional_config(env_config, env_config.pop('additional_config')) env_config = normalize_config(env_config) check_keys(args.get('region') or env_config['region'], env_config.get('aws_profile')) env_config = {k: v for k, v in env_config.items() if v} diff --git a/src/forge/rsync.py b/src/forge/rsync.py index f0fde29..3e5cd46 100755 --- a/src/forge/rsync.py +++ b/src/forge/rsync.py @@ -64,17 +64,20 @@ def _rsync(config, ip): with key_file(pem_secret, region, profile) as pem_path: if os.path.isdir(rsync_loc): logger.info('Copying folder %s to EC2.', rsync_loc) - c = f'rsync -rave "ssh -o UserKnownHostsFile=/dev/null -o StrictHostKeyChecking=no -i {pem_path}" {rsync_loc}/* root@{ip}:/root/' + rsync_loc += '/*' elif os.path.isfile(rsync_loc): logger.info('Copying file %s to EC2.', rsync_loc) - c = f'rsync -rave "ssh -o UserKnownHostsFile=/dev/null -o StrictHostKeyChecking=no -i {pem_path}" {rsync_loc} root@{ip}:/root/' else: logger.error("File or folder from 'rsync_path' parameter not found: %s", rsync_loc) sys.exit(1) + cmd = 'rsync -rave "ssh -o UserKnownHostsFile=/dev/null -o StrictHostKeyChecking=no' + cmd +=f' -i {pem_path}" {rsync_loc} root@{ip}:/root/' + try: output = subprocess.check_output( - c, stderr=subprocess.STDOUT, shell=True, universal_newlines=True) + cmd, stderr=subprocess.STDOUT, shell=True, universal_newlines=True + ) except subprocess.CalledProcessError as exc: logger.error('Rsync failed:\n%s', exc.output) else: diff --git a/src/forge/ssh.py b/src/forge/ssh.py index c4ae7a8..5d3f75c 100755 --- a/src/forge/ssh.py +++ b/src/forge/ssh.py @@ -67,7 +67,8 @@ def ssh(config): try: subprocess.run(shlex.split(cmd), check=True, universal_newlines=True) except subprocess.CalledProcessError as exc: - logger.error( - 'SSH failed with error code %d: %s', exc.returncode, exc.cmd - ) + if exc.returncode == 142: + logger.error('Missing proper SSH credentials to connect. Please check your user_data and/or AMI.') + else: + logger.error('SSH failed with error code %d: %s', exc.returncode, exc.cmd) sys.exit(exc.returncode) diff --git a/src/forge/start.py b/src/forge/start.py index b832088..3b18f79 100755 --- a/src/forge/start.py +++ b/src/forge/start.py @@ -33,6 +33,40 @@ def cli_start(subparsers): 'forge_env'] +def start_fleet(n_list, config): + """starts each fleet in n_list + + Parameters + ---------- + n_list : list + List of fleet names + config : dict + Forge configuration data + """ + profile = config.get('aws_profile') + region = config.get('region') + set_boto_session(region, profile) + client = boto3.client('ec2') + + details = {n: ec2_ip(n, config) for n in n_list} + targets = {k: get_ip(v, ('stopped', 'stopping')) for k, v in details.items()} + if not targets: + logger.error('Could not find any valid instances to start.') + sys.exit(1) + + for k, v in targets.items(): + if not v: + logger.error('Could not find any valid instances to start for %s', k) + continue + + logger.debug('Instance target details are %s', targets) + logger.info(f'{k} fleet is now starting.') + + for ec2 in v: + _, uid = ec2 + client.start_instances(InstanceIds=[uid]) + + def start(config): """start a stopped on-demand EC2 instance @@ -46,39 +80,6 @@ def start(config): service = config['service'] market = config.get('market') - def start_fleet (n_list, config): - """starts each fleet in n_list - - Parameters - ---------- - n_list : list - List of fleet names - config : dict - Forge configuration data - """ - profile = config.get('aws_profile') - region = config.get('region') - set_boto_session(region, profile) - client = boto3.client('ec2') - - details = {n: ec2_ip(n, config) for n in n_list} - targets = {k: get_ip(v, ('stopped', 'stopping')) for k, v in details.items()} - if not targets: - logger.error('Could not find any valid instances to start.') - sys.exit(1) - - for k, v in targets.items(): - if not v: - logger.error('Could not find any valid instances to start for %s', k) - continue - - logger.debug('Instance target details are %s', targets) - logger.info(f'{k} fleet is now starting.') - - for ec2 in v: - _, uid = ec2 - client.start_instances(InstanceIds=[uid]) - n_list = [] if service == "cluster": if market[0] == 'spot': diff --git a/src/forge/yaml_loader.py b/src/forge/yaml_loader.py index 1f41574..b9cd30f 100755 --- a/src/forge/yaml_loader.py +++ b/src/forge/yaml_loader.py @@ -9,7 +9,7 @@ from . import DEFAULT_ARG_VALS, ADDITIONAL_KEYS from .configure import check_env_yaml -from .common import normalize_config, set_config_dir, check_keys +from .common import normalize_config, set_config_dir, check_keys, parse_additional_config logger = logging.getLogger(__name__) @@ -229,7 +229,7 @@ def load_config(args): env = args['forge_env'] or user_config.get('forge_env') env_config['config_dir'] = config_dir - env_config = normalize_config(env_config, additional_config) + env_config = parse_additional_config(env_config, additional_config) logger.debug('Full user config options: %s', user_config) diff --git a/tests/test_common.py b/tests/test_common.py index d8488a1..f7c8a5f 100644 --- a/tests/test_common.py +++ b/tests/test_common.py @@ -1,4 +1,6 @@ """Tests for the common functions of Forge.""" +import json +from datetime import datetime from unittest import mock import pytest @@ -13,30 +15,38 @@ TEST_ADDITIONAL_KEYS = ['fake', 'not_real'] @mock.patch.dict('forge.common.DEFAULT_ARG_VALS', TEST_DEFAULT_ARG_VALS) -@pytest.mark.parametrize('config,additional_config,expected', [ +@pytest.mark.parametrize('config,expected', [ # Overriding default ratio ({'aws_az': 'us-east-1a', 'ratio': [6, 8]}, - None, {'aws_az': 'us-east-1a', 'region': 'us-east-1'}), # Regular config ({'ram': ['8', [256, 512]], 'cpu': ['1, 2', '7,8'], 'aws_az': 'testing', 'market': 'on-demand, spot'}, - None, {'ram': [[8], [256, 512]], 'cpu': [[1, 2], [7, 8]], 'aws_az': 'testing', 'region': 'testin', 'market': ['on-demand', 'spot'], 'ratio': None}), # No-market config ({'ram': ['8', [256, 512]], 'cpu': ['1, 2', '7,8']}, - None, {'ram': [[8], [256, 512]], 'cpu': [[1, 2], [7, 8]], 'ratio': None}), - # only additional configs +]) +def test_normalize_config(config, expected): + """Test the normalization of config options.""" + actual = common.normalize_config(config) + if config.get('ratio'): + assert config['ratio'] == common.DEFAULT_ARG_VALS['default_ratio'] + + assert actual == expected + + +@mock.patch.dict('forge.common.DEFAULT_ARG_VALS', TEST_DEFAULT_ARG_VALS) +@pytest.mark.parametrize('config,additional_config,expected', [ ({}, [{'name': 'pip', 'type': 'list', 'default': [], 'constraints': []}, {'name': 'version', 'type': 'float', 'default': 2.3, 'constraints': [2.3, 3.0, 3.1]}], {'version': 2.3}) ]) -def test_normalize_config(config, additional_config, expected): +def test_parse_additional_config(config, additional_config, expected): """Test the normalization of config options.""" - actual = common.normalize_config(config, additional_config) + actual = common.parse_additional_config(config, additional_config) if config.get('ratio'): assert config['ratio'] == common.DEFAULT_ARG_VALS['default_ratio'] @@ -79,3 +89,67 @@ def test_user_accessible_vars(config, kwargs, expected): """Test creating the dict of user-accessible variables.""" actual = common.user_accessible_vars(config, **kwargs) assert actual == expected + + +@mock.patch('forge.common.boto3') +@mock.patch('forge.common.datetime') +def test_get_ec2_pricing_spot(mock_dt, mock_boto): + """Test getting spot EC2 hourly pricing.""" + exp_price = 0.123 + response = {'SpotPriceHistory': [{'SpotPrice': str(exp_price)}]} + mock_client = mock_boto.client.return_value = mock.Mock() + mock_describe = mock_client.describe_spot_price_history + mock_describe.return_value = response + now = datetime(2022, 1, 1, 12, 0, 0) + mock_dt.utcnow.return_value = now + + config = {'aws_az': 'us-east-1a', 'region': 'us-east-1'} + ec2_type = 'r5.large' + act_price = common.get_ec2_pricing(ec2_type, 'spot', config) + assert act_price == exp_price + + mock_boto.client.assert_called_once_with('ec2') + mock_dt.utcnow.assert_called_once() + mock_describe.assert_called_once_with( + StartTime=now, + ProductDescriptions=['Linux/UNIX (Amazon VPC)'], + AvailabilityZone=config['aws_az'], + InstanceTypes=[ec2_type] + ) + + +@mock.patch('forge.common.boto3') +@mock.patch('forge.common.get_regions') +def test_get_ec2_pricing_ondemand(mock_regions, mock_boto): + """Test getting on-demand EC2 hourly pricing.""" + exp_price = 0.123 + region = 'us-east-1' + long_region = 'US East (N. Virginia)' + response = {'PriceList': [json.dumps( + {"terms": {"OnDemand": { + "XYZ": {"priceDimensions": {"XYZ.ABC": {"pricePerUnit": {"USD": "0.1230000000"}}}} + }}} + )]} + + mock_client = mock_boto.client.return_value = mock.Mock() + mock_products = mock_client.get_products + mock_products.return_value = response + mock_regions.return_value = {region: long_region} + + config = {'region': region} + ec2_type = 'r5.large' + act_price = common.get_ec2_pricing(ec2_type, 'on-demand', config) + assert act_price == exp_price + + mock_boto.client.assert_called_once_with('pricing', region_name=region) + mock_regions.assert_called_once() + mock_products.assert_called_once_with( + ServiceCode='AmazonEC2', Filters=[ + {'Field': 'tenancy', 'Value': 'shared', 'Type': 'TERM_MATCH'}, + {'Field': 'operatingSystem', 'Value': 'Linux', 'Type': 'TERM_MATCH'}, + {'Field': 'preInstalledSw', 'Value': 'NA', 'Type': 'TERM_MATCH'}, + {'Field': 'location', 'Value': long_region, 'Type': 'TERM_MATCH'}, + {'Field': 'capacitystatus', 'Value': 'Used', 'Type': 'TERM_MATCH'}, + {'Field': 'instanceType', 'Value': ec2_type, 'Type': 'TERM_MATCH'} + ] + ) diff --git a/tests/test_rsync.py b/tests/test_rsync.py new file mode 100644 index 0000000..c71dabe --- /dev/null +++ b/tests/test_rsync.py @@ -0,0 +1,273 @@ +"""Tests for the rsync module of Forge.""" +import subprocess +from unittest import mock + +import pytest + +from forge import rsync + + +@mock.patch('forge.rsync.os.path') +@mock.patch('forge.rsync.subprocess.check_output') +@mock.patch('forge.rsync.key_file') +@mock.patch('forge.rsync.get_ip') +@mock.patch('forge.rsync.ec2_ip') +@pytest.mark.parametrize('service,out', [('single', 'single'), ('cluster', 'cluster-master')]) +def test_rsync_file_success(mock_ec2_ip, mock_get_ip, mock_key_file, mock_sub_chk, + mock_os_path, service, out, caplog): + """Test a successful execution of the 'rsync' sub-command for a file.""" + ip = '123.456.789' + ec2_details = [{'ip': ip, 'spot_id': ['abc'], 'state': None}] + mock_ec2_ip.return_value = ec2_details + mock_get_ip.return_value = [(ip, None)] + key_path = '/dummy/key/path' + rsync_path = 'path/to/rsync/file.txt' + mock_key_file.return_value.__enter__.return_value = key_path + mock_os_path.isdir.return_value = False + mock_os_path.isfile.return_value = True + config = { + 'name': 'test-rsync', + 'date': '2021-02-01', + 'market': ['spot', 'spot'], + 'service': service, + 'forge_pem_secret': 'forge-test', + 'region': 'us-east-1', + 'aws_profile': 'dev', + 'rsync_path': rsync_path, + } + expected_cmd = 'rsync -rave "ssh -o UserKnownHostsFile=/dev/null -o' + expected_cmd += f' StrictHostKeyChecking=no -i {key_path}" {rsync_path} root@{ip}:/root/' + + rsync.rsync(config) + + mock_ec2_ip.assert_called_once_with( + f"{config['name']}-spot-{out}-{config['date']}", config + ) + mock_get_ip.assert_called_once_with(ec2_details, ('running',)) + mock_os_path.isdir.assert_called_once_with(rsync_path) + mock_os_path.isfile.assert_called_once_with(rsync_path) + mock_key_file.assert_called_once_with( + config['forge_pem_secret'], config['region'], config['aws_profile'] + ) + mock_sub_chk.assert_called_once_with( + expected_cmd, stderr=subprocess.STDOUT, shell=True, universal_newlines=True + ) + assert f'Copying file {rsync_path} to EC2.' in caplog.text + + +@mock.patch('forge.rsync.os.path') +@mock.patch('forge.rsync.subprocess.check_output') +@mock.patch('forge.rsync.key_file') +@mock.patch('forge.rsync.get_ip') +@mock.patch('forge.rsync.ec2_ip') +def test_rsync_dir_success(mock_ec2_ip, mock_get_ip, mock_key_file, mock_sub_chk, + mock_os_path, caplog): + """Test a successful execution of the 'rsync' sub-command for a directory.""" + ip = '123.456.789' + ec2_details = [{'ip': ip, 'spot_id': ['abc'], 'state': None}] + mock_ec2_ip.return_value = ec2_details + mock_get_ip.return_value = [(ip, None)] + key_path = '/dummy/key/path' + rsync_path = 'path/to/rsync/dir' + mock_key_file.return_value.__enter__.return_value = key_path + mock_os_path.isdir.return_value = True + config = { + 'name': 'test-rsync', + 'date': '2021-02-01', + 'market': ['spot', 'spot'], + 'service': 'single', + 'forge_pem_secret': 'forge-test', + 'region': 'us-east-1', + 'aws_profile': 'dev', + 'rsync_path': rsync_path, + } + expected_cmd = 'rsync -rave "ssh -o UserKnownHostsFile=/dev/null -o' + expected_cmd += f' StrictHostKeyChecking=no -i {key_path}" {rsync_path}/* root@{ip}:/root/' + + rsync.rsync(config) + + mock_ec2_ip.assert_called_once_with( + f"{config['name']}-spot-single-{config['date']}", config + ) + mock_get_ip.assert_called_once_with(ec2_details, ('running',)) + mock_os_path.isdir.assert_called_once_with(rsync_path) + mock_key_file.assert_called_once_with( + config['forge_pem_secret'], config['region'], config['aws_profile'] + ) + mock_sub_chk.assert_called_once_with( + expected_cmd, stderr=subprocess.STDOUT, shell=True, universal_newlines=True + ) + assert f'Copying folder {rsync_path} to EC2.' in caplog.text + assert 'Rsync successful:\n' in caplog.text + + +@mock.patch('forge.rsync.os.path') +@mock.patch('forge.rsync.key_file') +@mock.patch('forge.rsync.get_ip') +@mock.patch('forge.rsync.ec2_ip') +def test_rsync_no_paths(mock_ec2_ip, mock_get_ip, mock_key_file, mock_os_path, caplog): + """Test an execution of the 'rsync' sub-command with no valid file/folder.""" + ip = '123.456.789' + ec2_details = [{'ip': ip, 'spot_id': ['abc'], 'state': None}] + mock_ec2_ip.return_value = ec2_details + mock_get_ip.return_value = [(ip, None)] + key_path = '/dummy/key/path' + rsync_path = 'fake/dir' + mock_key_file.return_value.__enter__.return_value = key_path + mock_os_path.isdir.return_value = False + mock_os_path.isfile.return_value = False + config = { + 'name': 'test-rsync', + 'date': '2021-02-01', + 'market': ['spot', 'spot'], + 'service': 'single', + 'forge_pem_secret': 'forge-test', + 'region': 'us-east-1', + 'aws_profile': 'dev', + 'rsync_path': rsync_path, + } + + with pytest.raises(SystemExit, match='1'): + rsync.rsync(config) + + mock_ec2_ip.assert_called_once_with( + f"{config['name']}-spot-single-{config['date']}", config + ) + mock_get_ip.assert_called_once_with(ec2_details, ('running',)) + mock_os_path.isdir.assert_called_once_with(rsync_path) + mock_os_path.isfile.assert_called_once_with(rsync_path) + mock_key_file.assert_called_once_with( + config['forge_pem_secret'], config['region'], config['aws_profile'] + ) + assert "File or folder from 'rsync_path' parameter not found:" in caplog.text + + +@mock.patch('forge.rsync.os.path') +@mock.patch('forge.rsync.subprocess.check_output') +@mock.patch('forge.rsync.key_file') +@mock.patch('forge.rsync.get_ip') +@mock.patch('forge.rsync.ec2_ip') +def test_rsync_fail(mock_ec2_ip, mock_get_ip, mock_key_file, mock_sub_chk, + mock_os_path, caplog): + """Test an execution of the 'rsync' sub-command with subprocess errors.""" + ip = '123.456.789' + ec2_details = [{'ip': ip, 'spot_id': ['abc'], 'state': None}] + mock_ec2_ip.return_value = ec2_details + mock_get_ip.return_value = [(ip, None)] + key_path = '/dummy/key/path' + rsync_path = 'path/to/rsync/dir' + mock_key_file.return_value.__enter__.return_value = key_path + mock_os_path.isdir.return_value = True + config = { + 'name': 'test-rsync', + 'date': '2021-02-01', + 'market': ['spot', 'spot'], + 'service': 'single', + 'forge_pem_secret': 'forge-test', + 'region': 'us-east-1', + 'aws_profile': 'dev', + 'rsync_path': rsync_path, + } + expected_cmd = 'rsync -rave "ssh -o UserKnownHostsFile=/dev/null -o' + expected_cmd += f' StrictHostKeyChecking=no -i {key_path}" {rsync_path}/* root@{ip}:/root/' + + mock_sub_chk.side_effect = subprocess.CalledProcessError( + returncode=123, cmd=expected_cmd + ) + + rsync.rsync(config) + + mock_ec2_ip.assert_called_once_with( + f"{config['name']}-spot-single-{config['date']}", config + ) + mock_get_ip.assert_called_once_with(ec2_details, ('running',)) + mock_os_path.isdir.assert_called_once_with(rsync_path) + mock_key_file.assert_called_once_with( + config['forge_pem_secret'], config['region'], config['aws_profile'] + ) + assert 'Rsync failed:\nNone' in caplog.text + + +@mock.patch('forge.rsync.os.path') +@mock.patch('forge.rsync.subprocess.check_output') +@mock.patch('forge.rsync.key_file') +@mock.patch('forge.rsync.get_ip') +@mock.patch('forge.rsync.ec2_ip') +def test_rsync_multi(mock_ec2_ip, mock_get_ip, mock_key_file, mock_sub_chk, + mock_os_path, caplog): + """Test an execution of the 'rsync' sub-command to multiple instances.""" + ips = ['123.456.789', '987.654.321'] + ids = ['abc', 'def'] + ec2_details = [ + {'ip': ip, 'spot_id': [sid], 'state': None} for ip, sid in zip(ips, ids) + ] + mock_ec2_ip.side_effect = ec2_details + mock_get_ip.side_effect = [[(ip, None)] for ip in ips] + key_path = '/dummy/key/path' + rsync_path = 'path/to/rsync/dir' + mock_key_file.return_value.__enter__.return_value = key_path + mock_os_path.isdir.return_value = True + config = { + 'name': 'test-rsync', + 'date': '2021-02-01', + 'market': ['spot', 'spot'], + 'service': 'cluster', + 'forge_pem_secret': 'forge-test', + 'region': 'us-east-1', + 'aws_profile': 'dev', + 'rsync_path': rsync_path, + 'rr_all': True, + } + e_cmd = 'rsync -rave "ssh -o UserKnownHostsFile=/dev/null -o' + e_cmd += f' StrictHostKeyChecking=no -i {key_path}" {rsync_path}/* root@{{0}}:/root/' + expected_cmds = [e_cmd.format(ip) for ip in ips] + + rsync.rsync(config) + + mock_ec2_ip.assert_has_calls([ + mock.call(f"{config['name']}-spot-cluster-master-{config['date']}", config), + mock.call(f"{config['name']}-spot-cluster-worker-{config['date']}", config) + ]) + mock_get_ip.assert_has_calls([mock.call(d, ('running',)) for d in ec2_details]) + mock_os_path.isdir.assert_has_calls([mock.call(rsync_path) for _ in ips]) + mock_key_file.assert_has_calls([ + mock.call( + config['forge_pem_secret'], config['region'], config['aws_profile'] + ) for _ in ips + ], any_order=True) + mock_sub_chk.assert_has_calls([ + mock.call( + e_cmd, stderr=subprocess.STDOUT, shell=True, universal_newlines=True + ) for e_cmd in expected_cmds + ], any_order=True) + for ip in ips: + assert f'Rsync destination is {ip}' in caplog.text + + +@mock.patch('forge.rsync.get_ip') +@mock.patch('forge.rsync.ec2_ip') +@pytest.mark.parametrize('targets', [[], [tuple()]]) +def test_rsync_no_instances(mock_ec2_ip, mock_get_ip, targets, caplog): + """Test an execution of the 'rsync' sub-command with no valid instances.""" + ec2_details = [{'ip': None, 'spot_id': [None], 'state': None}] + mock_ec2_ip.return_value = ec2_details + mock_get_ip.return_value = targets + rsync_path = 'path/to/rsync/dir' + config = { + 'name': 'test-rsync', + 'date': '2021-02-01', + 'market': ['spot', 'spot'], + 'service': 'single', + 'forge_pem_secret': 'forge-test', + 'region': 'us-east-1', + 'aws_profile': 'dev', + 'rsync_path': rsync_path, + } + + rsync.rsync(config) + + mock_ec2_ip.assert_called_once_with( + f"{config['name']}-spot-single-{config['date']}", config + ) + mock_get_ip.assert_called_once_with(ec2_details, ('running',)) + assert 'Could not find any valid instances to rsync to' in caplog.text diff --git a/tests/test_run.py b/tests/test_run.py index 0bdc79b..c5340ce 100644 --- a/tests/test_run.py +++ b/tests/test_run.py @@ -42,6 +42,7 @@ def test_run_success(mock_ec2_ip, mock_get_ip, mock_key_file, mock_sub_run, serv mock_ec2_ip.assert_called_once_with( f"{config['name']}-spot-{out}-{config['date']}", config ) + mock_get_ip.assert_called_once_with(ec2_details, ('running',)) mock_key_file.assert_called_once_with( config['forge_pem_secret'], config['region'], config['aws_profile'] ) diff --git a/tests/test_ssh.py b/tests/test_ssh.py new file mode 100644 index 0000000..c217387 --- /dev/null +++ b/tests/test_ssh.py @@ -0,0 +1,129 @@ +"""Tests for the ssh module of Forge.""" +import logging +import subprocess +from unittest import mock + +import pytest + +from forge import ssh + + +@mock.patch('forge.ssh.subprocess.run') +@mock.patch('forge.ssh.key_file') +@mock.patch('forge.ssh.get_ip') +@mock.patch('forge.ssh.ec2_ip') +@pytest.mark.parametrize('service,out', [('single', 'single'), ('cluster', 'cluster-master')]) +def test_ssh_success(mock_ec2_ip, mock_get_ip, mock_key_file, mock_sub_run, service, out): + """Test a successful execution of the 'ssh' sub-command.""" + ip = '123.456.789' + ec2_details = [{'ip': ip, 'id': 'i-abc', 'fleet_id': ['fleet-def'], 'state': 'running'}] + mock_ec2_ip.return_value = ec2_details + mock_get_ip.return_value = [(ip, 'i-abc')] + key_path = '/dummy/key/path' + mock_key_file.return_value.__enter__.return_value = key_path + config = { + 'name': 'test-run', + 'date': '2021-02-01', + 'service': service, + 'forge_pem_secret': 'forge-test', + 'region': 'us-east-1', + 'aws_profile': 'dev', + 'market': ['spot', 'spot'], + } + expected_cmd = [ + 'ssh', '-t', '-o', 'UserKnownHostsFile=/dev/null', '-o', + 'StrictHostKeyChecking=no', '-i', key_path, f'root@{ip}', + ] + + ssh.ssh(config) + + mock_ec2_ip.assert_called_once_with( + f"{config['name']}-{config['market'][0]}-{out}-{config['date']}", config + ) + mock_key_file.assert_called_once_with( + config['forge_pem_secret'], config['region'], config['aws_profile'] + ) + mock_sub_run.assert_called_once_with( + expected_cmd, check=True, universal_newlines=True + ) + + +@mock.patch('forge.ssh.get_ip') +@mock.patch('forge.ssh.ec2_ip') +def test_ssh_no_instances(mock_ec2_ip, mock_get_ip, caplog): + """Test an execution of the 'ssh' sub-command with no valid instances.""" + ec2_details = [{'ip': None, 'id': None, 'fleet_id': [], 'state': None}] + mock_ec2_ip.return_value = ec2_details + mock_get_ip.return_value = [] + config = { + 'name': 'test-run', + 'date': '2021-02-01', + 'service': 'single', + 'forge_pem_secret': 'forge-test', + 'region': 'us-east-1', + 'aws_profile': 'dev', + 'forge_env': 'test', + 'market': ['spot', 'spot'], + } + + with pytest.raises(SystemExit): + ssh.ssh(config) + + n = f"{config['name']}-spot-single-{config['date']}" + + mock_ec2_ip.assert_called_once_with(n, config) + mock_get_ip.assert_called_once_with(ec2_details, ('running',)) + assert caplog.record_tuples == [ + ('forge.ssh', logging.ERROR, 'Could not find any valid instances to SSH to') + ] + + +@mock.patch('forge.ssh.subprocess.run') +@mock.patch('forge.ssh.key_file') +@mock.patch('forge.ssh.get_ip') +@mock.patch('forge.ssh.ec2_ip') +@pytest.mark.parametrize('err_code,err_msg', [ + (142, 'Missing proper SSH credentials'), + (111, 'SSH failed with error code 111') +]) +def test_ssh_connection_error(mock_ec2_ip, mock_get_ip, mock_key_file, mock_sub_run, + caplog, err_code, err_msg): + """Test an execution of the 'ssh' sub-command with connection errors.""" + ip = '123.456.789' + ec2_details = [{'ip': ip, 'id': 'i-abc', 'fleet_id': ['fleet-def'], 'state': 'running'}] + mock_ec2_ip.return_value = ec2_details + mock_get_ip.return_value = [(ip, 'i-abc')] + key_path = '/dummy/key/path' + mock_key_file.return_value.__enter__.return_value = key_path + config = { + 'name': 'test-run', + 'date': '2021-02-01', + 'service': 'single', + 'forge_pem_secret': 'forge-test', + 'region': 'us-east-1', + 'aws_profile': 'dev', + 'forge_env': 'test', + 'market': ['spot', 'spot'], + } + expected_cmd = [ + 'ssh', '-t', '-o', 'UserKnownHostsFile=/dev/null', '-o', + 'StrictHostKeyChecking=no', '-i', key_path, f'root@{ip}', + ] + mock_sub_run.side_effect = subprocess.CalledProcessError( + returncode=err_code, cmd=expected_cmd + ) + + with pytest.raises(SystemExit, match=str(err_code)): + ssh.ssh(config) + + n = f"{config['name']}-spot-single-{config['date']}" + + mock_ec2_ip.assert_called_once_with(n, config) + mock_get_ip.assert_called_once_with(ec2_details, ('running',)) + mock_key_file.assert_called_once_with( + config['forge_pem_secret'], config['region'], config['aws_profile'] + ) + mock_sub_run.assert_called_once_with( + expected_cmd, check=True, universal_newlines=True + ) + assert err_msg in caplog.text diff --git a/tests/test_start.py b/tests/test_start.py new file mode 100644 index 0000000..9c0abd5 --- /dev/null +++ b/tests/test_start.py @@ -0,0 +1,69 @@ +import logging +from unittest import mock + +from forge import start + +import pytest + +logger = logging.getLogger("start") + + +@mock.patch('forge.start.start_fleet') +@pytest.mark.parametrize( + 'service, markets', + [ + ('cluster', ['on-demand', 'on-demand']), + ('single', ['on-demand']), + ] +) +def test_start(mock_start_fleet, service, markets): + config = { + "name": "some name", + "date": "2022-01-01", + "market": markets, + "service": service, + "region": "us-east-1", + 'aws_profile': "dev" + } + start.start(config) + n_list = [] + if service == "cluster": + for index, market in enumerate(markets): + worker_name = "master" if index == 0 else "worker" + n_list.append(f'{config["name"]}-{market}-{service}-{worker_name}-{config["date"]}') + else: + for market in markets: + n_list.append(f'{config["name"]}-{market}-{service}-{config["date"]}') + + mock_start_fleet.assert_called_once_with(n_list, config) + + +@mock.patch('forge.start.start_fleet') +@pytest.mark.parametrize( + 'service, markets', + [ + ('cluster', ['spot']), + ('cluster', ['on-demand', 'spot']), + ('single', ['spot']), + ] +) +def test_start_error_in_spot_instance(mock_start_fleet, caplog, service, markets): + config = { + "name": "some name", + "date": "2022-01-01", + "market": markets, + "service": service, + } + error_msg = "" + if service == "cluster": + if markets[0] == "spot": + error_msg = "Master is a spot instance; you cannot start a spot instance" + elif markets[1] == "spot": + error_msg = "Worker is a spot fleet; you cannot start a spot fleet" + else: + if markets[0] == "spot": + error_msg = "The instance is a spot instance; you cannot start a spot instance" + + with caplog.at_level(logging.ERROR): + start.start(config) + assert error_msg in caplog.text diff --git a/tests/test_yaml_loader.py b/tests/test_yaml_loader.py index caed5d7..430f7d4 100644 --- a/tests/test_yaml_loader.py +++ b/tests/test_yaml_loader.py @@ -7,10 +7,10 @@ from forge import yaml_loader - TEST_DIR = os.path.dirname(os.path.realpath(__file__)) FORGE_DIR = os.path.dirname(os.path.realpath(yaml_loader.__file__)) + @pytest.mark.parametrize('user_yaml,expected', [ # Input/ouput pairs of config dicts # Single required-only configs with destroy @@ -209,3 +209,30 @@ def test_load_config_errors(mock_pass, args, exp_error, caplog): yaml_loader.load_config(args) assert caplog.record_tuples[-1] == ('forge.yaml_loader', logging.ERROR, exp_error) + + +@pytest.mark.parametrize( + "raw_x", [ + 1, + 1.1, + "1", + None, + [1, "two", 3.3], + [1, 2, 3], + [[-1, "two", 3.3]] + ] +) +def test_non_negative_list_failures(raw_x): + with pytest.raises(ValueError): + yaml_loader.non_negative_list_list_ints(raw_x) + + +@pytest.mark.parametrize( + "raw_x", [ + [[1, 2, 3]], + [[1, 2, 3], [4, 5, 6]], + ] +) +def test_non_negative_list_pass(raw_x): + result = yaml_loader.non_negative_list_list_ints(raw_x) + assert result == raw_x