Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: improve test readability #21

Merged
merged 19 commits into from
Jul 26, 2024
Merged
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 41 additions & 39 deletions tests/tasks/test_tasks.py
uniqueg marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
"""Tests for I/O and decryption tasks"""

from time import sleep
import json
import requests
import signal
from time import sleep

import pytest
import requests

from task_bodies import (
output_dir,
Expand All @@ -19,51 +21,57 @@
TIME_LIMIT = 60
athith-g marked this conversation as resolved.
Show resolved Hide resolved


def create_task(tasks_body):
"""Creates task with the given task body."""
return requests.post(
url=f"{TES_URL}/tasks", headers=HEADERS, json=tasks_body, timeout=TIME_LIMIT
)
def timeout(func):
"""Decorator that enforces a time limit on a function."""
def handler(signum, frame):
raise TimeoutError(f"Task did not complete within {TIME_LIMIT} seconds")

def wrapper(*args, **kwargs):
signal.signal(signal.SIGALRM, handler)
signal.alarm(TIME_LIMIT)
func(*args, **kwargs)
signal.alarm(0)

def get_task(task_id):
"""Retrieves list of tasks."""
return requests.get(
url=f"{TES_URL}/tasks/{task_id}", headers=HEADERS, timeout=TIME_LIMIT
)

return wrapper

def get_task_state(task_id):
"""Retrieves state of task until completion."""
def wait_for_task_completion():
nonlocal task_state
elapsed_seconds = 0
get_response = get_task(task_id)
task_state = json.loads(get_response.text)["state"]
while task_state in WAIT_STATUSES:
if elapsed_seconds >= TIME_LIMIT:
raise requests.Timeout(f"Task did not complete within {TIME_LIMIT} seconds.")
sleep(1)
elapsed_seconds += 1
get_response = get_task(task_id)
task_state = json.loads(get_response.text)["state"]

task_state = ""
wait_for_task_completion()
return task_state
@timeout
def wait_for_file_to_download(filename):
"""Waits for file with given filename to download."""
while not (output_dir/filename).exists():
sleep(1)


@pytest.fixture(name="post_response")
def fixture_post_response(request):
"""Returns response received after creating task."""
return create_task(request.param)
return requests.post(
url=f"{TES_URL}/tasks", headers=HEADERS, json=request.param, timeout=TIME_LIMIT
)


@pytest.fixture(name="task_state")
def fixture_task_state(post_response):
"""Returns state of task after completion."""
def get_task():
return requests.get(
url=f"{TES_URL}/tasks/{task_id}", headers=HEADERS, timeout=TIME_LIMIT
)

@timeout
athith-g marked this conversation as resolved.
Show resolved Hide resolved
def wait_for_task_completion():
nonlocal task_state
get_response = get_task()
task_state = json.loads(get_response.text)["state"]
while task_state in WAIT_STATUSES:
sleep(1)
get_response = get_task()
task_state = json.loads(get_response.text)["state"]

task_id = json.loads(post_response.text)["id"]
return get_task_state(task_id)
task_state = ""
wait_for_task_completion()
return task_state


@pytest.mark.parametrize("post_response,filename,expected_output", [
Expand All @@ -76,13 +84,7 @@ def test_task(post_response, task_state, filename, expected_output):
assert post_response.status_code == 200
assert task_state == "COMPLETE"

elapsed_seconds = 0
while not (output_dir/filename).exists():
if elapsed_seconds == TIME_LIMIT:
raise FileNotFoundError(f"{filename} did not download to {output_dir} "
f"within {TIME_LIMIT} seconds.")
sleep(1)
elapsed_seconds += 1
wait_for_file_to_download(filename)

with open(output_dir/filename, encoding="utf-8") as f:
output = f.read()
Expand Down