Skip to content

Commit

Permalink
Adding retry implementation.
Browse files Browse the repository at this point in the history
  • Loading branch information
vladbagmet committed Nov 16, 2023
1 parent a4dc307 commit b9f3d23
Show file tree
Hide file tree
Showing 2 changed files with 129 additions and 41 deletions.
32 changes: 24 additions & 8 deletions src/corva/api.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import json
import posixpath
import re
import time
from http import HTTPStatus
from typing import List, Optional, Sequence, Union

import requests
Expand All @@ -14,6 +16,14 @@ class Api:
"""

TIMEOUT_LIMITS = (3, 30) # seconds
DEFAULT_MAX_RETRIES = 5
HTTP_STATUS_CODES_TO_RETRY = [
HTTPStatus.TOO_MANY_REQUESTS, # 428
HTTPStatus.INTERNAL_SERVER_ERROR, # 500
HTTPStatus.BAD_GATEWAY, # 502
HTTPStatus.SERVICE_UNAVAILABLE, # 503
HTTPStatus.GATEWAY_TIMEOUT # 504
]

def __init__(
self,
Expand All @@ -31,6 +41,7 @@ def __init__(
self.app_key = app_key
self.app_connection_id = app_connection_id
self.timeout = timeout or self.TIMEOUT_LIMITS[1]
self.max_retries = self.DEFAULT_MAX_RETRIES

@property
def default_headers(self):
Expand Down Expand Up @@ -103,6 +114,7 @@ def _request(
requests.Response instance.
"""

response = requests.Response()
timeout = timeout or self.timeout
self._validate_timeout(timeout)

Expand All @@ -113,14 +125,18 @@ def _request(
**(headers or {}),
}

response = requests.request(
method=method,
url=url,
params=params,
json=data,
headers=headers,
timeout=timeout,
)
for retry_attempt in range(max(int(self.max_retries), 1)):
response = requests.request(
method=method,
url=url,
params=params,
json=data,
headers=headers,
timeout=timeout,
)
if response.status_code not in self.HTTP_STATUS_CODES_TO_RETRY:
break
time.sleep(2 ** retry_attempt / 4)

return response

Expand Down
138 changes: 105 additions & 33 deletions tests/unit/test_api.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import contextlib
import json
import time
import urllib.parse
from http import HTTPStatus

import pytest
import requests
Expand All @@ -17,7 +19,7 @@ def app(event, api):
return api


@pytest.fixture(scope='function')
@pytest.fixture(scope="function")
def api(app_runner) -> Api:
"""Returns Api instance from task app."""

Expand All @@ -29,23 +31,23 @@ def api(app_runner) -> Api:
def test_request_default_headers(api, requests_mock: RequestsMocker):
# do some api call
requests_mock.get(
'',
"",
request_headers={
'Authorization': f'API {api.api_key}',
'X-Corva-App': api.app_key,
"Authorization": f"API {api.api_key}",
"X-Corva-App": api.app_key,
},
)
api.get('')
api.get("")


@pytest.mark.parametrize(
'path,expected',
"path,expected",
[
['http://localhost', 'http://localhost'],
['api/v10/path', f'{SETTINGS.DATA_API_ROOT_URL}/api/v10/path'],
['/api/v10/path', f'{SETTINGS.DATA_API_ROOT_URL}/api/v10/path'],
['v10/path', f'{SETTINGS.API_ROOT_URL}/v10/path'],
['/v10/path', f'{SETTINGS.API_ROOT_URL}/v10/path'],
["http://localhost", "http://localhost"],
["api/v10/path", f"{SETTINGS.DATA_API_ROOT_URL}/api/v10/path"],
["/api/v10/path", f"{SETTINGS.DATA_API_ROOT_URL}/api/v10/path"],
["v10/path", f"{SETTINGS.API_ROOT_URL}/v10/path"],
["/v10/path", f"{SETTINGS.API_ROOT_URL}/v10/path"],
],
)
def test_request_url(path, expected, api, requests_mock: RequestsMocker):
Expand All @@ -54,20 +56,20 @@ def test_request_url(path, expected, api, requests_mock: RequestsMocker):


def test_request_data_param_passed_as_json(api, requests_mock: RequestsMocker):
post_mock = requests_mock.post('')
api.post('', data={})
assert post_mock.last_request._request.body.decode() == '{}'
post_mock = requests_mock.post("")
api.post("", data={})
assert post_mock.last_request._request.body.decode() == "{}"


def test_request_additional_headers(api, requests_mock: RequestsMocker):
custom_headers = {'custom': 'value'}
custom_headers = {"custom": "value"}

requests_mock.post('', request_headers={**api.default_headers, **custom_headers})
api.post('', headers={'custom': 'value'})
requests_mock.post("", request_headers={**api.default_headers, **custom_headers})
api.post("", headers={"custom": "value"})


@pytest.mark.parametrize(
'timeout, exc_ctx',
"timeout, exc_ctx",
[
(3, contextlib.nullcontext()),
(30, contextlib.nullcontext()),
Expand All @@ -76,22 +78,22 @@ def test_request_additional_headers(api, requests_mock: RequestsMocker):
],
)
def test_request_timeout_limits(timeout, exc_ctx, api, requests_mock: RequestsMocker):
requests_mock.post('')
requests_mock.post("")

with exc_ctx:
api.post('', timeout=timeout)
api.post("", timeout=timeout)


@pytest.mark.parametrize(
'fields,skip,limit,query,sort',
"fields,skip,limit,query,sort",
(
[None, 0, 1, {}, {}],
[
'_id',
"_id",
1,
2,
{'k1': 'v1'},
{'k2': 'v2'},
{"k1": "v1"},
{"k2": "v2"},
],
),
)
Expand All @@ -105,20 +107,20 @@ def test_get_dataset(
requests_mock: RequestsMocker,
):
provider = SETTINGS.PROVIDER
dataset = 'dataset'
dataset = "dataset"

qs = urllib.parse.urlencode(
{
'query': json.dumps(query),
'sort': json.dumps(sort),
**({'fields': fields} if fields else {}),
'limit': limit,
'skip': skip,
"query": json.dumps(query),
"sort": json.dumps(sort),
**({"fields": fields} if fields else {}),
"limit": limit,
"skip": skip,
}
)

requests_mock.get(
f'/api/v1/data/{provider}/{dataset}/?{qs}', complete_qs=True, text='[{}]'
f"/api/v1/data/{provider}/{dataset}/?{qs}", complete_qs=True, text="[{}]"
)

result = api.get_dataset(
Expand All @@ -136,10 +138,10 @@ def test_get_dataset(

def test_get_dataset_raises(api, requests_mock: RequestsMocker):
provider = SETTINGS.PROVIDER
dataset = 'dataset'
dataset = "dataset"

requests_mock.get(
f'/api/v1/data/{provider}/{dataset}/',
f"/api/v1/data/{provider}/{dataset}/",
status_code=400,
)

Expand All @@ -152,3 +154,73 @@ def test_get_dataset_raises(api, requests_mock: RequestsMocker):
sort={},
limit=1,
)


def test_retrying_logic_works_as_expected(api, requests_mock: RequestsMocker):
path = "/"
url = f"{SETTINGS.API_ROOT_URL}{path}"

bad_requests_statuses_codes = [
HTTPStatus.TOO_MANY_REQUESTS,
HTTPStatus.INTERNAL_SERVER_ERROR,
HTTPStatus.BAD_GATEWAY,
]
good_requests_statuses_codes = [HTTPStatus.OK]

requests_sequence_return_status_codes = (
bad_requests_statuses_codes + good_requests_statuses_codes
)

requests_mock.register_uri(
"GET",
url,
[
{"status_code": int(status_code)}
for status_code in requests_sequence_return_status_codes
],
)

start_time = time.time()
response = api.get(path)
end_time = time.time()

assert response.status_code in good_requests_statuses_codes
assert len(requests_mock.request_history) == len(
requests_sequence_return_status_codes
)
assert end_time - start_time > 1, (
f"At least 1 second retry delay should be applied for "
f"{len(bad_requests_statuses_codes)} retries."
)


def test_retrying_logic_for_custom_max_retries_makes_n_calls(
api, requests_mock: RequestsMocker
):
custom_max_retries = 2
path = "/"
url = f"{SETTINGS.API_ROOT_URL}{path}"

bad_requests_statuses_codes = [
HTTPStatus.TOO_MANY_REQUESTS,
HTTPStatus.INTERNAL_SERVER_ERROR,
HTTPStatus.BAD_GATEWAY,
]

requests_mock.register_uri(
"GET",
url,
[
{"status_code": int(status_code)}
for status_code in bad_requests_statuses_codes
],
)

assert len(bad_requests_statuses_codes) >= custom_max_retries

api.max_retries = custom_max_retries
api.get(path)

assert (
len(requests_mock.request_history) == custom_max_retries
), f"When all requests fail only {custom_max_retries} requests should be made."

0 comments on commit b9f3d23

Please sign in to comment.