Skip to content

Commit

Permalink
Merge pull request #30 from kiwix/worker-countries-web-api
Browse files Browse the repository at this point in the history
update worker countries on startup
  • Loading branch information
elfkuzco authored Aug 5, 2024
2 parents 2ab87d4 + 9d2946e commit 512c059
Show file tree
Hide file tree
Showing 11 changed files with 228 additions and 2 deletions.
3 changes: 2 additions & 1 deletion backend/src/mirrors_qa_backend/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from fastapi import FastAPI

from mirrors_qa_backend.db import initialize_mirrors, upgrade_db_schema
from mirrors_qa_backend.routes import auth, tests
from mirrors_qa_backend.routes import auth, tests, worker


@asynccontextmanager
Expand All @@ -18,6 +18,7 @@ def create_app(*, debug: bool = True):

app.include_router(router=tests.router)
app.include_router(router=auth.router)
app.include_router(router=worker.router)

return app

Expand Down
72 changes: 72 additions & 0 deletions backend/src/mirrors_qa_backend/routes/worker.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
import pycountry
from fastapi import APIRouter
from fastapi import status as status_codes

from mirrors_qa_backend.db.country import update_countries as update_db_countries
from mirrors_qa_backend.db.exceptions import RecordDoesNotExistError
from mirrors_qa_backend.db.worker import get_worker as get_db_worker
from mirrors_qa_backend.db.worker import update_worker as update_db_worker
from mirrors_qa_backend.routes.dependencies import CurrentWorker, DbSession
from mirrors_qa_backend.routes.http_errors import (
BadRequestError,
NotFoundError,
UnauthorizedError,
)
from mirrors_qa_backend.schemas import UpdateWorkerCountries, WorkerCountries
from mirrors_qa_backend.serializer import serialize_country

router = APIRouter(prefix="/workers", tags=["workers"])


@router.get(
"/{worker_id}/countries",
status_code=status_codes.HTTP_200_OK,
responses={
status_codes.HTTP_200_OK: {
"description": "Return the list of countries the worker is assigned to."
}
},
)
def list_countries(session: DbSession, worker_id: str) -> WorkerCountries:
try:
worker = get_db_worker(session, worker_id)
except RecordDoesNotExistError as exc:
raise NotFoundError(str(exc)) from exc

return WorkerCountries(
countries=[serialize_country(country) for country in worker.countries]
)


@router.put(
"/{worker_id}/countries",
status_code=status_codes.HTTP_200_OK,
responses={
status_codes.HTTP_200_OK: {
"description": "Return the updated list of countries the worker is assigned"
}
},
)
def update_countries(
session: DbSession,
worker_id: str,
current_worker: CurrentWorker,
data: UpdateWorkerCountries,
) -> WorkerCountries:
if current_worker.id != worker_id:
raise UnauthorizedError(
"You do not have the required permissions to access this endpoint."
)
# Ensure all the country codes are valid country codes
country_mapping: dict[str, str] = {}
for country_code in data.country_codes:
if country := pycountry.countries.get(alpha_2=country_code):
country_mapping[country_code.lower()] = country.name
else:
raise BadRequestError(f"{country_code} is not a valid country code.")
update_db_countries(session, country_mapping)
updated_worker = update_db_worker(session, worker_id, list(country_mapping.keys()))

return WorkerCountries(
countries=[serialize_country(country) for country in updated_worker.countries]
)
19 changes: 18 additions & 1 deletion backend/src/mirrors_qa_backend/schemas.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import datetime
import math
from ipaddress import IPv4Address
from typing import Annotated

import pydantic
from pydantic import UUID4, ConfigDict
from pydantic import UUID4, ConfigDict, Field

from mirrors_qa_backend.enums import StatusEnum

Expand Down Expand Up @@ -56,6 +57,22 @@ class Paginator(BaseModel):
last_page: int | None = None


ISOCountryCode = Annotated[str, Field(min_length=2, max_length=2)]


class Country(BaseModel):
code: ISOCountryCode # two-letter country code as defined in ISO 3166-1
name: str # full name of the country (in English)


class WorkerCountries(BaseModel):
countries: list[Country]


class UpdateWorkerCountries(BaseModel):
country_codes: list[ISOCountryCode]


class TestsList(BaseModel):
tests: list[Test]
metadata: Paginator
Expand Down
4 changes: 4 additions & 0 deletions backend/src/mirrors_qa_backend/serializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,3 +37,7 @@ def serialize_mirror(mirror: models.Mirror) -> schemas.Mirror:
as_only=mirror.as_only,
other_countries=mirror.other_countries,
)


def serialize_country(country: models.Country) -> schemas.Country:
return schemas.Country(code=country.code, name=country.name)
8 changes: 8 additions & 0 deletions backend/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,9 @@
from mirrors_qa_backend import schemas
from mirrors_qa_backend.cryptography import sign_message
from mirrors_qa_backend.db import Session
from mirrors_qa_backend.db.country import create_country
from mirrors_qa_backend.db.models import Base, Mirror, Test, Worker
from mirrors_qa_backend.db.worker import update_worker_countries
from mirrors_qa_backend.enums import StatusEnum
from mirrors_qa_backend.serializer import serialize_mirror

Expand Down Expand Up @@ -125,6 +127,12 @@ def worker(public_key: RSAPublicKey, dbsession: OrmSession) -> Worker:
pubkey_pkcs8=pubkey_pkcs8,
)
dbsession.add(worker)

country_data = {"fr": "France", "ca": "Canada"}
for country_code, country_name in country_data.items():
create_country(dbsession, country_code=country_code, country_name=country_name)
update_worker_countries(dbsession, worker, list(country_data.keys()))

return worker


Expand Down
Empty file.
File renamed without changes.
File renamed without changes.
80 changes: 80 additions & 0 deletions backend/tests/routes/test_worker.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
import pytest
from fastapi import status as status_codes
from fastapi.testclient import TestClient
from sqlalchemy.orm import Session as OrmSession

from mirrors_qa_backend.db.models import Worker
from mirrors_qa_backend.db.worker import get_worker


@pytest.fixture
def auth_headers(access_token: str) -> dict[str, str]:
return {
"Content-type": "application/json",
"Authorization": f"Bearer {access_token}",
}


def test_list_worker_countries(worker: Worker, client: TestClient) -> None:
response = client.get(f"/workers/{worker.id}/countries")
assert response.status_code == status_codes.HTTP_200_OK

data = response.json()
assert "countries" in data

countries = data["countries"]
assert len(worker.countries) == len(countries)

worker_country_codes = [country.code for country in worker.countries]
for country in countries:
assert country["code"] in worker_country_codes


def test_update_worker_with_non_existent_country_code(
worker: Worker, auth_headers: dict[str, str], client: TestClient
):
# mixture of invalid country codes and valid country codes
country_codes = ["us", "fr", "ca", "jj", "xx"]
response = client.put(
f"/workers/{worker.id}/countries",
headers=auth_headers,
json={"country_codes": country_codes},
)

assert response.status_code == status_codes.HTTP_400_BAD_REQUEST


@pytest.mark.parametrize(
["country_codes"],
[
(["nz", "us", "ng", "fr", "ca", "be", "bg", "md"],),
(["ng"],),
([],),
],
)
def test_update_worker_countries(
dbsession: OrmSession,
worker: Worker,
country_codes: list[str],
auth_headers: dict[str, str],
client: TestClient,
) -> None:
response = client.put(
f"/workers/{worker.id}/countries",
headers=auth_headers,
json={"country_codes": country_codes},
)

assert response.status_code == status_codes.HTTP_200_OK

data = response.json()
assert "countries" in data

# reload the worker with the updated countries
worker = get_worker(dbsession, worker.id)
countries = data["countries"]
assert len(worker.countries) == len(countries)

worker_country_codes = [country.code for country in worker.countries]
for country in countries:
assert country["code"] in worker_country_codes
1 change: 1 addition & 0 deletions worker/manager/src/mirrors_qa_manager/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ def query_api(
"POST": requests.post,
"PATCH": requests.patch,
"DELETE": requests.delete,
"PUT": requests.put,
}.get(method.upper(), requests.get)

resp = func(url, headers=req_headers, json=payload)
Expand Down
43 changes: 43 additions & 0 deletions worker/manager/src/mirrors_qa_manager/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import datetime
import json
import random
import re
import shutil
import signal
import sys
Expand Down Expand Up @@ -93,6 +94,24 @@ def get_host_fpath(self, container_fpath: Path) -> Path:
"""Determine the host path of a path in the container."""
return self.host_workdir / container_fpath.relative_to(Settings.WORKDIR_FPATH)

def get_country_codes_from_config_files(self) -> list[str]:
"""Get the ISO country codes using configuration files in base_dir.
Finds all files ending with .conf and applies the following steps:
- take the first two letters of the config filename.
- add to output list if first two letters of config are valid country codes.
"""
conf_file_ptn = re.compile(r"^(?P<country_code>[a-z]{2})-")
country_codes = set()

for conf_file in self.base_dir.glob("*.conf"):
if match := conf_file_ptn.search(conf_file.stem):
country_code = match.groupdict()["country_code"]
if pycountry.countries.get(alpha_2=country_code):
country_codes.add(country_code)

return list(country_codes)

def copy_wireguard_conf_file(self, country_code: str | None = None) -> Path:
"""Path to copied-from-base-dir <country_code>.conf file in instance directory.
Expand Down Expand Up @@ -233,6 +252,27 @@ def merge_data(
"isp": ip_data["organization"],
}

def update_countries_list(self):
"""Update the list of countries from config files if there are any."""
country_codes = self.get_country_codes_from_config_files()
if not country_codes:
logger.info("No country codes inferred from configuration files.")
return

logger.info(
f"Found {len(country_codes)} country codes from configuration files."
)
logger.debug("Updating list of countries on Backend API.")
data = self.query_api(
f"/workers/{self.worker_id}/countries",
method="put",
payload={"country_codes": country_codes},
)
logger.info(
f"Updated the list of countries for worker to {len(data['countries'])} "
"countries."
)

def fetch_tests(self) -> list[dict[str, Any]]:
logger.debug("Fetching tasks from backend API")

Expand Down Expand Up @@ -269,6 +309,9 @@ def run(self) -> None:
):
self.wg_interface_status = WgInterfaceStatus.UP

# Update the worker list of countries using the configuration files
self.update_countries_list()

tests = self.fetch_tests()
for test in tests:
test_id = test["id"]
Expand Down

0 comments on commit 512c059

Please sign in to comment.