From b9f3d23a6060d1b9f88e62de9c77263e4ef64797 Mon Sep 17 00:00:00 2001 From: Vlad Bagmet Date: Thu, 16 Nov 2023 23:43:55 +0000 Subject: [PATCH] Adding retry implementation. --- src/corva/api.py | 32 +++++++--- tests/unit/test_api.py | 138 +++++++++++++++++++++++++++++++---------- 2 files changed, 129 insertions(+), 41 deletions(-) diff --git a/src/corva/api.py b/src/corva/api.py index 73e5fc89..96754a64 100644 --- a/src/corva/api.py +++ b/src/corva/api.py @@ -1,6 +1,8 @@ import json import posixpath import re +import time +from http import HTTPStatus from typing import List, Optional, Sequence, Union import requests @@ -14,6 +16,14 @@ class Api: """ TIMEOUT_LIMITS = (3, 30) # seconds + DEFAULT_MAX_RETRIES = 5 + HTTP_STATUS_CODES_TO_RETRY = [ + HTTPStatus.TOO_MANY_REQUESTS, # 428 + HTTPStatus.INTERNAL_SERVER_ERROR, # 500 + HTTPStatus.BAD_GATEWAY, # 502 + HTTPStatus.SERVICE_UNAVAILABLE, # 503 + HTTPStatus.GATEWAY_TIMEOUT # 504 + ] def __init__( self, @@ -31,6 +41,7 @@ def __init__( self.app_key = app_key self.app_connection_id = app_connection_id self.timeout = timeout or self.TIMEOUT_LIMITS[1] + self.max_retries = self.DEFAULT_MAX_RETRIES @property def default_headers(self): @@ -103,6 +114,7 @@ def _request( requests.Response instance. """ + response = requests.Response() timeout = timeout or self.timeout self._validate_timeout(timeout) @@ -113,14 +125,18 @@ def _request( **(headers or {}), } - response = requests.request( - method=method, - url=url, - params=params, - json=data, - headers=headers, - timeout=timeout, - ) + for retry_attempt in range(max(int(self.max_retries), 1)): + response = requests.request( + method=method, + url=url, + params=params, + json=data, + headers=headers, + timeout=timeout, + ) + if response.status_code not in self.HTTP_STATUS_CODES_TO_RETRY: + break + time.sleep(2 ** retry_attempt / 4) return response diff --git a/tests/unit/test_api.py b/tests/unit/test_api.py index e26f2baf..d764c39a 100644 --- a/tests/unit/test_api.py +++ b/tests/unit/test_api.py @@ -1,6 +1,8 @@ import contextlib import json +import time import urllib.parse +from http import HTTPStatus import pytest import requests @@ -17,7 +19,7 @@ def app(event, api): return api -@pytest.fixture(scope='function') +@pytest.fixture(scope="function") def api(app_runner) -> Api: """Returns Api instance from task app.""" @@ -29,23 +31,23 @@ def api(app_runner) -> Api: def test_request_default_headers(api, requests_mock: RequestsMocker): # do some api call requests_mock.get( - '', + "", request_headers={ - 'Authorization': f'API {api.api_key}', - 'X-Corva-App': api.app_key, + "Authorization": f"API {api.api_key}", + "X-Corva-App": api.app_key, }, ) - api.get('') + api.get("") @pytest.mark.parametrize( - 'path,expected', + "path,expected", [ - ['http://localhost', 'http://localhost'], - ['api/v10/path', f'{SETTINGS.DATA_API_ROOT_URL}/api/v10/path'], - ['/api/v10/path', f'{SETTINGS.DATA_API_ROOT_URL}/api/v10/path'], - ['v10/path', f'{SETTINGS.API_ROOT_URL}/v10/path'], - ['/v10/path', f'{SETTINGS.API_ROOT_URL}/v10/path'], + ["http://localhost", "http://localhost"], + ["api/v10/path", f"{SETTINGS.DATA_API_ROOT_URL}/api/v10/path"], + ["/api/v10/path", f"{SETTINGS.DATA_API_ROOT_URL}/api/v10/path"], + ["v10/path", f"{SETTINGS.API_ROOT_URL}/v10/path"], + ["/v10/path", f"{SETTINGS.API_ROOT_URL}/v10/path"], ], ) def test_request_url(path, expected, api, requests_mock: RequestsMocker): @@ -54,20 +56,20 @@ def test_request_url(path, expected, api, requests_mock: RequestsMocker): def test_request_data_param_passed_as_json(api, requests_mock: RequestsMocker): - post_mock = requests_mock.post('') - api.post('', data={}) - assert post_mock.last_request._request.body.decode() == '{}' + post_mock = requests_mock.post("") + api.post("", data={}) + assert post_mock.last_request._request.body.decode() == "{}" def test_request_additional_headers(api, requests_mock: RequestsMocker): - custom_headers = {'custom': 'value'} + custom_headers = {"custom": "value"} - requests_mock.post('', request_headers={**api.default_headers, **custom_headers}) - api.post('', headers={'custom': 'value'}) + requests_mock.post("", request_headers={**api.default_headers, **custom_headers}) + api.post("", headers={"custom": "value"}) @pytest.mark.parametrize( - 'timeout, exc_ctx', + "timeout, exc_ctx", [ (3, contextlib.nullcontext()), (30, contextlib.nullcontext()), @@ -76,22 +78,22 @@ def test_request_additional_headers(api, requests_mock: RequestsMocker): ], ) def test_request_timeout_limits(timeout, exc_ctx, api, requests_mock: RequestsMocker): - requests_mock.post('') + requests_mock.post("") with exc_ctx: - api.post('', timeout=timeout) + api.post("", timeout=timeout) @pytest.mark.parametrize( - 'fields,skip,limit,query,sort', + "fields,skip,limit,query,sort", ( [None, 0, 1, {}, {}], [ - '_id', + "_id", 1, 2, - {'k1': 'v1'}, - {'k2': 'v2'}, + {"k1": "v1"}, + {"k2": "v2"}, ], ), ) @@ -105,20 +107,20 @@ def test_get_dataset( requests_mock: RequestsMocker, ): provider = SETTINGS.PROVIDER - dataset = 'dataset' + dataset = "dataset" qs = urllib.parse.urlencode( { - 'query': json.dumps(query), - 'sort': json.dumps(sort), - **({'fields': fields} if fields else {}), - 'limit': limit, - 'skip': skip, + "query": json.dumps(query), + "sort": json.dumps(sort), + **({"fields": fields} if fields else {}), + "limit": limit, + "skip": skip, } ) requests_mock.get( - f'/api/v1/data/{provider}/{dataset}/?{qs}', complete_qs=True, text='[{}]' + f"/api/v1/data/{provider}/{dataset}/?{qs}", complete_qs=True, text="[{}]" ) result = api.get_dataset( @@ -136,10 +138,10 @@ def test_get_dataset( def test_get_dataset_raises(api, requests_mock: RequestsMocker): provider = SETTINGS.PROVIDER - dataset = 'dataset' + dataset = "dataset" requests_mock.get( - f'/api/v1/data/{provider}/{dataset}/', + f"/api/v1/data/{provider}/{dataset}/", status_code=400, ) @@ -152,3 +154,73 @@ def test_get_dataset_raises(api, requests_mock: RequestsMocker): sort={}, limit=1, ) + + +def test_retrying_logic_works_as_expected(api, requests_mock: RequestsMocker): + path = "/" + url = f"{SETTINGS.API_ROOT_URL}{path}" + + bad_requests_statuses_codes = [ + HTTPStatus.TOO_MANY_REQUESTS, + HTTPStatus.INTERNAL_SERVER_ERROR, + HTTPStatus.BAD_GATEWAY, + ] + good_requests_statuses_codes = [HTTPStatus.OK] + + requests_sequence_return_status_codes = ( + bad_requests_statuses_codes + good_requests_statuses_codes + ) + + requests_mock.register_uri( + "GET", + url, + [ + {"status_code": int(status_code)} + for status_code in requests_sequence_return_status_codes + ], + ) + + start_time = time.time() + response = api.get(path) + end_time = time.time() + + assert response.status_code in good_requests_statuses_codes + assert len(requests_mock.request_history) == len( + requests_sequence_return_status_codes + ) + assert end_time - start_time > 1, ( + f"At least 1 second retry delay should be applied for " + f"{len(bad_requests_statuses_codes)} retries." + ) + + +def test_retrying_logic_for_custom_max_retries_makes_n_calls( + api, requests_mock: RequestsMocker +): + custom_max_retries = 2 + path = "/" + url = f"{SETTINGS.API_ROOT_URL}{path}" + + bad_requests_statuses_codes = [ + HTTPStatus.TOO_MANY_REQUESTS, + HTTPStatus.INTERNAL_SERVER_ERROR, + HTTPStatus.BAD_GATEWAY, + ] + + requests_mock.register_uri( + "GET", + url, + [ + {"status_code": int(status_code)} + for status_code in bad_requests_statuses_codes + ], + ) + + assert len(bad_requests_statuses_codes) >= custom_max_retries + + api.max_retries = custom_max_retries + api.get(path) + + assert ( + len(requests_mock.request_history) == custom_max_retries + ), f"When all requests fail only {custom_max_retries} requests should be made."