Skip to content

Commit

Permalink
feat: implement retry policy
Browse files Browse the repository at this point in the history
  • Loading branch information
jooola committed Jul 31, 2024
1 parent 3799a6e commit ebd29dc
Show file tree
Hide file tree
Showing 2 changed files with 103 additions and 45 deletions.
100 changes: 61 additions & 39 deletions hcloud/_client.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import time
from http import HTTPStatus
from random import uniform
from typing import Protocol

Expand Down Expand Up @@ -256,50 +257,71 @@ def request( # type: ignore[no-untyped-def]

retries = 0
while True:
response = self._requests_session.request(
method=method,
url=url,
headers=headers,
**kwargs,
)

correlation_id = response.headers.get("X-Correlation-Id")
payload = {}
try:
if len(response.content) > 0:
payload = response.json()
except (TypeError, ValueError) as exc:
raise APIException(
code=response.status_code,
message=response.reason,
details={"content": response.content},
correlation_id=correlation_id,
) from exc

if not response.ok:
if not payload or "error" not in payload:
raise APIException(
code=response.status_code,
message=response.reason,
details={"content": response.content},
correlation_id=correlation_id,
)

error: dict = payload["error"]

if (
error["code"] == "rate_limit_exceeded"
and retries < self._retry_max_retries
):
response = self._requests_session.request(
method=method,
url=url,
headers=headers,
**kwargs,
)
return self._read_response(response)
except APIException as exception:
if retries < self._retry_max_retries and self._retry_policy(exception):
time.sleep(self._retry_interval(retries))
retries += 1
continue

raise
except requests.exceptions.Timeout:
if retries < self._retry_max_retries:
time.sleep(self._retry_interval(retries))
retries += 1
continue
raise

def _read_response(self, response: requests.Response) -> dict:
correlation_id = response.headers.get("X-Correlation-Id")
payload = {}
try:
if len(response.content) > 0:
payload = response.json()
except (TypeError, ValueError) as exc:
raise APIException(
code=response.status_code,
message=response.reason,
details={"content": response.content},
correlation_id=correlation_id,
) from exc

if not response.ok:
if not payload or "error" not in payload:
raise APIException(
code=error["code"],
message=error["message"],
details=error.get("details"),
code=response.status_code,
message=response.reason,
details={"content": response.content},
correlation_id=correlation_id,
)

return payload
error: dict = payload["error"]
raise APIException(
code=error["code"],
message=error["message"],
details=error.get("details"),
correlation_id=correlation_id,
)

return payload

def _retry_policy(self, exception: APIException) -> bool:
if isinstance(exception.code, str):
return exception.code in (
"rate_limit_exceeded",
"conflict",
)

if isinstance(exception.code, int):
return exception.code in (
HTTPStatus.BAD_GATEWAY,
HTTPStatus.GATEWAY_TIMEOUT,
)

return False
48 changes: 42 additions & 6 deletions tests/unit/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,12 +109,12 @@ def test_request_fails(self, client, fail_response):

def test_request_fails_correlation_id(self, client, response):
response.headers["X-Correlation-Id"] = "67ed842dc8bc8673"
response.status_code = 409
response.status_code = 422
response._content = json.dumps(
{
"error": {
"code": "conflict",
"message": "some conflict",
"code": "service_error",
"message": "Something crashed",
}
}
).encode("utf-8")
Expand All @@ -125,11 +125,11 @@ def test_request_fails_correlation_id(self, client, response):
"POST", "http://url.com", params={"argument": "value"}, timeout=2
)
error = exception_info.value
assert error.code == "conflict"
assert error.message == "some conflict"
assert error.code == "service_error"
assert error.message == "Something crashed"
assert error.details is None
assert error.correlation_id == "67ed842dc8bc8673"
assert str(error) == "some conflict (conflict, 67ed842dc8bc8673)"
assert str(error) == "Something crashed (service_error, 67ed842dc8bc8673)"

def test_request_500(self, client, fail_response):
fail_response.status_code = 500
Expand Down Expand Up @@ -208,6 +208,42 @@ def test_request_limit_then_success(self, client, rate_limit_response):
)
assert client._requests_session.request.call_count == 2

@pytest.mark.parametrize(
("exception", "expected"),
[
(
APIException(code="rate_limit_exceeded", message="Error", details=None),
True,
),
(
APIException(code="conflict", message="Error", details=None),
True,
),
(
APIException(code=409, message="Conflict", details=None),
False,
),
(
APIException(code=429, message="Too Many Requests", details=None),
False,
),
(
APIException(code=502, message="Bad Gateway", details=None),
True,
),
(
APIException(code=503, message="Service Unavailable", details=None),
False,
),
(
APIException(code=504, message="Gateway Timeout", details=None),
True,
),
],
)
def test_retry_policy(self, client, exception, expected):
assert client._retry_policy(exception) == expected


def test_constant_backoff_function():
backoff = constant_backoff_function(interval=1.0)
Expand Down

0 comments on commit ebd29dc

Please sign in to comment.