Skip to content

Commit

Permalink
Replace bdd tests with parametrized tests for stations endpoints
Browse files Browse the repository at this point in the history
  • Loading branch information
conbrad committed Nov 6, 2024
1 parent c920992 commit e6275d5
Show file tree
Hide file tree
Showing 4 changed files with 76 additions and 242 deletions.
147 changes: 15 additions & 132 deletions api/app/tests/__init__.py
Original file line number Diff line number Diff line change
@@ -1,114 +1,51 @@
""" Util & common files for tests
"""
from typing import IO, Any, Callable, Optional, Tuple
"""Util & common files for tests"""

from typing import Callable, Optional
from dateutil import parser
import os
import datetime
import json
import importlib
import jsonpickle
from app.db.models.common import TZTimeStamp


def get_complete_filename(module_path: str, filename: str):
""" Get the full path of a filename, given it's module path """
"""Get the full path of a filename, given it's module path"""
dirname = os.path.dirname(os.path.realpath(module_path))
return os.path.join(dirname, filename)


def _load_json_file(module_path: str, filename: str) -> Optional[dict]:
""" Load json file given a module path and a filename """
if filename == 'None': # Not the best solution...
"""Load json file given a module path and a filename"""
if filename == "None": # Not the best solution...
return None
if filename:
with open(get_complete_filename(module_path, filename), encoding="utf-8") as file_pointer:
return json.load(file_pointer)
return None


def _load_json_file_with_name(module_path: str, filename: str) -> Tuple[Optional[dict], str]:
""" Load json file given a module path and a filename """
if filename == 'None': # Not the best solution...
return None, filename
if filename:
with open(get_complete_filename(module_path, filename), encoding="utf-8") as file_pointer:
return json.load(file_pointer), filename
return None, filename


def load_json_file(module_path: str) -> Callable[[str], dict]:
""" Return a function that can load json from a filename and return a dict """
"""Return a function that can load json from a filename and return a dict"""

def _json_loader(filename: str):
return _load_json_file(module_path, filename)
return _json_loader


def load_json_file_with_name(module_path: str) -> Callable[[str], dict]:
""" Return a function that can load a json from a filename and return a dict, but also the filename """
def _json_loader(filename: str):
return _load_json_file_with_name(module_path, filename)
return _json_loader


def json_converter(item: object):
""" Add datetime serialization """
if isinstance(item, datetime.datetime):
return item.isoformat()
return None


def dump_sqlalchemy_row_data_to_json(response, target: IO[Any]):
""" Useful for dumping sqlalchemy responses to json in for unit tests. """
result = []
for response_row in response:
result.append(jsonpickle.encode(response_row))
target.write(jsonpickle.encode(result))


def dump_sqlalchemy_mapped_object_response_to_json(response, target: IO[Any]):
""" Useful for dumping sqlalchemy responses to json in for unit tests.
e.g. if we want to store the response for GDPS predictions for two stations, we could write the
following code:
```python
query = get_station_model_predictions_order_by_prediction_timestamp(
session, [322, 838], ModelEnum.GDPS, back_5_days, now)
with open('tmp.json', 'w') as tmp:
dump_sqlalchemy_response_to_json(query, tmp)
```
"""
result = []
for row in response:
result_row = []
for record in row:
# Copy the dict so we can safely change it.
data = dict(record.__dict__)
# Pop internal value
data.pop('_sa_instance_state')
result_row.append(
{
'module': type(record).__module__,
'class': type(record).__name__,
'data': data
}
)
result.append(result_row)
json.dump(result, fp=target, default=json_converter, indent=3)


def load_sqlalchemy_response_from_json(filename):
""" Load a sqlalchemy response from a json file """
with open(filename, 'r', encoding="utf-8") as tmp:
"""Load a sqlalchemy response from a json file"""
with open(filename, "r", encoding="utf-8") as tmp:
data = json.load(tmp)
return load_sqlalchemy_response_from_object(data)


def de_serialize_record(record):
""" De-serailize a single sqlalchemy record """
module = importlib.import_module(record['module'])
class_ = getattr(module, record['class'])
"""De-serailize a single sqlalchemy record"""
module = importlib.import_module(record["module"])
class_ = getattr(module, record["class"])
record_data = {}
for key, value in record['data'].items():
for key, value in record["data"].items():
# Handle the special case, where the type is timestamp, converting the string to the
# correct data type.
if isinstance(getattr(class_, key).type, TZTimeStamp):
Expand All @@ -119,7 +56,7 @@ def de_serialize_record(record):


def load_sqlalchemy_response_from_object(data: object):
""" Load a sqlalchemy response from an object """
"""Load a sqlalchemy response from an object"""
# Usualy the data is a list of objects - or a list of list of objects.
# e.g.: [ { record }]
# e.g.: or [ [{record}, {record}]]
Expand All @@ -134,57 +71,3 @@ def load_sqlalchemy_response_from_object(data: object):
return result
# Sometimes though, we're only expecting a single record, not a list.
return de_serialize_record(data)


def apply_crud_mapping(monkeypatch, crud_mapping: dict, module_path: str):
""" Mock the sql response
The crud response was generated by temporarily introducing
"dump_sqlalchemy_row_data_to_json" and "dump_sqlalchemy_mapped_object_response_to_json"
in code - and saving the database responses.
"""

if crud_mapping:
for item in crud_mapping:
if item['serializer'] == "jsonpickle":
_jsonpickle_patch_function(monkeypatch,
item['module'], item['function'], item['json'], module_path)
else:
_json_patch_function(monkeypatch,
item['module'], item['function'], item['json'], module_path)

return {}


def _jsonpickle_patch_function(
monkeypatch,
module_name: str,
function_name: str,
json_filename: str,
module_path: str):
""" Patch module_name.function_name to return de-serialized json_filename """
def mock_get_data(*_):
filename = get_complete_filename(module_path, json_filename)
with open(filename, encoding="utf-8") as file_pointer:
rows = jsonpickle.decode(file_pointer.read())
for row in rows:
# Workaround to remain compatible with old tests. Ideally we would just always pickle the row.
if isinstance(row, str):
yield jsonpickle.decode(row)
continue
yield row

monkeypatch.setattr(importlib.import_module(module_name), function_name, mock_get_data)


def _json_patch_function(monkeypatch,
module_name: str,
function_name: str,
json_filename: str,
module_path: str):
""" Patch module_name.function_name to return de-serialized json_filename """
def mock_get_data(*_):
filename = get_complete_filename(module_path, json_filename)
with open(filename, encoding="utf-8") as file_pointer:
return json.load(file_pointer)

monkeypatch.setattr(importlib.import_module(module_name), function_name, mock_get_data)
26 changes: 0 additions & 26 deletions api/app/tests/test_stations.feature

This file was deleted.

84 changes: 0 additions & 84 deletions api/app/tests/test_stations.py

This file was deleted.

61 changes: 61 additions & 0 deletions api/app/tests/test_stations_new.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
from aiohttp import ClientSession
import pytest
import app.main
from datetime import datetime, timezone
from app.tests import load_json_file
from app.tests.common import default_mock_client_get
from fastapi.testclient import TestClient
from httpx import Response


@pytest.fixture()
def client():
from app.main import app as test_app

with TestClient(test_app) as test_client:
yield test_client


@pytest.mark.parametrize(
"url, status, code, name, lat, long",
[
("/api/stations/", 200, 331, "ASHNOLA", 49.13905, -120.1844),
("/api/stations/", 200, 322, "AFTON", 50.6733333, -120.4816667),
("/api/stations/", 200, 317, "ALLISON PASS", 49.0623139, -120.7674194),
],
)
@pytest.mark.usefixtures("mock_jwt_decode")
def test_get_stations(
client: TestClient,
monkeypatch,
url,
status,
code,
name,
lat,
long,
):
monkeypatch.setattr(ClientSession, "get", default_mock_client_get)
response: Response = client.get(url)
assert response.status_code == status
station = next(x for x in response.json()["features"] if x["properties"]["code"] == code)
assert station["properties"]["code"] == code, "Code"
assert station["properties"]["name"] == name, "Name"
assert station["geometry"]["coordinates"][1] == lat, "Latitude"
assert station["geometry"]["coordinates"][0] == long, "Longitude"
assert len(response.json()["features"]) >= 200


@pytest.mark.usefixtures("mock_jwt_decode")
def test_get_station_details(client: TestClient, monkeypatch):
monkeypatch.setattr(ClientSession, "get", default_mock_client_get)

def mock_get_utc_now():
return datetime.fromtimestamp(1618870929583 / 1000, tz=timezone.utc)

monkeypatch.setattr(app.routers.stations, "get_utc_now", mock_get_utc_now)
expected_response = load_json_file(__file__)("test_stations_details_expected_response.json")

response: Response = client.get("/api/stations/details/")
assert response.status_code == 200
assert response.json() == expected_response

0 comments on commit e6275d5

Please sign in to comment.