-
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #30 from kiwix/worker-countries-web-api
update worker countries on startup
- Loading branch information
Showing
11 changed files
with
228 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,72 @@ | ||
import pycountry | ||
from fastapi import APIRouter | ||
from fastapi import status as status_codes | ||
|
||
from mirrors_qa_backend.db.country import update_countries as update_db_countries | ||
from mirrors_qa_backend.db.exceptions import RecordDoesNotExistError | ||
from mirrors_qa_backend.db.worker import get_worker as get_db_worker | ||
from mirrors_qa_backend.db.worker import update_worker as update_db_worker | ||
from mirrors_qa_backend.routes.dependencies import CurrentWorker, DbSession | ||
from mirrors_qa_backend.routes.http_errors import ( | ||
BadRequestError, | ||
NotFoundError, | ||
UnauthorizedError, | ||
) | ||
from mirrors_qa_backend.schemas import UpdateWorkerCountries, WorkerCountries | ||
from mirrors_qa_backend.serializer import serialize_country | ||
|
||
router = APIRouter(prefix="/workers", tags=["workers"]) | ||
|
||
|
||
@router.get( | ||
"/{worker_id}/countries", | ||
status_code=status_codes.HTTP_200_OK, | ||
responses={ | ||
status_codes.HTTP_200_OK: { | ||
"description": "Return the list of countries the worker is assigned to." | ||
} | ||
}, | ||
) | ||
def list_countries(session: DbSession, worker_id: str) -> WorkerCountries: | ||
try: | ||
worker = get_db_worker(session, worker_id) | ||
except RecordDoesNotExistError as exc: | ||
raise NotFoundError(str(exc)) from exc | ||
|
||
return WorkerCountries( | ||
countries=[serialize_country(country) for country in worker.countries] | ||
) | ||
|
||
|
||
@router.put( | ||
"/{worker_id}/countries", | ||
status_code=status_codes.HTTP_200_OK, | ||
responses={ | ||
status_codes.HTTP_200_OK: { | ||
"description": "Return the updated list of countries the worker is assigned" | ||
} | ||
}, | ||
) | ||
def update_countries( | ||
session: DbSession, | ||
worker_id: str, | ||
current_worker: CurrentWorker, | ||
data: UpdateWorkerCountries, | ||
) -> WorkerCountries: | ||
if current_worker.id != worker_id: | ||
raise UnauthorizedError( | ||
"You do not have the required permissions to access this endpoint." | ||
) | ||
# Ensure all the country codes are valid country codes | ||
country_mapping: dict[str, str] = {} | ||
for country_code in data.country_codes: | ||
if country := pycountry.countries.get(alpha_2=country_code): | ||
country_mapping[country_code.lower()] = country.name | ||
else: | ||
raise BadRequestError(f"{country_code} is not a valid country code.") | ||
update_db_countries(session, country_mapping) | ||
updated_worker = update_db_worker(session, worker_id, list(country_mapping.keys())) | ||
|
||
return WorkerCountries( | ||
countries=[serialize_country(country) for country in updated_worker.countries] | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
File renamed without changes.
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,80 @@ | ||
import pytest | ||
from fastapi import status as status_codes | ||
from fastapi.testclient import TestClient | ||
from sqlalchemy.orm import Session as OrmSession | ||
|
||
from mirrors_qa_backend.db.models import Worker | ||
from mirrors_qa_backend.db.worker import get_worker | ||
|
||
|
||
@pytest.fixture | ||
def auth_headers(access_token: str) -> dict[str, str]: | ||
return { | ||
"Content-type": "application/json", | ||
"Authorization": f"Bearer {access_token}", | ||
} | ||
|
||
|
||
def test_list_worker_countries(worker: Worker, client: TestClient) -> None: | ||
response = client.get(f"/workers/{worker.id}/countries") | ||
assert response.status_code == status_codes.HTTP_200_OK | ||
|
||
data = response.json() | ||
assert "countries" in data | ||
|
||
countries = data["countries"] | ||
assert len(worker.countries) == len(countries) | ||
|
||
worker_country_codes = [country.code for country in worker.countries] | ||
for country in countries: | ||
assert country["code"] in worker_country_codes | ||
|
||
|
||
def test_update_worker_with_non_existent_country_code( | ||
worker: Worker, auth_headers: dict[str, str], client: TestClient | ||
): | ||
# mixture of invalid country codes and valid country codes | ||
country_codes = ["us", "fr", "ca", "jj", "xx"] | ||
response = client.put( | ||
f"/workers/{worker.id}/countries", | ||
headers=auth_headers, | ||
json={"country_codes": country_codes}, | ||
) | ||
|
||
assert response.status_code == status_codes.HTTP_400_BAD_REQUEST | ||
|
||
|
||
@pytest.mark.parametrize( | ||
["country_codes"], | ||
[ | ||
(["nz", "us", "ng", "fr", "ca", "be", "bg", "md"],), | ||
(["ng"],), | ||
([],), | ||
], | ||
) | ||
def test_update_worker_countries( | ||
dbsession: OrmSession, | ||
worker: Worker, | ||
country_codes: list[str], | ||
auth_headers: dict[str, str], | ||
client: TestClient, | ||
) -> None: | ||
response = client.put( | ||
f"/workers/{worker.id}/countries", | ||
headers=auth_headers, | ||
json={"country_codes": country_codes}, | ||
) | ||
|
||
assert response.status_code == status_codes.HTTP_200_OK | ||
|
||
data = response.json() | ||
assert "countries" in data | ||
|
||
# reload the worker with the updated countries | ||
worker = get_worker(dbsession, worker.id) | ||
countries = data["countries"] | ||
assert len(worker.countries) == len(countries) | ||
|
||
worker_country_codes = [country.code for country in worker.countries] | ||
for country in countries: | ||
assert country["code"] in worker_country_codes |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters