Skip to content

Commit

Permalink
Fix retrying logic (#480)
Browse files Browse the repository at this point in the history
  • Loading branch information
Fokko authored Feb 29, 2024
1 parent 6a34421 commit 1f433a4
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 23 deletions.
8 changes: 4 additions & 4 deletions pyiceberg/catalog/rest.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ def _retry_hook(retry_state: RetryCallState) -> None:
_RETRY_ARGS = {
"retry": retry_if_exception_type(AuthorizationExpiredError),
"stop": stop_after_attempt(2),
"before": _retry_hook,
"before_sleep": _retry_hook,
"reraise": True,
}

Expand Down Expand Up @@ -446,10 +446,10 @@ def _response_to_table(self, identifier_tuple: Tuple[str, ...], table_response:
catalog=self,
)

def _refresh_token(self, session: Optional[Session] = None, new_token: Optional[str] = None) -> None:
def _refresh_token(self, session: Optional[Session] = None, initial_token: Optional[str] = None) -> None:
session = session or self._session
if new_token is not None:
self.properties[TOKEN] = new_token
if initial_token is not None:
self.properties[TOKEN] = initial_token
elif CREDENTIAL in self.properties:
self.properties[TOKEN] = self._fetch_access_token(session, self.properties[CREDENTIAL])

Expand Down
58 changes: 39 additions & 19 deletions tests/catalog/test_rest.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,24 +306,39 @@ def test_list_namespace_with_parent_200(rest_mock: Mocker) -> None:
]


def test_list_namespaces_419(rest_mock: Mocker) -> None:
def test_list_namespaces_token_expired(rest_mock: Mocker) -> None:
new_token = "new_jwt_token"
new_header = dict(TEST_HEADERS)
new_header["Authorization"] = f"Bearer {new_token}"

rest_mock.post(
namespaces = rest_mock.register_uri(
"GET",
f"{TEST_URI}v1/namespaces",
json={
"error": {
"message": "Authorization expired.",
"type": "AuthorizationExpiredError",
"code": 419,
}
},
status_code=419,
request_headers=TEST_HEADERS,
)
rest_mock.post(
[
{
"status_code": 419,
"json": {
"error": {
"message": "Authorization expired.",
"type": "AuthorizationExpiredError",
"code": 419,
}
},
"headers": TEST_HEADERS,
},
{
"status_code": 200,
"json": {"namespaces": [["default"], ["examples"], ["fokko"], ["system"]]},
"headers": new_header,
},
{
"status_code": 200,
"json": {"namespaces": [["default"], ["examples"], ["fokko"], ["system"]]},
"headers": new_header,
},
],
)
tokens = rest_mock.post(
f"{TEST_URI}v1/oauth/tokens",
json={
"access_token": new_token,
Expand All @@ -333,19 +348,24 @@ def test_list_namespaces_419(rest_mock: Mocker) -> None:
},
status_code=200,
)
rest_mock.get(
f"{TEST_URI}v1/namespaces",
json={"namespaces": [["default"], ["examples"], ["fokko"], ["system"]]},
status_code=200,
request_headers=new_header,
)
catalog = RestCatalog("rest", uri=TEST_URI, token=TEST_TOKEN, credential=TEST_CREDENTIALS)
assert catalog.list_namespaces() == [
("default",),
("examples",),
("fokko",),
("system",),
]
assert namespaces.call_count == 2
assert tokens.call_count == 1

assert catalog.list_namespaces() == [
("default",),
("examples",),
("fokko",),
("system",),
]
assert namespaces.call_count == 3
assert tokens.call_count == 1


def test_create_namespace_200(rest_mock: Mocker) -> None:
Expand Down

0 comments on commit 1f433a4

Please sign in to comment.